Files
vit-pytorch/README.md

407 lines
12 KiB
Markdown
Raw Normal View History

2020-12-04 12:26:28 -08:00
<img src="./vit.gif" width="500px"></img>
2020-10-04 12:47:08 -07:00
2020-10-04 12:34:44 -07:00
## Vision Transformer - Pytorch
2020-10-03 15:49:02 -07:00
2020-10-04 13:53:47 -07:00
Implementation of <a href="https://openreview.net/pdf?id=YicbFdNTTy">Vision Transformer</a>, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. Significance is further explained in <a href="https://www.youtube.com/watch?v=TrdevFK_am4">Yannic Kilcher's</a> video. There's really not much to code here, but may as well lay it out for everyone so we expedite the attention revolution.
2020-10-03 15:49:02 -07:00
2021-01-08 08:23:50 -08:00
For a Pytorch implementation with pretrained models, please see Ross Wightman's repository <a href="https://github.com/rwightman/pytorch-image-models">here</a>.
The official Jax repository is <a href="https://github.com/google-research/vision_transformer">here</a>.
2020-10-26 19:01:03 -07:00
2020-10-04 12:34:44 -07:00
## Install
```bash
$ pip install vit-pytorch
```
## Usage
```python
import torch
from vit_pytorch import ViT
v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
2020-10-13 13:11:32 -07:00
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
2020-10-04 12:34:44 -07:00
)
img = torch.randn(1, 3, 256, 256)
2020-10-07 11:21:03 -07:00
mask = torch.ones(1, 8, 8).bool() # optional mask, designating which patch to attend to
preds = v(img, mask = mask) # (1, 1000)
2020-10-04 12:34:44 -07:00
```
2020-11-21 18:37:03 +07:00
## Parameters
- `image_size`: int.
2021-01-12 06:55:45 -08:00
Image size. If you have rectangular images, make sure your image size is the maximum of the width and height
2020-11-21 18:37:03 +07:00
- `patch_size`: int.
Number of patches. `image_size` must be divisible by `patch_size`.
The number of patches is: ` n = (image_size // patch_size) ** 2` and `n` **must be greater than 16**.
- `num_classes`: int.
Number of classes to classify.
- `dim`: int.
Last dimension of output tensor after linear transformation `nn.Linear(..., dim)`.
- `depth`: int.
Number of Transformer blocks.
- `heads`: int.
Number of heads in Multi-head Attention layer.
- `mlp_dim`: int.
Dimension of the MLP (FeedForward) layer.
- `channels`: int, default `3`.
Number of image's channels.
- `dropout`: float between `[0, 1]`, default `0.`.
Dropout rate.
- `emb_dropout`: float between `[0, 1]`, default `0`.
Embedding dropout rate.
- `pool`: string, either `cls` token pooling or `mean` pooling
## Distillation
2020-12-24 11:34:15 -08:00
<img src="./distill.png" width="300px"></img>
A recent <a href="https://arxiv.org/abs/2012.12877">paper</a> has shown that use of a distillation token for distilling knowledge from convolutional nets to vision transformer can yield small and efficient vision transformers. This repository offers the means to do distillation easily.
ex. distilling from Resnet50 (or any teacher) to a vision transformer
```python
import torch
from torchvision.models import resnet50
from vit_pytorch.distill import DistillableViT, DistillWrapper
teacher = resnet50(pretrained = True)
v = DistillableViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048,
dropout = 0.1,
2020-12-24 11:11:58 -08:00
emb_dropout = 0.1
)
distiller = DistillWrapper(
student = v,
teacher = teacher,
temperature = 3, # temperature of distillation
alpha = 0.5 # trade between main loss and distillation loss
)
img = torch.randn(2, 3, 256, 256)
labels = torch.randint(0, 1000, (2,))
loss = distiller(img, labels)
loss.backward()
# after lots of training above ...
pred = v(img) # (2, 1000)
```
The `DistillableViT` class is identical to `ViT` except for how the forward pass is handled, so you should be able to load the parameters back to `ViT` after you have completed distillation training.
You can also use the handy `.to_vit` method on the `DistillableViT` instance to get back a `ViT` instance.
```python
v = v.to_vit()
type(v) # <class 'vit_pytorch.vit_pytorch.ViT'>
```
2021-03-23 11:57:13 -07:00
## Deep ViT
This <a href="https://arxiv.org/abs/2103.11886">paper</a> notes that ViT struggles to attend at greater depths (past 12 layers), and suggests mixing the attention of each head post-softmax as a solution, dubbed Re-attention. The results line up with the <a href="https://github.com/lucidrains/x-transformers#talking-heads-attention">Talking Heads</a> paper from NLP.
You can use it as follows
```python
import torch
from vit_pytorch.deepvit import DeepViT
v = DeepViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
img = torch.randn(1, 3, 256, 256)
preds = v(img) # (1, 1000)
```
2021-02-19 22:28:53 -08:00
## Token-to-Token ViT
<img src="./t2t.png" width="400px"></img>
<a href="https://arxiv.org/abs/2101.11986">This paper</a> proposes that the first couple layers should downsample the image sequence by unfolding, leading to overlapping image data in each token as shown in the figure above. You can use this variant of the `ViT` as follows.
```python
import torch
from vit_pytorch.t2t import T2TViT
v = T2TViT(
dim = 512,
image_size = 224,
depth = 5,
heads = 8,
mlp_dim = 512,
num_classes = 1000,
t2t_layers = ((7, 4), (3, 2), (3, 2)) # tuples of the kernel size and stride of each consecutive layers of the initial token to token module
)
img = torch.randn(1, 3, 224, 224)
v(img) # (1, 1000)
```
2021-03-08 07:28:31 -08:00
## Masked Patch Prediction
2021-03-08 07:28:31 -08:00
Thanks to <a href="https://github.com/zankner">Zach</a>, you can train using the original masked patch prediction task presented in the paper, with the following code.
2021-03-08 09:34:55 -05:00
```python
import torch
2021-03-08 07:28:31 -08:00
from vit_pytorch import ViT
from vit_pytorch.mpp import MPP
2021-03-08 09:34:55 -05:00
2021-03-09 19:23:09 -08:00
model = ViT(
image_size=256,
patch_size=32,
num_classes=1000,
dim=1024,
depth=6,
heads=8,
mlp_dim=2048,
dropout=0.1,
emb_dropout=0.1
)
2021-03-08 09:34:55 -05:00
mpp_trainer = MPP(
transformer=model,
patch_size=32,
dim=1024,
2021-03-08 07:28:31 -08:00
mask_prob=0.15, # probability of using token in masked prediction task
2021-03-08 09:34:55 -05:00
random_patch_prob=0.30, # probability of randomly replacing a token being used for mpp
2021-03-08 07:28:31 -08:00
replace_prob=0.50, # probability of replacing a token being used for mpp with the mask token
2021-03-08 09:34:55 -05:00
)
opt = torch.optim.Adam(mpp_trainer.parameters(), lr=3e-4)
def sample_unlabelled_images():
return torch.randn(20, 3, 256, 256)
for _ in range(100):
images = sample_unlabelled_images()
loss = mpp_trainer(images)
opt.zero_grad()
loss.backward()
opt.step()
# save your improved network
torch.save(model.state_dict(), './pretrained-net.pt')
```
2021-03-08 07:28:31 -08:00
## Research Ideas
### Self Supervised Training
You can train this with a near SOTA self-supervised learning technique, <a href="https://github.com/lucidrains/byol-pytorch">BYOL</a>, with the following code.
(1)
```bash
$ pip install byol-pytorch
```
(2)
```python
import torch
from vit_pytorch import ViT
from byol_pytorch import BYOL
model = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048
)
learner = BYOL(
model,
image_size = 256,
hidden_layer = 'to_latent'
)
opt = torch.optim.Adam(learner.parameters(), lr=3e-4)
def sample_unlabelled_images():
return torch.randn(20, 3, 256, 256)
for _ in range(100):
images = sample_unlabelled_images()
loss = learner(images)
opt.zero_grad()
loss.backward()
opt.step()
learner.update_moving_average() # update moving average of target encoder
# save your improved network
torch.save(model.state_dict(), './pretrained-net.pt')
```
A pytorch-lightning script is ready for you to use at the repository link above.
### Efficient Attention
There may be some coming from computer vision who think attention still suffers from quadratic costs. Fortunately, we have a lot of new techniques that may help. This repository offers a way for you to plugin your own sparse attention transformer.
2021-02-17 15:30:45 -08:00
An example with <a href="https://arxiv.org/abs/2102.03902">Nystromformer</a>
```bash
2021-02-17 15:30:45 -08:00
$ pip install nystrom-attention
```
```python
import torch
from vit_pytorch.efficient import ViT
2021-02-17 15:30:45 -08:00
from nystrom_attention import Nystromformer
2021-02-17 15:30:45 -08:00
efficient_transformer = Nystromformer(
dim = 512,
depth = 12,
heads = 8,
2021-02-17 15:30:45 -08:00
num_landmarks = 256
)
v = ViT(
dim = 512,
image_size = 2048,
patch_size = 32,
num_classes = 1000,
transformer = efficient_transformer
)
img = torch.randn(1, 3, 2048, 2048) # your high resolution picture
v(img) # (1, 1000)
```
Other sparse attention frameworks I would highly recommend is <a href="https://github.com/lucidrains/routing-transformer">Routing Transformer</a> or <a href="https://github.com/lucidrains/sinkhorn-transformer">Sinkhorn Transformer</a>
2020-12-27 19:10:38 -08:00
### Combining with other Transformer improvements
This paper purposely used the most vanilla of attention networks to make a statement. If you would like to use some of the latest improvements for attention nets, please use the `Encoder` from <a href="https://github.com/lucidrains/x-transformers">this repository</a>.
ex.
```bash
$ pip install x-transformers
```
```python
import torch
from vit_pytorch.efficient import ViT
from x_transformers import Encoder
v = ViT(
dim = 512,
image_size = 224,
patch_size = 16,
num_classes = 1000,
transformer = Encoder(
dim = 512, # set to be the same as the wrapper
depth = 12,
heads = 8,
ff_glu = True, # ex. feed forward GLU variant https://arxiv.org/abs/2002.05202
residual_attn = True # ex. residual attention https://arxiv.org/abs/2012.11747
)
)
img = torch.randn(1, 3, 224, 224)
v(img) # (1, 1000)
```
2021-01-04 10:10:54 -08:00
## Resources
Coming from computer vision and new to transformers? Here are some resources that greatly accelerated my learning.
1. <a href="http://jalammar.github.io/illustrated-transformer/">Illustrated Transformer</a> - Jay Alammar
2. <a href="http://peterbloem.nl/blog/transformers">Transformers from Scratch</a> - Peter Bloem
3. <a href="https://nlp.seas.harvard.edu/2018/04/03/attention.html">The Annotated Transformer</a> - Harvard NLP
2020-10-03 15:49:02 -07:00
## Citations
```bibtex
2020-11-03 14:28:20 -08:00
@misc{dosovitskiy2020image,
title = {An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
author = {Alexey Dosovitskiy and Lucas Beyer and Alexander Kolesnikov and Dirk Weissenborn and Xiaohua Zhai and Thomas Unterthiner and Mostafa Dehghani and Matthias Minderer and Georg Heigold and Sylvain Gelly and Jakob Uszkoreit and Neil Houlsby},
year = {2020},
eprint = {2010.11929},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
2020-10-03 15:49:02 -07:00
}
```
2020-11-03 11:59:05 -08:00
2020-12-24 10:58:41 -08:00
```bibtex
@misc{touvron2020training,
2020-12-24 10:59:03 -08:00
title = {Training data-efficient image transformers & distillation through attention},
author = {Hugo Touvron and Matthieu Cord and Matthijs Douze and Francisco Massa and Alexandre Sablayrolles and Hervé Jégou},
year = {2020},
eprint = {2012.12877},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
2020-12-24 10:58:41 -08:00
}
```
2021-02-19 22:28:53 -08:00
```bibtex
@misc{yuan2021tokenstotoken,
2021-03-23 11:57:13 -07:00
title = {Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet},
author = {Li Yuan and Yunpeng Chen and Tao Wang and Weihao Yu and Yujun Shi and Francis EH Tay and Jiashi Feng and Shuicheng Yan},
year = {2021},
eprint = {2101.11986},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@misc{zhou2021deepvit,
title = {DeepViT: Towards Deeper Vision Transformer},
author = {Daquan Zhou and Bingyi Kang and Xiaojie Jin and Linjie Yang and Xiaochen Lian and Qibin Hou and Jiashi Feng},
year = {2021},
eprint = {2103.11886},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
2021-02-19 22:28:53 -08:00
}
```
2020-11-03 11:59:05 -08:00
```bibtex
@misc{vaswani2017attention,
title = {Attention Is All You Need},
author = {Ashish Vaswani and Noam Shazeer and Niki Parmar and Jakob Uszkoreit and Llion Jones and Aidan N. Gomez and Lukasz Kaiser and Illia Polosukhin},
year = {2017},
eprint = {1706.03762},
archivePrefix = {arXiv},
primaryClass = {cs.CL}
}
```
2020-12-28 10:22:59 -08:00
*I visualise a time when we will be to robots what dogs are to humans, and Im rooting for the machines.* — Claude Shannon