mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
Compare commits
20 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e8ca6038c9 | ||
|
|
1106a2ba88 | ||
|
|
f95fa59422 | ||
|
|
be1712ebe2 | ||
|
|
1a76944124 | ||
|
|
2263b7396f | ||
|
|
74074e2b6c | ||
|
|
0c68688d61 | ||
|
|
5918f301a2 | ||
|
|
4a6469eecc | ||
|
|
5a225c8e3f | ||
|
|
e0007bd801 | ||
|
|
db98ed7a8e | ||
|
|
dc4b3327ce | ||
|
|
aa8f0a7bf3 | ||
|
|
34e6284f95 | ||
|
|
aa9ed249a3 | ||
|
|
ea0924ec96 | ||
|
|
59787a6b7e | ||
|
|
24339644ca |
120
README.md
120
README.md
@@ -4,7 +4,9 @@
|
||||
|
||||
Implementation of <a href="https://openreview.net/pdf?id=YicbFdNTTy">Vision Transformer</a>, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. Significance is further explained in <a href="https://www.youtube.com/watch?v=TrdevFK_am4">Yannic Kilcher's</a> video. There's really not much to code here, but may as well lay it out for everyone so we expedite the attention revolution.
|
||||
|
||||
For a Pytorch implementation with pretrained models, please see Ross Wightman's repository <a href="https://github.com/rwightman/pytorch-image-models">here</a>
|
||||
For a Pytorch implementation with pretrained models, please see Ross Wightman's repository <a href="https://github.com/rwightman/pytorch-image-models">here</a>.
|
||||
|
||||
The official Jax repository is <a href="https://github.com/google-research/vision_transformer">here</a>.
|
||||
|
||||
## Install
|
||||
|
||||
@@ -58,6 +60,63 @@ Number of image's channels.
|
||||
Dropout rate.
|
||||
- `emb_dropout`: float between `[0, 1]`, default `0`.
|
||||
Embedding dropout rate.
|
||||
- `pool`: string, either `cls` token pooling or `mean` pooling
|
||||
|
||||
## Distillation
|
||||
|
||||
<img src="./distill.png" width="300px"></img>
|
||||
|
||||
A recent <a href="https://arxiv.org/abs/2012.12877">paper</a> has shown that use of a distillation token for distilling knowledge from convolutional nets to vision transformer can yield small and efficient vision transformers. This repository offers the means to do distillation easily.
|
||||
|
||||
ex. distilling from Resnet50 (or any teacher) to a vision transformer
|
||||
|
||||
```python
|
||||
import torch
|
||||
from torchvision.models import resnet50
|
||||
|
||||
from vit_pytorch.distill import DistillableViT, DistillWrapper
|
||||
|
||||
teacher = resnet50(pretrained = True)
|
||||
|
||||
v = DistillableViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
distiller = DistillWrapper(
|
||||
student = v,
|
||||
teacher = teacher,
|
||||
temperature = 3, # temperature of distillation
|
||||
alpha = 0.5 # trade between main loss and distillation loss
|
||||
)
|
||||
|
||||
img = torch.randn(2, 3, 256, 256)
|
||||
labels = torch.randint(0, 1000, (2,))
|
||||
|
||||
loss = distiller(img, labels)
|
||||
loss.backward()
|
||||
|
||||
# after lots of training above ...
|
||||
|
||||
pred = v(img) # (2, 1000)
|
||||
```
|
||||
|
||||
The `DistillableViT` class is identical to `ViT` except for how the forward pass is handled, so you should be able to load the parameters back to `ViT` after you have completed distillation training.
|
||||
|
||||
You can also use the handy `.to_vit` method on the `DistillableViT` instance to get back a `ViT` instance.
|
||||
|
||||
```python
|
||||
v = v.to_vit()
|
||||
type(v) # <class 'vit_pytorch.vit_pytorch.ViT'>
|
||||
```
|
||||
|
||||
## Research Ideas
|
||||
|
||||
### Self Supervised Training
|
||||
@@ -88,7 +147,7 @@ model = ViT(
|
||||
learner = BYOL(
|
||||
model,
|
||||
image_size = 256,
|
||||
hidden_layer = 'to_cls_token'
|
||||
hidden_layer = 'to_latent'
|
||||
)
|
||||
|
||||
opt = torch.optim.Adam(learner.parameters(), lr=3e-4)
|
||||
@@ -147,6 +206,50 @@ v(img) # (1, 1000)
|
||||
|
||||
Other sparse attention frameworks I would highly recommend is <a href="https://github.com/lucidrains/routing-transformer">Routing Transformer</a> or <a href="https://github.com/lucidrains/sinkhorn-transformer">Sinkhorn Transformer</a>
|
||||
|
||||
### Combining with other Transformer improvements
|
||||
|
||||
This paper purposely used the most vanilla of attention networks to make a statement. If you would like to use some of the latest improvements for attention nets, please use the `Encoder` from <a href="https://github.com/lucidrains/x-transformers">this repository</a>.
|
||||
|
||||
ex.
|
||||
|
||||
```bash
|
||||
$ pip install x-transformers
|
||||
```
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.efficient import ViT
|
||||
from x_transformers import Encoder
|
||||
|
||||
v = ViT(
|
||||
dim = 512,
|
||||
image_size = 224,
|
||||
patch_size = 16,
|
||||
num_classes = 1000,
|
||||
transformer = Encoder(
|
||||
dim = 512, # set to be the same as the wrapper
|
||||
depth = 12,
|
||||
heads = 8,
|
||||
ff_glu = True, # ex. feed forward GLU variant https://arxiv.org/abs/2002.05202
|
||||
residual_attn = True # ex. residual attention https://arxiv.org/abs/2012.11747
|
||||
)
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 224, 224)
|
||||
v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
Coming from computer vision and new to transformers? Here are some resources that greatly accelerated my learning.
|
||||
|
||||
1. <a href="http://jalammar.github.io/illustrated-transformer/">Illustrated Transformer</a> - Jay Alammar
|
||||
|
||||
2. <a href="http://peterbloem.nl/blog/transformers">Transformers from Scratch</a> - Peter Bloem
|
||||
|
||||
3. <a href="https://nlp.seas.harvard.edu/2018/04/03/attention.html">The Annotated Transformer</a> - Harvard NLP
|
||||
|
||||
|
||||
## Citations
|
||||
|
||||
```bibtex
|
||||
@@ -160,6 +263,17 @@ Other sparse attention frameworks I would highly recommend is <a href="https://g
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{touvron2020training,
|
||||
title = {Training data-efficient image transformers & distillation through attention},
|
||||
author = {Hugo Touvron and Matthieu Cord and Matthijs Douze and Francisco Massa and Alexandre Sablayrolles and Hervé Jégou},
|
||||
year = {2020},
|
||||
eprint = {2012.12877},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{vaswani2017attention,
|
||||
title = {Attention Is All You Need},
|
||||
@@ -170,3 +284,5 @@ Other sparse attention frameworks I would highly recommend is <a href="https://g
|
||||
primaryClass = {cs.CL}
|
||||
}
|
||||
```
|
||||
|
||||
*I visualise a time when we will be to robots what dogs are to humans, and I’m rooting for the machines.* — Claude Shannon
|
||||
|
||||
BIN
distill.png
Normal file
BIN
distill.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 49 KiB |
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.4.0',
|
||||
version = '0.6.6',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
|
||||
130
vit_pytorch/distill.py
Normal file
130
vit_pytorch/distill.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from vit_pytorch.vit_pytorch import ViT
|
||||
from vit_pytorch.efficient import ViT as EfficientViT
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
# classes
|
||||
|
||||
class DistillMixin:
|
||||
def forward(self, img, distill_token = None, mask = None):
|
||||
p, distilling = self.patch_size, exists(distill_token)
|
||||
|
||||
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
|
||||
x = self.patch_to_embedding(x)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim = 1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
|
||||
if distilling:
|
||||
distill_tokens = repeat(distill_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((x, distill_tokens), dim = 1)
|
||||
|
||||
x = self._attend(x, mask)
|
||||
|
||||
if distilling:
|
||||
x, distill_tokens = x[:, :-1], x[:, -1]
|
||||
|
||||
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
x = self.to_latent(x)
|
||||
out = self.mlp_head(x)
|
||||
|
||||
if distilling:
|
||||
return out, distill_tokens
|
||||
|
||||
return out
|
||||
|
||||
class DistillableViT(DistillMixin, ViT):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(DistillableViT, self).__init__(*args, **kwargs)
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.dim = kwargs['dim']
|
||||
self.num_classes = kwargs['num_classes']
|
||||
|
||||
def to_vit(self):
|
||||
v = ViT(*self.args, **self.kwargs)
|
||||
v.load_state_dict(self.state_dict())
|
||||
return v
|
||||
|
||||
def _attend(self, x, mask):
|
||||
x = self.dropout(x)
|
||||
x = self.transformer(x, mask)
|
||||
return x
|
||||
|
||||
class DistillableEfficientViT(DistillMixin, EfficientViT):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(DistillableEfficientViT, self).__init__(*args, **kwargs)
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.dim = kwargs['dim']
|
||||
self.num_classes = kwargs['num_classes']
|
||||
|
||||
def to_vit(self):
|
||||
v = EfficientViT(*self.args, **self.kwargs)
|
||||
v.load_state_dict(self.state_dict())
|
||||
return v
|
||||
|
||||
def _attend(self, x, mask):
|
||||
return self.transformer(x)
|
||||
|
||||
# knowledge distillation wrapper
|
||||
|
||||
class DistillWrapper(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
teacher,
|
||||
student,
|
||||
temperature = 1.,
|
||||
alpha = 0.5
|
||||
):
|
||||
super().__init__()
|
||||
assert (isinstance(student, (DistillableViT, DistillableEfficientViT))) , 'student must be a vision transformer'
|
||||
|
||||
self.teacher = teacher
|
||||
self.student = student
|
||||
|
||||
dim = student.dim
|
||||
num_classes = student.num_classes
|
||||
self.temperature = temperature
|
||||
self.alpha = alpha
|
||||
|
||||
self.distillation_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
|
||||
self.distill_mlp = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img, labels, temperature = None, alpha = None, **kwargs):
|
||||
b, *_ = img.shape
|
||||
alpha = alpha if exists(alpha) else self.alpha
|
||||
T = temperature if exists(temperature) else self.temperature
|
||||
|
||||
with torch.no_grad():
|
||||
teacher_logits = self.teacher(img)
|
||||
|
||||
student_logits, distill_tokens = self.student(img, distill_token = self.distillation_token, **kwargs)
|
||||
distill_logits = self.distill_mlp(distill_tokens)
|
||||
|
||||
loss = F.cross_entropy(student_logits, labels)
|
||||
|
||||
distill_loss = F.kl_div(
|
||||
F.log_softmax(distill_logits / T, dim = -1),
|
||||
F.softmax(teacher_logits / T, dim = -1).detach(),
|
||||
reduction = 'batchmean')
|
||||
|
||||
distill_loss *= T ** 2
|
||||
|
||||
return loss * alpha + distill_loss * (1 - alpha)
|
||||
@@ -3,9 +3,10 @@ from einops import rearrange, repeat
|
||||
from torch import nn
|
||||
|
||||
class ViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, transformer, channels = 3):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, transformer, pool = 'cls', channels = 3):
|
||||
super().__init__()
|
||||
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
patch_dim = channels * patch_size ** 2
|
||||
|
||||
@@ -16,7 +17,8 @@ class ViT(nn.Module):
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.transformer = transformer
|
||||
|
||||
self.to_cls_token = nn.Identity()
|
||||
self.pool = pool
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
@@ -35,5 +37,7 @@ class ViT(nn.Module):
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.transformer(x)
|
||||
|
||||
x = self.to_cls_token(x[:, 0])
|
||||
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.mlp_head(x)
|
||||
|
||||
@@ -84,12 +84,13 @@ class Transformer(nn.Module):
|
||||
return x
|
||||
|
||||
class ViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
|
||||
super().__init__()
|
||||
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
patch_dim = channels * patch_size ** 2
|
||||
assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size'
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
|
||||
self.patch_size = patch_size
|
||||
|
||||
@@ -100,7 +101,8 @@ class ViT(nn.Module):
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
self.to_cls_token = nn.Identity()
|
||||
self.pool = pool
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
@@ -121,5 +123,7 @@ class ViT(nn.Module):
|
||||
|
||||
x = self.transformer(x, mask)
|
||||
|
||||
x = self.to_cls_token(x[:, 0])
|
||||
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.mlp_head(x)
|
||||
|
||||
Reference in New Issue
Block a user