mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2026-05-09 17:12:35 +00:00
Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bed61becd9 | ||
|
|
4fc7365356 | ||
|
|
3f2cbc6e23 | ||
|
|
85314cf0b6 | ||
|
|
5db8d9deed | ||
|
|
e8ca6038c9 | ||
|
|
1106a2ba88 | ||
|
|
f95fa59422 | ||
|
|
be1712ebe2 | ||
|
|
1a76944124 | ||
|
|
2263b7396f |
103
README.md
103
README.md
@@ -4,7 +4,9 @@
|
||||
|
||||
Implementation of <a href="https://openreview.net/pdf?id=YicbFdNTTy">Vision Transformer</a>, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. Significance is further explained in <a href="https://www.youtube.com/watch?v=TrdevFK_am4">Yannic Kilcher's</a> video. There's really not much to code here, but may as well lay it out for everyone so we expedite the attention revolution.
|
||||
|
||||
For a Pytorch implementation with pretrained models, please see Ross Wightman's repository <a href="https://github.com/rwightman/pytorch-image-models">here</a>
|
||||
For a Pytorch implementation with pretrained models, please see Ross Wightman's repository <a href="https://github.com/rwightman/pytorch-image-models">here</a>.
|
||||
|
||||
The official Jax repository is <a href="https://github.com/google-research/vision_transformer">here</a>.
|
||||
|
||||
## Install
|
||||
|
||||
@@ -38,7 +40,7 @@ preds = v(img, mask = mask) # (1, 1000)
|
||||
|
||||
## Parameters
|
||||
- `image_size`: int.
|
||||
Image size.
|
||||
Image size. If you have rectangular images, make sure your image size is the maximum of the width and height
|
||||
- `patch_size`: int.
|
||||
Number of patches. `image_size` must be divisible by `patch_size`.
|
||||
The number of patches is: ` n = (image_size // patch_size) ** 2` and `n` **must be greater than 16**.
|
||||
@@ -100,6 +102,10 @@ labels = torch.randint(0, 1000, (2,))
|
||||
|
||||
loss = distiller(img, labels)
|
||||
loss.backward()
|
||||
|
||||
# after lots of training above ...
|
||||
|
||||
pred = v(img) # (2, 1000)
|
||||
```
|
||||
|
||||
The `DistillableViT` class is identical to `ViT` except for how the forward pass is handled, so you should be able to load the parameters back to `ViT` after you have completed distillation training.
|
||||
@@ -111,6 +117,31 @@ v = v.to_vit()
|
||||
type(v) # <class 'vit_pytorch.vit_pytorch.ViT'>
|
||||
```
|
||||
|
||||
## Token to Token ViT
|
||||
|
||||
<img src="./t2t.png" width="400px"></img>
|
||||
|
||||
This paper proposes that the first couple layers should downsample the image sequence by unfolding, leading to overlapping image data in each token as shown in the figure above. You can use this variant of the `ViT` as follows.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.t2t import T2TViT
|
||||
|
||||
v = T2TViT(
|
||||
dim = 512,
|
||||
image_size = 224,
|
||||
patch_size = 16,
|
||||
depth = 5,
|
||||
heads = 8,
|
||||
mlp_dim = 512,
|
||||
num_classes = 1000,
|
||||
t2t_layers = ((7, 4), (3, 2), (3, 2)) # tuples of the kernel size and stride of each consecutive layer of unfolds, for the initial 'token to token' downsampling layers
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 224, 224)
|
||||
v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## Research Ideas
|
||||
|
||||
### Self Supervised Training
|
||||
@@ -167,23 +198,22 @@ A pytorch-lightning script is ready for you to use at the repository link above.
|
||||
|
||||
There may be some coming from computer vision who think attention still suffers from quadratic costs. Fortunately, we have a lot of new techniques that may help. This repository offers a way for you to plugin your own sparse attention transformer.
|
||||
|
||||
An example with <a href="https://arxiv.org/abs/2006.04768">Linformer</a>
|
||||
An example with <a href="https://arxiv.org/abs/2102.03902">Nystromformer</a>
|
||||
|
||||
```bash
|
||||
$ pip install linformer
|
||||
$ pip install nystrom-attention
|
||||
```
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.efficient import ViT
|
||||
from linformer import Linformer
|
||||
from nystrom_attention import Nystromformer
|
||||
|
||||
efficient_transformer = Linformer(
|
||||
efficient_transformer = Nystromformer(
|
||||
dim = 512,
|
||||
seq_len = 4096 + 1, # 64 x 64 patches + 1 cls token
|
||||
depth = 12,
|
||||
heads = 8,
|
||||
k = 256
|
||||
num_landmarks = 256
|
||||
)
|
||||
|
||||
v = ViT(
|
||||
@@ -200,6 +230,50 @@ v(img) # (1, 1000)
|
||||
|
||||
Other sparse attention frameworks I would highly recommend is <a href="https://github.com/lucidrains/routing-transformer">Routing Transformer</a> or <a href="https://github.com/lucidrains/sinkhorn-transformer">Sinkhorn Transformer</a>
|
||||
|
||||
### Combining with other Transformer improvements
|
||||
|
||||
This paper purposely used the most vanilla of attention networks to make a statement. If you would like to use some of the latest improvements for attention nets, please use the `Encoder` from <a href="https://github.com/lucidrains/x-transformers">this repository</a>.
|
||||
|
||||
ex.
|
||||
|
||||
```bash
|
||||
$ pip install x-transformers
|
||||
```
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.efficient import ViT
|
||||
from x_transformers import Encoder
|
||||
|
||||
v = ViT(
|
||||
dim = 512,
|
||||
image_size = 224,
|
||||
patch_size = 16,
|
||||
num_classes = 1000,
|
||||
transformer = Encoder(
|
||||
dim = 512, # set to be the same as the wrapper
|
||||
depth = 12,
|
||||
heads = 8,
|
||||
ff_glu = True, # ex. feed forward GLU variant https://arxiv.org/abs/2002.05202
|
||||
residual_attn = True # ex. residual attention https://arxiv.org/abs/2012.11747
|
||||
)
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 224, 224)
|
||||
v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
Coming from computer vision and new to transformers? Here are some resources that greatly accelerated my learning.
|
||||
|
||||
1. <a href="http://jalammar.github.io/illustrated-transformer/">Illustrated Transformer</a> - Jay Alammar
|
||||
|
||||
2. <a href="http://peterbloem.nl/blog/transformers">Transformers from Scratch</a> - Peter Bloem
|
||||
|
||||
3. <a href="https://nlp.seas.harvard.edu/2018/04/03/attention.html">The Annotated Transformer</a> - Harvard NLP
|
||||
|
||||
|
||||
## Citations
|
||||
|
||||
```bibtex
|
||||
@@ -224,6 +298,17 @@ Other sparse attention frameworks I would highly recommend is <a href="https://g
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{yuan2021tokenstotoken,
|
||||
title = {Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet},
|
||||
author = {Li Yuan and Yunpeng Chen and Tao Wang and Weihao Yu and Yujun Shi and Francis EH Tay and Jiashi Feng and Shuicheng Yan},
|
||||
year = {2021},
|
||||
eprint = {2101.11986},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{vaswani2017attention,
|
||||
title = {Attention Is All You Need},
|
||||
@@ -234,3 +319,5 @@ Other sparse attention frameworks I would highly recommend is <a href="https://g
|
||||
primaryClass = {cs.CL}
|
||||
}
|
||||
```
|
||||
|
||||
*I visualise a time when we will be to robots what dogs are to humans, and I’m rooting for the machines.* — Claude Shannon
|
||||
|
||||
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.6.4',
|
||||
version = '0.7.0',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
|
||||
@@ -13,7 +13,38 @@ def exists(val):
|
||||
|
||||
# classes
|
||||
|
||||
class DistillableViT(ViT):
|
||||
class DistillMixin:
|
||||
def forward(self, img, distill_token = None, mask = None):
|
||||
p, distilling = self.patch_size, exists(distill_token)
|
||||
|
||||
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
|
||||
x = self.patch_to_embedding(x)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim = 1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
|
||||
if distilling:
|
||||
distill_tokens = repeat(distill_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((x, distill_tokens), dim = 1)
|
||||
|
||||
x = self._attend(x, mask)
|
||||
|
||||
if distilling:
|
||||
x, distill_tokens = x[:, :-1], x[:, -1]
|
||||
|
||||
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
x = self.to_latent(x)
|
||||
out = self.mlp_head(x)
|
||||
|
||||
if distilling:
|
||||
return out, distill_tokens
|
||||
|
||||
return out
|
||||
|
||||
class DistillableViT(DistillMixin, ViT):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(DistillableViT, self).__init__(*args, **kwargs)
|
||||
self.args = args
|
||||
@@ -26,57 +57,26 @@ class DistillableViT(ViT):
|
||||
v.load_state_dict(self.state_dict())
|
||||
return v
|
||||
|
||||
def forward(self, img, distill_token, mask = None):
|
||||
p = self.patch_size
|
||||
|
||||
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
|
||||
x = self.patch_to_embedding(x)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim = 1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
|
||||
distill_tokens = repeat(distill_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((x, distill_tokens), dim = 1)
|
||||
|
||||
def _attend(self, x, mask):
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x, mask)
|
||||
return x
|
||||
|
||||
x, distill_tokens = x[:, :-1], x[:, -1]
|
||||
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.mlp_head(x), distill_tokens
|
||||
|
||||
class DistillableEfficientViT(EfficientViT):
|
||||
class DistillableEfficientViT(DistillMixin, EfficientViT):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(DistillableEfficientViT, self).__init__(*args, **kwargs)
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.dim = kwargs['dim']
|
||||
self.num_classes = kwargs['num_classes']
|
||||
|
||||
def forward(self, img, distill_token, mask = None):
|
||||
p = self.patch_size
|
||||
def to_vit(self):
|
||||
v = EfficientViT(*self.args, **self.kwargs)
|
||||
v.load_state_dict(self.state_dict())
|
||||
return v
|
||||
|
||||
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
|
||||
x = self.patch_to_embedding(x)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim = 1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
|
||||
distill_tokens = repeat(distill_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((x, distill_tokens), dim = 1)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
x, distill_tokens = x[:, :-1], x[:, -1]
|
||||
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.mlp_head(x), distill_tokens
|
||||
def _attend(self, x, mask):
|
||||
return self.transformer(x)
|
||||
|
||||
# knowledge distillation wrapper
|
||||
|
||||
|
||||
72
vit_pytorch/t2t.py
Normal file
72
vit_pytorch/t2t.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vit_pytorch.vit_pytorch import Transformer
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# classes
|
||||
|
||||
class RearrangeImage(nn.Module):
|
||||
def forward(self, x):
|
||||
return rearrange(x, 'b (h w) c -> b c h w', h = int(math.sqrt(x.shape[1])))
|
||||
|
||||
# main class
|
||||
|
||||
class T2TViT(nn.Module):
|
||||
def __init__(
|
||||
self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., t2t_layers = ((7, 4), (3, 2), (3, 2))):
|
||||
super().__init__()
|
||||
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
patch_dim = channels * patch_size ** 2
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
|
||||
layers = []
|
||||
layer_dim = channels
|
||||
|
||||
for i, (kernel_size, stride) in enumerate(t2t_layers):
|
||||
layer_dim *= kernel_size ** 2
|
||||
is_first = i == 0
|
||||
|
||||
layers.extend([
|
||||
RearrangeImage() if not is_first else nn.Identity(),
|
||||
nn.Unfold(kernel_size = kernel_size, stride = stride, padding = stride // 2),
|
||||
Rearrange('b c n -> b n c'),
|
||||
Transformer(dim = layer_dim, heads = 1, depth = 1, dim_head = layer_dim, mlp_dim = layer_dim, dropout = dropout),
|
||||
])
|
||||
|
||||
layers.append(nn.Linear(layer_dim, dim))
|
||||
self.to_patch_embedding = nn.Sequential(*layers)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
self.pool = pool
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.mlp_head(x)
|
||||
@@ -1,7 +1,10 @@
|
||||
import math
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from torch import nn
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
MIN_NUM_PATCHES = 16
|
||||
|
||||
@@ -38,7 +41,7 @@ class Attention(nn.Module):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim ** -0.5
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Sequential(
|
||||
@@ -51,25 +54,25 @@ class Attention(nn.Module):
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
||||
|
||||
dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
mask_value = -torch.finfo(dots.dtype).max
|
||||
|
||||
if mask is not None:
|
||||
mask = F.pad(mask.flatten(1), (1, 0), value = True)
|
||||
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
|
||||
mask = mask[:, None, :] * mask[:, :, None]
|
||||
mask = rearrange(mask, 'b i -> b () i ()') * rearrange(mask, 'b j -> b () () j')
|
||||
dots.masked_fill_(~mask, mask_value)
|
||||
del mask
|
||||
|
||||
attn = dots.softmax(dim=-1)
|
||||
|
||||
out = torch.einsum('bhij,bhjd->bhid', attn, v)
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
out = self.to_out(out)
|
||||
return out
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
@@ -92,10 +95,12 @@ class ViT(nn.Module):
|
||||
assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size'
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
||||
nn.Linear(patch_dim, dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.patch_to_embedding = nn.Linear(patch_dim, dim)
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
@@ -110,10 +115,7 @@ class ViT(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, img, mask = None):
|
||||
p = self.patch_size
|
||||
|
||||
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
|
||||
x = self.patch_to_embedding(x)
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
|
||||
Reference in New Issue
Block a user