Compare commits

...

31 Commits
0.2.2 ... 0.6.3

Author SHA1 Message Date
Phil Wang
74074e2b6c offer easy way to turn DistillableViT to ViT at the end of training 2020-12-25 11:16:52 -08:00
Phil Wang
0c68688d61 bump for release 2020-12-25 09:30:48 -08:00
Phil Wang
5918f301a2 cleanup 2020-12-25 09:30:38 -08:00
Phil Wang
4a6469eecc Merge pull request #51 from umbertov/main
Add class for distillation with efficient attention
2020-12-25 09:21:17 -08:00
Umberto Valleriani
5a225c8e3f Add class for distillation with efficient attention
DistillableEfficientViT does the same as DistillableViT, except it
may accept a custom transformer encoder, possibly implementing an
efficient attention mechanism
2020-12-25 13:46:29 +01:00
Phil Wang
e0007bd801 add distill diagram 2020-12-24 11:34:15 -08:00
Phil Wang
db98ed7a8e allow for overriding alpha as well on forward in distillation wrapper 2020-12-24 11:18:36 -08:00
Phil Wang
dc4b3327ce no grad for teacher in distillation 2020-12-24 11:11:58 -08:00
Phil Wang
aa8f0a7bf3 Update README.md 2020-12-24 10:59:03 -08:00
Phil Wang
34e6284f95 Update README.md 2020-12-24 10:58:41 -08:00
Phil Wang
aa9ed249a3 add knowledge distillation with distillation tokens, in light of new finding from facebook ai 2020-12-24 10:39:15 -08:00
Phil Wang
ea0924ec96 update readme 2020-12-23 19:06:48 -08:00
Phil Wang
59787a6b7e allow for mean pool with efficient version too 2020-12-23 18:15:40 -08:00
Phil Wang
24339644ca offer a way to use mean pooling of last layer 2020-12-23 17:23:58 -08:00
Phil Wang
b786029e18 fix the dimension per head to be independent of dim and heads, to make sure users do not have it be too small to learn anything 2020-12-17 07:43:52 -08:00
Phil Wang
9624181940 simplify mlp head 2020-12-07 14:31:50 -08:00
Phil Wang
a656a213e6 update diagram 2020-12-04 12:26:28 -08:00
Phil Wang
f1deb5fb7e Merge pull request #31 from minhlong94/main
Update README and documentation
2020-11-21 08:05:38 -08:00
Long M. Lưu
3f50dd72cf Update README.md 2020-11-21 18:37:03 +07:00
Long M. Lưu
ee5e4e9929 Update vit_pytorch.py 2020-11-21 18:23:04 +07:00
Phil Wang
6c8dfc185e remove float(-inf) as masking value 2020-11-13 12:25:21 -08:00
Phil Wang
4f84ad7a64 authors are now known 2020-11-03 14:28:20 -08:00
Phil Wang
c74bc781f0 cite 2020-11-03 11:59:05 -08:00
Phil Wang
dc5b89c942 use einops repeat 2020-10-28 18:13:57 -07:00
Phil Wang
c1043ab00c update readme 2020-10-26 19:01:03 -07:00
Phil Wang
7a214d7109 allow for training on different image sizes, provided images are smaller than what was passed as image_size keyword on init 2020-10-25 13:17:42 -07:00
Phil Wang
6d1df1a970 more efficient 2020-10-22 22:37:06 -07:00
Phil Wang
d65a8c17a5 remove dropout from last linear to logits 2020-10-16 13:58:23 -07:00
Phil Wang
f7c164d910 assert minimum number of patches 2020-10-16 12:19:50 -07:00
Phil Wang
c7b74e0bc3 rename ipy notebook 2020-10-14 10:35:46 -07:00
Phil Wang
5b5d98a3a7 dropouts are more specific and aggressive in the paper, thanks for letting me know @hila-chefer 2020-10-14 09:22:16 -07:00
9 changed files with 292 additions and 46 deletions

124
README.md
View File

@@ -1,9 +1,11 @@
<img src="./vit.png" width="500px"></img>
<img src="./vit.gif" width="500px"></img>
## Vision Transformer - Pytorch
Implementation of <a href="https://openreview.net/pdf?id=YicbFdNTTy">Vision Transformer</a>, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. Significance is further explained in <a href="https://www.youtube.com/watch?v=TrdevFK_am4">Yannic Kilcher's</a> video. There's really not much to code here, but may as well lay it out for everyone so we expedite the attention revolution.
For a Pytorch implementation with pretrained models, please see Ross Wightman's repository <a href="https://github.com/rwightman/pytorch-image-models">here</a>
## Install
```bash
@@ -22,10 +24,10 @@ v = ViT(
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
heads = 16,
mlp_dim = 2048,
attn_dropout = 0.1,
ff_dropout = 0.1
dropout = 0.1,
emb_dropout = 0.1
)
img = torch.randn(1, 3, 256, 256)
@@ -34,6 +36,81 @@ mask = torch.ones(1, 8, 8).bool() # optional mask, designating which patch to at
preds = v(img, mask = mask) # (1, 1000)
```
## Parameters
- `image_size`: int.
Image size.
- `patch_size`: int.
Number of patches. `image_size` must be divisible by `patch_size`.
The number of patches is: ` n = (image_size // patch_size) ** 2` and `n` **must be greater than 16**.
- `num_classes`: int.
Number of classes to classify.
- `dim`: int.
Last dimension of output tensor after linear transformation `nn.Linear(..., dim)`.
- `depth`: int.
Number of Transformer blocks.
- `heads`: int.
Number of heads in Multi-head Attention layer.
- `mlp_dim`: int.
Dimension of the MLP (FeedForward) layer.
- `channels`: int, default `3`.
Number of image's channels.
- `dropout`: float between `[0, 1]`, default `0.`.
Dropout rate.
- `emb_dropout`: float between `[0, 1]`, default `0`.
Embedding dropout rate.
- `pool`: string, either `cls` token pooling or `mean` pooling
## Distillation
<img src="./distill.png" width="300px"></img>
A recent <a href="https://arxiv.org/abs/2012.12877">paper</a> has shown that use of a distillation token for distilling knowledge from convolutional nets to vision transformer can yield small and efficient vision transformers. This repository offers the means to do distillation easily.
ex. distilling from Resnet50 (or any teacher) to a vision transformer
```python
import torch
from torchvision.models import resnet50
from vit_pytorch.distill import DistillableViT, DistillWrapper
teacher = resnet50(pretrained = True)
v = DistillableViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
distiller = DistillWrapper(
student = v,
teacher = teacher,
temperature = 3, # temperature of distillation
alpha = 0.5 # trade between main loss and distillation loss
)
img = torch.randn(2, 3, 256, 256)
labels = torch.randint(0, 1000, (2,))
loss = distiller(img, labels)
loss.backward()
```
The `DistillableViT` class is identical to `ViT` except for how the forward pass is handled, so you should be able to load the parameters back to `ViT` after you have completed distillation training.
You can also use the handy `.to_vit` method on the `DistillableViT` instance to get back a `ViT` instance.
```python
v = v.to_vit()
type(v) # <class 'vit_pytorch.vit_pytorch.ViT'>
```
## Research Ideas
### Self Supervised Training
@@ -64,7 +141,7 @@ model = ViT(
learner = BYOL(
model,
image_size = 256,
hidden_layer = 'to_cls_token'
hidden_layer = 'to_latent'
)
opt = torch.optim.Adam(learner.parameters(), lr=3e-4)
@@ -126,13 +203,34 @@ Other sparse attention frameworks I would highly recommend is <a href="https://g
## Citations
```bibtex
@inproceedings{
anonymous2021an,
title={An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
author={Anonymous},
booktitle={Submitted to International Conference on Learning Representations},
year={2021},
url={https://openreview.net/forum?id=YicbFdNTTy},
note={under review}
@misc{dosovitskiy2020image,
title = {An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
author = {Alexey Dosovitskiy and Lucas Beyer and Alexander Kolesnikov and Dirk Weissenborn and Xiaohua Zhai and Thomas Unterthiner and Mostafa Dehghani and Matthias Minderer and Georg Heigold and Sylvain Gelly and Jakob Uszkoreit and Neil Houlsby},
year = {2020},
eprint = {2010.11929},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@misc{touvron2020training,
title = {Training data-efficient image transformers & distillation through attention},
author = {Hugo Touvron and Matthieu Cord and Matthijs Douze and Francisco Massa and Alexandre Sablayrolles and Hervé Jégou},
year = {2020},
eprint = {2012.12877},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@misc{vaswani2017attention,
title = {Attention Is All You Need},
author = {Ashish Vaswani and Noam Shazeer and Niki Parmar and Jakob Uszkoreit and Llion Jones and Aidan N. Gomez and Lukasz Kaiser and Illia Polosukhin},
year = {2017},
eprint = {1706.03762},
archivePrefix = {arXiv},
primaryClass = {cs.CL}
}
```

BIN
distill.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 49 KiB

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.2.1',
version = '0.6.4',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',

BIN
vit.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.8 MiB

BIN
vit.png

Binary file not shown.

Before

Width:  |  Height:  |  Size: 137 KiB

130
vit_pytorch/distill.py Normal file
View File

@@ -0,0 +1,130 @@
import torch
import torch.nn.functional as F
from torch import nn
from vit_pytorch.vit_pytorch import ViT
from vit_pytorch.efficient import ViT as EfficientViT
from einops import rearrange, repeat
# helpers
def exists(val):
return val is not None
# classes
class DistillableViT(ViT):
def __init__(self, *args, **kwargs):
super(DistillableViT, self).__init__(*args, **kwargs)
self.args = args
self.kwargs = kwargs
self.dim = kwargs['dim']
self.num_classes = kwargs['num_classes']
def to_vit(self):
v = ViT(*self.args, **self.kwargs)
v.load_state_dict(self.state_dict())
return v
def forward(self, img, distill_token, mask = None):
p = self.patch_size
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
x = self.patch_to_embedding(x)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim = 1)
x += self.pos_embedding[:, :(n + 1)]
distill_tokens = repeat(distill_token, '() n d -> b n d', b = b)
x = torch.cat((x, distill_tokens), dim = 1)
x = self.dropout(x)
x = self.transformer(x, mask)
x, distill_tokens = x[:, :-1], x[:, -1]
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x), distill_tokens
class DistillableEfficientViT(EfficientViT):
def __init__(self, *args, **kwargs):
super(DistillableEfficientViT, self).__init__(*args, **kwargs)
self.dim = kwargs['dim']
self.num_classes = kwargs['num_classes']
def forward(self, img, distill_token, mask = None):
p = self.patch_size
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
x = self.patch_to_embedding(x)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim = 1)
x += self.pos_embedding[:, :(n + 1)]
distill_tokens = repeat(distill_token, '() n d -> b n d', b = b)
x = torch.cat((x, distill_tokens), dim = 1)
x = self.transformer(x)
x, distill_tokens = x[:, :-1], x[:, -1]
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x), distill_tokens
# knowledge distillation wrapper
class DistillWrapper(nn.Module):
def __init__(
self,
*,
teacher,
student,
temperature = 1.,
alpha = 0.5
):
super().__init__()
assert (isinstance(student, (DistillableViT, DistillableEfficientViT))) , 'student must be a vision transformer'
self.teacher = teacher
self.student = student
dim = student.dim
num_classes = student.num_classes
self.temperature = temperature
self.alpha = alpha
self.distillation_token = nn.Parameter(torch.randn(1, 1, dim))
self.distill_mlp = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img, labels, temperature = None, alpha = None, **kwargs):
b, *_ = img.shape
alpha = alpha if exists(alpha) else self.alpha
T = temperature if exists(temperature) else self.temperature
with torch.no_grad():
teacher_logits = self.teacher(img)
student_logits, distill_tokens = self.student(img, distill_token = self.distillation_token, **kwargs)
distill_logits = self.distill_mlp(distill_tokens)
loss = F.cross_entropy(student_logits, labels)
distill_loss = F.kl_div(
F.log_softmax(distill_logits / T, dim = -1),
F.softmax(teacher_logits / T, dim = -1).detach(),
reduction = 'batchmean')
distill_loss *= T ** 2
return loss * alpha + distill_loss * (1 - alpha)

View File

@@ -1,11 +1,12 @@
import torch
from einops import rearrange
from einops import rearrange, repeat
from torch import nn
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, transformer, channels = 3):
def __init__(self, *, image_size, patch_size, num_classes, dim, transformer, pool = 'cls', channels = 3):
super().__init__()
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
@@ -16,13 +17,12 @@ class ViT(nn.Module):
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.transformer = transformer
self.to_cls_token = nn.Identity()
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, num_classes)
nn.Linear(dim, num_classes)
)
def forward(self, img):
@@ -30,11 +30,14 @@ class ViT(nn.Module):
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
x = self.patch_to_embedding(x)
b, n, _ = x.shape
cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding
x += self.pos_embedding[:, :(n + 1)]
x = self.transformer(x)
x = self.to_cls_token(x[:, 0])
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)

View File

@@ -1,8 +1,10 @@
import torch
import torch.nn.functional as F
from einops import rearrange
from einops import rearrange, repeat
from torch import nn
MIN_NUM_PATCHES = 16
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
@@ -25,36 +27,41 @@ class FeedForward(nn.Module):
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim)
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dropout = 0.):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim ** -0.5
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
self.to_out = nn.Linear(dim, dim)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, mask = None):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = h)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
mask_value = -torch.finfo(dots.dtype).max
if mask is not None:
mask = F.pad(mask.flatten(1), (1, 0), value = True)
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
mask = mask[:, None, :] * mask[:, :, None]
dots.masked_fill_(~mask, float('-inf'))
dots.masked_fill_(~mask, mask_value)
del mask
attn = dots.softmax(dim=-1)
attn = self.dropout(attn)
out = torch.einsum('bhij,bhjd->bhid', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
@@ -62,13 +69,13 @@ class Attention(nn.Module):
return out
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, mlp_dim, attn_dropout, ff_dropout):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Residual(PreNorm(dim, Attention(dim, heads = heads, dropout = attn_dropout))),
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = ff_dropout)))
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
]))
def forward(self, x, mask = None):
for attn, ff in self.layers:
@@ -77,26 +84,29 @@ class Transformer(nn.Module):
return x
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, attn_dropout = 0., ff_dropout = 0.):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size'
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.patch_size = patch_size
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.patch_to_embedding = nn.Linear(patch_dim, dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.transformer = Transformer(dim, depth, heads, mlp_dim, attn_dropout, ff_dropout)
self.dropout = nn.Dropout(emb_dropout)
self.to_cls_token = nn.Identity()
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, mlp_dim),
nn.GELU(),
nn.Linear(mlp_dim, num_classes)
nn.Linear(dim, num_classes)
)
def forward(self, img, mask = None):
@@ -104,11 +114,16 @@ class ViT(nn.Module):
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
x = self.patch_to_embedding(x)
b, n, _ = x.shape
cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x, mask)
x = self.to_cls_token(x[:, 0])
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)