2020-10-04 12:47:08 -07:00
<img src="./vit.png" width="500px"></img>
2020-10-04 12:34:44 -07:00
## Vision Transformer - Pytorch
2020-10-03 15:49:02 -07:00
2020-10-04 13:53:47 -07:00
Implementation of <a href="https://openreview.net/pdf?id=YicbFdNTTy">Vision Transformer</a>, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. Significance is further explained in <a href="https://www.youtube.com/watch?v=TrdevFK_am4">Yannic Kilcher's</a> video. There's really not much to code here, but may as well lay it out for everyone so we expedite the attention revolution.
2020-10-03 15:49:02 -07:00
2020-10-04 12:34:44 -07:00
## Install
```bash
$ pip install vit-pytorch
```
## Usage
```python
import torch
from vit_pytorch import ViT
v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048
)
img = torch.randn(1, 3, 256, 256)
2020-10-07 11:21:03 -07:00
mask = torch.ones(1, 8, 8).bool() # optional mask, designating which patch to attend to
preds = v(img, mask = mask) # (1, 1000)
2020-10-04 12:34:44 -07:00
```
2020-10-04 14:55:29 -07:00
## Suggestion
You can train this with a near SOTA self-supervised learning technique, <a href="https://github.com/lucidrains/byol-pytorch">BYOL</a>, with the following code.
(1)
```bash
$ pip install byol-pytorch
```
(2)
```python
import torch
from vit_pytorch import ViT
from byol_pytorch import BYOL
model = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048
)
learner = BYOL(
model,
image_size = 256,
hidden_layer = 'to_cls_token'
)
opt = torch.optim.Adam(learner.parameters(), lr=3e-4)
def sample_unlabelled_images():
return torch.randn(20, 3, 256, 256)
for _ in range(100):
images = sample_unlabelled_images()
loss = learner(images)
opt.zero_grad()
loss.backward()
opt.step()
learner.update_moving_average() # update moving average of target encoder
# save your improved network
torch.save(model.state_dict(), './pretrained-net.pt')
```
A pytorch-lightning script is ready for you to use at the repository link above.
2020-10-03 15:49:02 -07:00
## Citations
```bibtex
@inproceedings {
anonymous2021an,
title={An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
author={Anonymous},
booktitle={Submitted to International Conference on Learning Representations},
year={2021},
url={https://openreview.net/forum?id=YicbFdNTTy},
note={under review}
}
```