Compare commits

...

28 Commits
0.1.0 ... 0.6.0

Author SHA1 Message Date
Phil Wang
aa9ed249a3 add knowledge distillation with distillation tokens, in light of new finding from facebook ai 2020-12-24 10:39:15 -08:00
Phil Wang
ea0924ec96 update readme 2020-12-23 19:06:48 -08:00
Phil Wang
59787a6b7e allow for mean pool with efficient version too 2020-12-23 18:15:40 -08:00
Phil Wang
24339644ca offer a way to use mean pooling of last layer 2020-12-23 17:23:58 -08:00
Phil Wang
b786029e18 fix the dimension per head to be independent of dim and heads, to make sure users do not have it be too small to learn anything 2020-12-17 07:43:52 -08:00
Phil Wang
9624181940 simplify mlp head 2020-12-07 14:31:50 -08:00
Phil Wang
a656a213e6 update diagram 2020-12-04 12:26:28 -08:00
Phil Wang
f1deb5fb7e Merge pull request #31 from minhlong94/main
Update README and documentation
2020-11-21 08:05:38 -08:00
Long M. Lưu
3f50dd72cf Update README.md 2020-11-21 18:37:03 +07:00
Long M. Lưu
ee5e4e9929 Update vit_pytorch.py 2020-11-21 18:23:04 +07:00
Phil Wang
6c8dfc185e remove float(-inf) as masking value 2020-11-13 12:25:21 -08:00
Phil Wang
4f84ad7a64 authors are now known 2020-11-03 14:28:20 -08:00
Phil Wang
c74bc781f0 cite 2020-11-03 11:59:05 -08:00
Phil Wang
dc5b89c942 use einops repeat 2020-10-28 18:13:57 -07:00
Phil Wang
c1043ab00c update readme 2020-10-26 19:01:03 -07:00
Phil Wang
7a214d7109 allow for training on different image sizes, provided images are smaller than what was passed as image_size keyword on init 2020-10-25 13:17:42 -07:00
Phil Wang
6d1df1a970 more efficient 2020-10-22 22:37:06 -07:00
Phil Wang
d65a8c17a5 remove dropout from last linear to logits 2020-10-16 13:58:23 -07:00
Phil Wang
f7c164d910 assert minimum number of patches 2020-10-16 12:19:50 -07:00
Phil Wang
c7b74e0bc3 rename ipy notebook 2020-10-14 10:35:46 -07:00
Phil Wang
5b5d98a3a7 dropouts are more specific and aggressive in the paper, thanks for letting me know @hila-chefer 2020-10-14 09:22:16 -07:00
Phil Wang
b0e4790c24 bump package 2020-10-13 13:12:19 -07:00
Phil Wang
0b2b3fc20c add dropouts 2020-10-13 13:11:59 -07:00
Phil Wang
ced464dcb4 Update setup.py 2020-10-11 00:06:26 -07:00
Phil Wang
5bf45a2d4d Merge pull request #4 from adimyth/main
Image Classification Example
2020-10-10 19:12:31 -07:00
adimyth
fa32e22855 adds a classification example using 'cats & dogs' data 2020-10-11 03:15:19 +05:30
Phil Wang
a0fa41070f norm cls token before sending to mlp head 2020-10-10 12:08:42 -07:00
Phil Wang
b298031c17 write up example for using efficient transformers 2020-10-07 19:15:21 -07:00
8 changed files with 6561 additions and 38 deletions

146
README.md
View File

@@ -1,9 +1,11 @@
<img src="./vit.png" width="500px"></img>
<img src="./vit.gif" width="500px"></img>
## Vision Transformer - Pytorch
Implementation of <a href="https://openreview.net/pdf?id=YicbFdNTTy">Vision Transformer</a>, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. Significance is further explained in <a href="https://www.youtube.com/watch?v=TrdevFK_am4">Yannic Kilcher's</a> video. There's really not much to code here, but may as well lay it out for everyone so we expedite the attention revolution.
For a Pytorch implementation with pretrained models, please see Ross Wightman's repository <a href="https://github.com/rwightman/pytorch-image-models">here</a>
## Install
```bash
@@ -22,8 +24,10 @@ v = ViT(
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
img = torch.randn(1, 3, 256, 256)
@@ -32,7 +36,76 @@ mask = torch.ones(1, 8, 8).bool() # optional mask, designating which patch to at
preds = v(img, mask = mask) # (1, 1000)
```
## Suggestion
## Parameters
- `image_size`: int.
Image size.
- `patch_size`: int.
Number of patches. `image_size` must be divisible by `patch_size`.
The number of patches is: ` n = (image_size // patch_size) ** 2` and `n` **must be greater than 16**.
- `num_classes`: int.
Number of classes to classify.
- `dim`: int.
Last dimension of output tensor after linear transformation `nn.Linear(..., dim)`.
- `depth`: int.
Number of Transformer blocks.
- `heads`: int.
Number of heads in Multi-head Attention layer.
- `mlp_dim`: int.
Dimension of the MLP (FeedForward) layer.
- `channels`: int, default `3`.
Number of image's channels.
- `dropout`: float between `[0, 1]`, default `0.`.
Dropout rate.
- `emb_dropout`: float between `[0, 1]`, default `0`.
Embedding dropout rate.
- `pool`: string, either `cls` token pooling or `mean` pooling
## Distillation
A recent <a href="https://arxiv.org/abs/2012.12877">paper</a> has shown that use of a distillation token for distilling knowledge from convolutional nets to vision transformer can yield small and efficient vision transformers. This repository offers the means to do distillation easily.
ex. distilling from Resnet50 (or any teacher) to a vision transformer
```python
import torch
from torchvision.models import resnet50
from vit_pytorch.distill import DistillableViT, DistillWrapper
teacher = resnet50(pretrained = True)
v = DistillableViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1,
pool = 'mean'
)
distiller = DistillWrapper(
student = v,
teacher = teacher,
temperature = 3, # temperature of distillation
alpha = 0.5 # trade between main loss and distillation loss
)
img = torch.randn(2, 3, 256, 256)
labels = torch.randint(0, 1000, (2,))
loss = distiller(img, labels)
loss.backward()
```
The `DistillableViT` class is identical to `ViT` except for how the forward pass is handled, so you should be able to load the parameters back to `ViT` after you have completed distillation training.
## Research Ideas
### Self Supervised Training
You can train this with a near SOTA self-supervised learning technique, <a href="https://github.com/lucidrains/byol-pytorch">BYOL</a>, with the following code.
@@ -60,7 +133,7 @@ model = ViT(
learner = BYOL(
model,
image_size = 256,
hidden_layer = 'to_cls_token'
hidden_layer = 'to_latent'
)
opt = torch.optim.Adam(learner.parameters(), lr=3e-4)
@@ -82,16 +155,63 @@ torch.save(model.state_dict(), './pretrained-net.pt')
A pytorch-lightning script is ready for you to use at the repository link above.
### Efficient Attention
There may be some coming from computer vision who think attention still suffers from quadratic costs. Fortunately, we have a lot of new techniques that may help. This repository offers a way for you to plugin your own sparse attention transformer.
An example with <a href="https://arxiv.org/abs/2006.04768">Linformer</a>
```bash
$ pip install linformer
```
```python
import torch
from vit_pytorch.efficient import ViT
from linformer import Linformer
efficient_transformer = Linformer(
dim = 512,
seq_len = 4096 + 1, # 64 x 64 patches + 1 cls token
depth = 12,
heads = 8,
k = 256
)
v = ViT(
dim = 512,
image_size = 2048,
patch_size = 32,
num_classes = 1000,
transformer = efficient_transformer
)
img = torch.randn(1, 3, 2048, 2048) # your high resolution picture
v(img) # (1, 1000)
```
Other sparse attention frameworks I would highly recommend is <a href="https://github.com/lucidrains/routing-transformer">Routing Transformer</a> or <a href="https://github.com/lucidrains/sinkhorn-transformer">Sinkhorn Transformer</a>
## Citations
```bibtex
@inproceedings{
anonymous2021an,
title={An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
author={Anonymous},
booktitle={Submitted to International Conference on Learning Representations},
year={2021},
url={https://openreview.net/forum?id=YicbFdNTTy},
note={under review}
@misc{dosovitskiy2020image,
title = {An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
author = {Alexey Dosovitskiy and Lucas Beyer and Alexander Kolesnikov and Dirk Weissenborn and Xiaohua Zhai and Thomas Unterthiner and Mostafa Dehghani and Matthias Minderer and Georg Heigold and Sylvain Gelly and Jakob Uszkoreit and Neil Houlsby},
year = {2020},
eprint = {2010.11929},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@misc{vaswani2017attention,
title = {Attention Is All You Need},
author = {Ashish Vaswani and Noam Shazeer and Niki Parmar and Jakob Uszkoreit and Llion Jones and Aidan N. Gomez and Lukasz Kaiser and Illia Polosukhin},
year = {2017},
eprint = {1706.03762},
archivePrefix = {arXiv},
primaryClass = {cs.CL}
}
```

6253
examples/cats_and_dogs.ipynb Normal file

File diff suppressed because one or more lines are too long

View File

@@ -2,8 +2,8 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(),
version = '0.0.5',
packages = find_packages(exclude=['examples']),
version = '0.6.0',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',
@@ -25,4 +25,4 @@ setup(
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
)

BIN
vit.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.8 MiB

BIN
vit.png

Binary file not shown.

Before

Width:  |  Height:  |  Size: 137 KiB

88
vit_pytorch/distill.py Normal file
View File

@@ -0,0 +1,88 @@
import torch
import torch.nn.functional as F
from torch import nn
from vit_pytorch.vit_pytorch import ViT
from einops import rearrange, repeat
# helpers
def exists(val):
return val is not None
# classes
class DistillableViT(ViT):
def __init__(self, *args, **kwargs):
super(DistillableViT, self).__init__(*args, **kwargs)
self.dim = kwargs['dim']
self.num_classes = kwargs['num_classes']
def forward(self, img, distill_token, mask = None):
p = self.patch_size
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
x = self.patch_to_embedding(x)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim = 1)
x += self.pos_embedding[:, :(n + 1)]
distill_tokens = repeat(distill_token, '() n d -> b n d', b = b)
x = torch.cat((x, distill_tokens), dim = 1)
x = self.dropout(x)
x = self.transformer(x, mask)
x, distill_tokens = x[:, :-1], x[:, -1]
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x), distill_tokens
class DistillWrapper(nn.Module):
def __init__(
self,
*,
teacher,
student,
temperature = 1.,
alpha = 0.5
):
super().__init__()
assert isinstance(student, DistillableViT), 'student must be a vision transformer'
self.teacher = teacher
self.student = student
dim = student.dim
num_classes = student.num_classes
self.temperature = temperature
self.alpha = alpha
self.distillation_token = nn.Parameter(torch.randn(1, 1, dim))
self.distill_mlp = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img, labels, temperature = None, **kwargs):
b, *_, alpha = *img.shape, self.alpha
T = temperature if exists(temperature) else self.temperature
teacher_logits = self.teacher(img)
student_logits, distill_tokens = self.student(img, distill_token = self.distillation_token, **kwargs)
distill_logits = self.distill_mlp(distill_tokens)
loss = F.cross_entropy(student_logits, labels)
distill_loss = F.kl_div(
F.log_softmax(distill_logits / T, dim = -1),
F.softmax(teacher_logits / T, dim = -1).detach(),
reduction = 'batchmean')
distill_loss *= T ** 2
return loss * alpha + distill_loss * (1 - alpha)

43
vit_pytorch/efficient.py Normal file
View File

@@ -0,0 +1,43 @@
import torch
from einops import rearrange, repeat
from torch import nn
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, transformer, pool = 'cls', channels = 3):
super().__init__()
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
self.patch_size = patch_size
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.patch_to_embedding = nn.Linear(patch_dim, dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.transformer = transformer
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
p = self.patch_size
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
x = self.patch_to_embedding(x)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.transformer(x)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)

View File

@@ -1,8 +1,10 @@
import torch
import torch.nn.functional as F
from einops import rearrange
from einops import rearrange, repeat
from torch import nn
MIN_NUM_PATCHES = 16
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
@@ -19,36 +21,44 @@ class PreNorm(nn.Module):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim)
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim ** -0.5
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
self.to_out = nn.Linear(dim, dim)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, mask = None):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = h)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
mask_value = -torch.finfo(dots.dtype).max
if mask is not None:
mask = F.pad(mask.flatten(1), (1, 0), value = True)
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
mask = mask[:, None, :] * mask[:, :, None]
dots.masked_fill_(~mask, float('-inf'))
dots.masked_fill_(~mask, mask_value)
del mask
attn = dots.softmax(dim=-1)
@@ -59,13 +69,13 @@ class Attention(nn.Module):
return out
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, mlp_dim):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Residual(PreNorm(dim, Attention(dim, heads = heads))),
Residual(PreNorm(dim, FeedForward(dim, mlp_dim)))
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
]))
def forward(self, x, mask = None):
for attn, ff in self.layers:
@@ -74,25 +84,29 @@ class Transformer(nn.Module):
return x
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size'
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.patch_size = patch_size
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.patch_to_embedding = nn.Linear(patch_dim, dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.transformer = Transformer(dim, depth, heads, mlp_dim)
self.dropout = nn.Dropout(emb_dropout)
self.to_cls_token = nn.Identity()
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.Linear(dim, mlp_dim),
nn.GELU(),
nn.Linear(mlp_dim, num_classes)
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img, mask = None):
@@ -100,11 +114,16 @@ class ViT(nn.Module):
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
x = self.patch_to_embedding(x)
b, n, _ = x.shape
cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x, mask)
x = self.to_cls_token(x[:, 0])
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)