mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 16:12:29 +00:00
Compare commits
28 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
aa9ed249a3 | ||
|
|
ea0924ec96 | ||
|
|
59787a6b7e | ||
|
|
24339644ca | ||
|
|
b786029e18 | ||
|
|
9624181940 | ||
|
|
a656a213e6 | ||
|
|
f1deb5fb7e | ||
|
|
3f50dd72cf | ||
|
|
ee5e4e9929 | ||
|
|
6c8dfc185e | ||
|
|
4f84ad7a64 | ||
|
|
c74bc781f0 | ||
|
|
dc5b89c942 | ||
|
|
c1043ab00c | ||
|
|
7a214d7109 | ||
|
|
6d1df1a970 | ||
|
|
d65a8c17a5 | ||
|
|
f7c164d910 | ||
|
|
c7b74e0bc3 | ||
|
|
5b5d98a3a7 | ||
|
|
b0e4790c24 | ||
|
|
0b2b3fc20c | ||
|
|
ced464dcb4 | ||
|
|
5bf45a2d4d | ||
|
|
fa32e22855 | ||
|
|
a0fa41070f | ||
|
|
b298031c17 |
146
README.md
146
README.md
@@ -1,9 +1,11 @@
|
||||
<img src="./vit.png" width="500px"></img>
|
||||
<img src="./vit.gif" width="500px"></img>
|
||||
|
||||
## Vision Transformer - Pytorch
|
||||
|
||||
Implementation of <a href="https://openreview.net/pdf?id=YicbFdNTTy">Vision Transformer</a>, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. Significance is further explained in <a href="https://www.youtube.com/watch?v=TrdevFK_am4">Yannic Kilcher's</a> video. There's really not much to code here, but may as well lay it out for everyone so we expedite the attention revolution.
|
||||
|
||||
For a Pytorch implementation with pretrained models, please see Ross Wightman's repository <a href="https://github.com/rwightman/pytorch-image-models">here</a>
|
||||
|
||||
## Install
|
||||
|
||||
```bash
|
||||
@@ -22,8 +24,10 @@ v = ViT(
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 256, 256)
|
||||
@@ -32,7 +36,76 @@ mask = torch.ones(1, 8, 8).bool() # optional mask, designating which patch to at
|
||||
preds = v(img, mask = mask) # (1, 1000)
|
||||
```
|
||||
|
||||
## Suggestion
|
||||
## Parameters
|
||||
- `image_size`: int.
|
||||
Image size.
|
||||
- `patch_size`: int.
|
||||
Number of patches. `image_size` must be divisible by `patch_size`.
|
||||
The number of patches is: ` n = (image_size // patch_size) ** 2` and `n` **must be greater than 16**.
|
||||
- `num_classes`: int.
|
||||
Number of classes to classify.
|
||||
- `dim`: int.
|
||||
Last dimension of output tensor after linear transformation `nn.Linear(..., dim)`.
|
||||
- `depth`: int.
|
||||
Number of Transformer blocks.
|
||||
- `heads`: int.
|
||||
Number of heads in Multi-head Attention layer.
|
||||
- `mlp_dim`: int.
|
||||
Dimension of the MLP (FeedForward) layer.
|
||||
- `channels`: int, default `3`.
|
||||
Number of image's channels.
|
||||
- `dropout`: float between `[0, 1]`, default `0.`.
|
||||
Dropout rate.
|
||||
- `emb_dropout`: float between `[0, 1]`, default `0`.
|
||||
Embedding dropout rate.
|
||||
- `pool`: string, either `cls` token pooling or `mean` pooling
|
||||
|
||||
## Distillation
|
||||
|
||||
A recent <a href="https://arxiv.org/abs/2012.12877">paper</a> has shown that use of a distillation token for distilling knowledge from convolutional nets to vision transformer can yield small and efficient vision transformers. This repository offers the means to do distillation easily.
|
||||
|
||||
ex. distilling from Resnet50 (or any teacher) to a vision transformer
|
||||
|
||||
```python
|
||||
import torch
|
||||
from torchvision.models import resnet50
|
||||
|
||||
from vit_pytorch.distill import DistillableViT, DistillWrapper
|
||||
|
||||
teacher = resnet50(pretrained = True)
|
||||
|
||||
v = DistillableViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1,
|
||||
pool = 'mean'
|
||||
)
|
||||
|
||||
distiller = DistillWrapper(
|
||||
student = v,
|
||||
teacher = teacher,
|
||||
temperature = 3, # temperature of distillation
|
||||
alpha = 0.5 # trade between main loss and distillation loss
|
||||
)
|
||||
|
||||
img = torch.randn(2, 3, 256, 256)
|
||||
labels = torch.randint(0, 1000, (2,))
|
||||
|
||||
loss = distiller(img, labels)
|
||||
loss.backward()
|
||||
```
|
||||
|
||||
The `DistillableViT` class is identical to `ViT` except for how the forward pass is handled, so you should be able to load the parameters back to `ViT` after you have completed distillation training.
|
||||
|
||||
## Research Ideas
|
||||
|
||||
### Self Supervised Training
|
||||
|
||||
You can train this with a near SOTA self-supervised learning technique, <a href="https://github.com/lucidrains/byol-pytorch">BYOL</a>, with the following code.
|
||||
|
||||
@@ -60,7 +133,7 @@ model = ViT(
|
||||
learner = BYOL(
|
||||
model,
|
||||
image_size = 256,
|
||||
hidden_layer = 'to_cls_token'
|
||||
hidden_layer = 'to_latent'
|
||||
)
|
||||
|
||||
opt = torch.optim.Adam(learner.parameters(), lr=3e-4)
|
||||
@@ -82,16 +155,63 @@ torch.save(model.state_dict(), './pretrained-net.pt')
|
||||
|
||||
A pytorch-lightning script is ready for you to use at the repository link above.
|
||||
|
||||
### Efficient Attention
|
||||
|
||||
There may be some coming from computer vision who think attention still suffers from quadratic costs. Fortunately, we have a lot of new techniques that may help. This repository offers a way for you to plugin your own sparse attention transformer.
|
||||
|
||||
An example with <a href="https://arxiv.org/abs/2006.04768">Linformer</a>
|
||||
|
||||
```bash
|
||||
$ pip install linformer
|
||||
```
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.efficient import ViT
|
||||
from linformer import Linformer
|
||||
|
||||
efficient_transformer = Linformer(
|
||||
dim = 512,
|
||||
seq_len = 4096 + 1, # 64 x 64 patches + 1 cls token
|
||||
depth = 12,
|
||||
heads = 8,
|
||||
k = 256
|
||||
)
|
||||
|
||||
v = ViT(
|
||||
dim = 512,
|
||||
image_size = 2048,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
transformer = efficient_transformer
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 2048, 2048) # your high resolution picture
|
||||
v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
Other sparse attention frameworks I would highly recommend is <a href="https://github.com/lucidrains/routing-transformer">Routing Transformer</a> or <a href="https://github.com/lucidrains/sinkhorn-transformer">Sinkhorn Transformer</a>
|
||||
|
||||
## Citations
|
||||
|
||||
```bibtex
|
||||
@inproceedings{
|
||||
anonymous2021an,
|
||||
title={An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
|
||||
author={Anonymous},
|
||||
booktitle={Submitted to International Conference on Learning Representations},
|
||||
year={2021},
|
||||
url={https://openreview.net/forum?id=YicbFdNTTy},
|
||||
note={under review}
|
||||
@misc{dosovitskiy2020image,
|
||||
title = {An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
|
||||
author = {Alexey Dosovitskiy and Lucas Beyer and Alexander Kolesnikov and Dirk Weissenborn and Xiaohua Zhai and Thomas Unterthiner and Mostafa Dehghani and Matthias Minderer and Georg Heigold and Sylvain Gelly and Jakob Uszkoreit and Neil Houlsby},
|
||||
year = {2020},
|
||||
eprint = {2010.11929},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{vaswani2017attention,
|
||||
title = {Attention Is All You Need},
|
||||
author = {Ashish Vaswani and Noam Shazeer and Niki Parmar and Jakob Uszkoreit and Llion Jones and Aidan N. Gomez and Lukasz Kaiser and Illia Polosukhin},
|
||||
year = {2017},
|
||||
eprint = {1706.03762},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CL}
|
||||
}
|
||||
```
|
||||
|
||||
6253
examples/cats_and_dogs.ipynb
Normal file
6253
examples/cats_and_dogs.ipynb
Normal file
File diff suppressed because one or more lines are too long
6
setup.py
6
setup.py
@@ -2,8 +2,8 @@ from setuptools import setup, find_packages
|
||||
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(),
|
||||
version = '0.0.5',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.6.0',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
@@ -25,4 +25,4 @@ setup(
|
||||
'License :: OSI Approved :: MIT License',
|
||||
'Programming Language :: Python :: 3.6',
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
88
vit_pytorch/distill.py
Normal file
88
vit_pytorch/distill.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from vit_pytorch.vit_pytorch import ViT
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
# classes
|
||||
|
||||
class DistillableViT(ViT):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(DistillableViT, self).__init__(*args, **kwargs)
|
||||
self.dim = kwargs['dim']
|
||||
self.num_classes = kwargs['num_classes']
|
||||
|
||||
def forward(self, img, distill_token, mask = None):
|
||||
p = self.patch_size
|
||||
|
||||
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
|
||||
x = self.patch_to_embedding(x)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim = 1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
|
||||
distill_tokens = repeat(distill_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((x, distill_tokens), dim = 1)
|
||||
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x, mask)
|
||||
|
||||
x, distill_tokens = x[:, :-1], x[:, -1]
|
||||
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.mlp_head(x), distill_tokens
|
||||
|
||||
class DistillWrapper(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
teacher,
|
||||
student,
|
||||
temperature = 1.,
|
||||
alpha = 0.5
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(student, DistillableViT), 'student must be a vision transformer'
|
||||
self.teacher = teacher
|
||||
self.student = student
|
||||
|
||||
dim = student.dim
|
||||
num_classes = student.num_classes
|
||||
self.temperature = temperature
|
||||
self.alpha = alpha
|
||||
|
||||
self.distillation_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
|
||||
self.distill_mlp = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img, labels, temperature = None, **kwargs):
|
||||
b, *_, alpha = *img.shape, self.alpha
|
||||
T = temperature if exists(temperature) else self.temperature
|
||||
|
||||
teacher_logits = self.teacher(img)
|
||||
student_logits, distill_tokens = self.student(img, distill_token = self.distillation_token, **kwargs)
|
||||
distill_logits = self.distill_mlp(distill_tokens)
|
||||
|
||||
loss = F.cross_entropy(student_logits, labels)
|
||||
|
||||
distill_loss = F.kl_div(
|
||||
F.log_softmax(distill_logits / T, dim = -1),
|
||||
F.softmax(teacher_logits / T, dim = -1).detach(),
|
||||
reduction = 'batchmean')
|
||||
|
||||
distill_loss *= T ** 2
|
||||
|
||||
return loss * alpha + distill_loss * (1 - alpha)
|
||||
43
vit_pytorch/efficient.py
Normal file
43
vit_pytorch/efficient.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from torch import nn
|
||||
|
||||
class ViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, transformer, pool = 'cls', channels = 3):
|
||||
super().__init__()
|
||||
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
patch_dim = channels * patch_size ** 2
|
||||
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.patch_to_embedding = nn.Linear(patch_dim, dim)
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.transformer = transformer
|
||||
|
||||
self.pool = pool
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
p = self.patch_size
|
||||
|
||||
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
|
||||
x = self.patch_to_embedding(x)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.transformer(x)
|
||||
|
||||
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.mlp_head(x)
|
||||
@@ -1,8 +1,10 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from einops import rearrange, repeat
|
||||
from torch import nn
|
||||
|
||||
MIN_NUM_PATCHES = 16
|
||||
|
||||
class Residual(nn.Module):
|
||||
def __init__(self, fn):
|
||||
super().__init__()
|
||||
@@ -19,36 +21,44 @@ class PreNorm(nn.Module):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, dim)
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim ** -0.5
|
||||
|
||||
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(dim, dim)
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, mask = None):
|
||||
b, n, _, h = *x.shape, self.heads
|
||||
qkv = self.to_qkv(x)
|
||||
q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = h)
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
||||
|
||||
dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
|
||||
mask_value = -torch.finfo(dots.dtype).max
|
||||
|
||||
if mask is not None:
|
||||
mask = F.pad(mask.flatten(1), (1, 0), value = True)
|
||||
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
|
||||
mask = mask[:, None, :] * mask[:, :, None]
|
||||
dots.masked_fill_(~mask, float('-inf'))
|
||||
dots.masked_fill_(~mask, mask_value)
|
||||
del mask
|
||||
|
||||
attn = dots.softmax(dim=-1)
|
||||
@@ -59,13 +69,13 @@ class Attention(nn.Module):
|
||||
return out
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, mlp_dim):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Residual(PreNorm(dim, Attention(dim, heads = heads))),
|
||||
Residual(PreNorm(dim, FeedForward(dim, mlp_dim)))
|
||||
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
|
||||
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
|
||||
]))
|
||||
def forward(self, x, mask = None):
|
||||
for attn, ff in self.layers:
|
||||
@@ -74,25 +84,29 @@ class Transformer(nn.Module):
|
||||
return x
|
||||
|
||||
class ViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
|
||||
super().__init__()
|
||||
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
|
||||
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
patch_dim = channels * patch_size ** 2
|
||||
assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size'
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.patch_to_embedding = nn.Linear(patch_dim, dim)
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.transformer = Transformer(dim, depth, heads, mlp_dim)
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.to_cls_token = nn.Identity()
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
self.pool = pool
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.Linear(dim, mlp_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(mlp_dim, num_classes)
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img, mask = None):
|
||||
@@ -100,11 +114,16 @@ class ViT(nn.Module):
|
||||
|
||||
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
|
||||
x = self.patch_to_embedding(x)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x, mask)
|
||||
|
||||
x = self.to_cls_token(x[:, 0])
|
||||
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.mlp_head(x)
|
||||
|
||||
Reference in New Issue
Block a user