Compare commits

...

19 Commits
0.6.1 ... 0.7.0

Author SHA1 Message Date
Phil Wang
bed61becd9 add token to token ViT 2021-02-19 22:25:58 -08:00
Phil Wang
4fc7365356 incept idea for using nystromformer 2021-02-17 15:30:45 -08:00
Phil Wang
3f2cbc6e23 fix for ambiguity in broadcasting mask 2021-02-17 07:38:11 -08:00
Phil Wang
85314cf0b6 patch for scaling factor, thanks to @urkax 2021-01-21 09:39:42 -08:00
Phil Wang
5db8d9deed update readme about non-square images 2021-01-12 06:55:45 -08:00
Phil Wang
e8ca6038c9 allow for DistillableVit to still run predictions 2021-01-11 10:49:14 -08:00
Phil Wang
1106a2ba88 link to official repo 2021-01-08 08:23:50 -08:00
Phil Wang
f95fa59422 link to resources for vision people 2021-01-04 10:10:54 -08:00
Phil Wang
be1712ebe2 add quote 2020-12-28 10:22:59 -08:00
Phil Wang
1a76944124 update readme 2020-12-27 19:10:38 -08:00
Phil Wang
2263b7396f allow distillable efficient vit to restore efficient vit as well 2020-12-25 19:31:25 -08:00
Phil Wang
74074e2b6c offer easy way to turn DistillableViT to ViT at the end of training 2020-12-25 11:16:52 -08:00
Phil Wang
0c68688d61 bump for release 2020-12-25 09:30:48 -08:00
Phil Wang
5918f301a2 cleanup 2020-12-25 09:30:38 -08:00
Phil Wang
4a6469eecc Merge pull request #51 from umbertov/main
Add class for distillation with efficient attention
2020-12-25 09:21:17 -08:00
Umberto Valleriani
5a225c8e3f Add class for distillation with efficient attention
DistillableEfficientViT does the same as DistillableViT, except it
may accept a custom transformer encoder, possibly implementing an
efficient attention mechanism
2020-12-25 13:46:29 +01:00
Phil Wang
e0007bd801 add distill diagram 2020-12-24 11:34:15 -08:00
Phil Wang
db98ed7a8e allow for overriding alpha as well on forward in distillation wrapper 2020-12-24 11:18:36 -08:00
Phil Wang
dc4b3327ce no grad for teacher in distillation 2020-12-24 11:11:58 -08:00
7 changed files with 252 additions and 41 deletions

115
README.md
View File

@@ -4,7 +4,9 @@
Implementation of <a href="https://openreview.net/pdf?id=YicbFdNTTy">Vision Transformer</a>, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. Significance is further explained in <a href="https://www.youtube.com/watch?v=TrdevFK_am4">Yannic Kilcher's</a> video. There's really not much to code here, but may as well lay it out for everyone so we expedite the attention revolution.
For a Pytorch implementation with pretrained models, please see Ross Wightman's repository <a href="https://github.com/rwightman/pytorch-image-models">here</a>
For a Pytorch implementation with pretrained models, please see Ross Wightman's repository <a href="https://github.com/rwightman/pytorch-image-models">here</a>.
The official Jax repository is <a href="https://github.com/google-research/vision_transformer">here</a>.
## Install
@@ -38,7 +40,7 @@ preds = v(img, mask = mask) # (1, 1000)
## Parameters
- `image_size`: int.
Image size.
Image size. If you have rectangular images, make sure your image size is the maximum of the width and height
- `patch_size`: int.
Number of patches. `image_size` must be divisible by `patch_size`.
The number of patches is: ` n = (image_size // patch_size) ** 2` and `n` **must be greater than 16**.
@@ -62,6 +64,8 @@ Embedding dropout rate.
## Distillation
<img src="./distill.png" width="300px"></img>
A recent <a href="https://arxiv.org/abs/2012.12877">paper</a> has shown that use of a distillation token for distilling knowledge from convolutional nets to vision transformer can yield small and efficient vision transformers. This repository offers the means to do distillation easily.
ex. distilling from Resnet50 (or any teacher) to a vision transformer
@@ -83,8 +87,7 @@ v = DistillableViT(
heads = 8,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1,
pool = 'mean'
emb_dropout = 0.1
)
distiller = DistillWrapper(
@@ -99,10 +102,46 @@ labels = torch.randint(0, 1000, (2,))
loss = distiller(img, labels)
loss.backward()
# after lots of training above ...
pred = v(img) # (2, 1000)
```
The `DistillableViT` class is identical to `ViT` except for how the forward pass is handled, so you should be able to load the parameters back to `ViT` after you have completed distillation training.
You can also use the handy `.to_vit` method on the `DistillableViT` instance to get back a `ViT` instance.
```python
v = v.to_vit()
type(v) # <class 'vit_pytorch.vit_pytorch.ViT'>
```
## Token to Token ViT
<img src="./t2t.png" width="400px"></img>
This paper proposes that the first couple layers should downsample the image sequence by unfolding, leading to overlapping image data in each token as shown in the figure above. You can use this variant of the `ViT` as follows.
```python
import torch
from vit_pytorch.t2t import T2TViT
v = T2TViT(
dim = 512,
image_size = 224,
patch_size = 16,
depth = 5,
heads = 8,
mlp_dim = 512,
num_classes = 1000,
t2t_layers = ((7, 4), (3, 2), (3, 2)) # tuples of the kernel size and stride of each consecutive layer of unfolds, for the initial 'token to token' downsampling layers
)
img = torch.randn(1, 3, 224, 224)
v(img) # (1, 1000)
```
## Research Ideas
### Self Supervised Training
@@ -159,23 +198,22 @@ A pytorch-lightning script is ready for you to use at the repository link above.
There may be some coming from computer vision who think attention still suffers from quadratic costs. Fortunately, we have a lot of new techniques that may help. This repository offers a way for you to plugin your own sparse attention transformer.
An example with <a href="https://arxiv.org/abs/2006.04768">Linformer</a>
An example with <a href="https://arxiv.org/abs/2102.03902">Nystromformer</a>
```bash
$ pip install linformer
$ pip install nystrom-attention
```
```python
import torch
from vit_pytorch.efficient import ViT
from linformer import Linformer
from nystrom_attention import Nystromformer
efficient_transformer = Linformer(
efficient_transformer = Nystromformer(
dim = 512,
seq_len = 4096 + 1, # 64 x 64 patches + 1 cls token
depth = 12,
heads = 8,
k = 256
num_landmarks = 256
)
v = ViT(
@@ -192,6 +230,50 @@ v(img) # (1, 1000)
Other sparse attention frameworks I would highly recommend is <a href="https://github.com/lucidrains/routing-transformer">Routing Transformer</a> or <a href="https://github.com/lucidrains/sinkhorn-transformer">Sinkhorn Transformer</a>
### Combining with other Transformer improvements
This paper purposely used the most vanilla of attention networks to make a statement. If you would like to use some of the latest improvements for attention nets, please use the `Encoder` from <a href="https://github.com/lucidrains/x-transformers">this repository</a>.
ex.
```bash
$ pip install x-transformers
```
```python
import torch
from vit_pytorch.efficient import ViT
from x_transformers import Encoder
v = ViT(
dim = 512,
image_size = 224,
patch_size = 16,
num_classes = 1000,
transformer = Encoder(
dim = 512, # set to be the same as the wrapper
depth = 12,
heads = 8,
ff_glu = True, # ex. feed forward GLU variant https://arxiv.org/abs/2002.05202
residual_attn = True # ex. residual attention https://arxiv.org/abs/2012.11747
)
)
img = torch.randn(1, 3, 224, 224)
v(img) # (1, 1000)
```
## Resources
Coming from computer vision and new to transformers? Here are some resources that greatly accelerated my learning.
1. <a href="http://jalammar.github.io/illustrated-transformer/">Illustrated Transformer</a> - Jay Alammar
2. <a href="http://peterbloem.nl/blog/transformers">Transformers from Scratch</a> - Peter Bloem
3. <a href="https://nlp.seas.harvard.edu/2018/04/03/attention.html">The Annotated Transformer</a> - Harvard NLP
## Citations
```bibtex
@@ -216,6 +298,17 @@ Other sparse attention frameworks I would highly recommend is <a href="https://g
}
```
```bibtex
@misc{yuan2021tokenstotoken,
title = {Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet},
author = {Li Yuan and Yunpeng Chen and Tao Wang and Weihao Yu and Yujun Shi and Francis EH Tay and Jiashi Feng and Shuicheng Yan},
year = {2021},
eprint = {2101.11986},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@misc{vaswani2017attention,
title = {Attention Is All You Need},
@@ -226,3 +319,5 @@ Other sparse attention frameworks I would highly recommend is <a href="https://g
primaryClass = {cs.CL}
}
```
*I visualise a time when we will be to robots what dogs are to humans, and Im rooting for the machines.* — Claude Shannon

BIN
distill.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 49 KiB

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.6.0',
version = '0.7.0',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',

BIN
t2t.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 109 KiB

View File

@@ -2,6 +2,7 @@ import torch
import torch.nn.functional as F
from torch import nn
from vit_pytorch.vit_pytorch import ViT
from vit_pytorch.efficient import ViT as EfficientViT
from einops import rearrange, repeat
@@ -12,14 +13,9 @@ def exists(val):
# classes
class DistillableViT(ViT):
def __init__(self, *args, **kwargs):
super(DistillableViT, self).__init__(*args, **kwargs)
self.dim = kwargs['dim']
self.num_classes = kwargs['num_classes']
def forward(self, img, distill_token, mask = None):
p = self.patch_size
class DistillMixin:
def forward(self, img, distill_token = None, mask = None):
p, distilling = self.patch_size, exists(distill_token)
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
x = self.patch_to_embedding(x)
@@ -29,18 +25,60 @@ class DistillableViT(ViT):
x = torch.cat((cls_tokens, x), dim = 1)
x += self.pos_embedding[:, :(n + 1)]
distill_tokens = repeat(distill_token, '() n d -> b n d', b = b)
x = torch.cat((x, distill_tokens), dim = 1)
if distilling:
distill_tokens = repeat(distill_token, '() n d -> b n d', b = b)
x = torch.cat((x, distill_tokens), dim = 1)
x = self.dropout(x)
x = self._attend(x, mask)
x = self.transformer(x, mask)
if distilling:
x, distill_tokens = x[:, :-1], x[:, -1]
x, distill_tokens = x[:, :-1], x[:, -1]
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x), distill_tokens
out = self.mlp_head(x)
if distilling:
return out, distill_tokens
return out
class DistillableViT(DistillMixin, ViT):
def __init__(self, *args, **kwargs):
super(DistillableViT, self).__init__(*args, **kwargs)
self.args = args
self.kwargs = kwargs
self.dim = kwargs['dim']
self.num_classes = kwargs['num_classes']
def to_vit(self):
v = ViT(*self.args, **self.kwargs)
v.load_state_dict(self.state_dict())
return v
def _attend(self, x, mask):
x = self.dropout(x)
x = self.transformer(x, mask)
return x
class DistillableEfficientViT(DistillMixin, EfficientViT):
def __init__(self, *args, **kwargs):
super(DistillableEfficientViT, self).__init__(*args, **kwargs)
self.args = args
self.kwargs = kwargs
self.dim = kwargs['dim']
self.num_classes = kwargs['num_classes']
def to_vit(self):
v = EfficientViT(*self.args, **self.kwargs)
v.load_state_dict(self.state_dict())
return v
def _attend(self, x, mask):
return self.transformer(x)
# knowledge distillation wrapper
class DistillWrapper(nn.Module):
def __init__(
@@ -52,7 +90,8 @@ class DistillWrapper(nn.Module):
alpha = 0.5
):
super().__init__()
assert isinstance(student, DistillableViT), 'student must be a vision transformer'
assert (isinstance(student, (DistillableViT, DistillableEfficientViT))) , 'student must be a vision transformer'
self.teacher = teacher
self.student = student
@@ -68,11 +107,14 @@ class DistillWrapper(nn.Module):
nn.Linear(dim, num_classes)
)
def forward(self, img, labels, temperature = None, **kwargs):
b, *_, alpha = *img.shape, self.alpha
def forward(self, img, labels, temperature = None, alpha = None, **kwargs):
b, *_ = img.shape
alpha = alpha if exists(alpha) else self.alpha
T = temperature if exists(temperature) else self.temperature
teacher_logits = self.teacher(img)
with torch.no_grad():
teacher_logits = self.teacher(img)
student_logits, distill_tokens = self.student(img, distill_token = self.distillation_token, **kwargs)
distill_logits = self.distill_mlp(distill_tokens)

72
vit_pytorch/t2t.py Normal file
View File

@@ -0,0 +1,72 @@
import math
import torch
from torch import nn
from vit_pytorch.vit_pytorch import Transformer
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# classes
class RearrangeImage(nn.Module):
def forward(self, x):
return rearrange(x, 'b (h w) c -> b c h w', h = int(math.sqrt(x.shape[1])))
# main class
class T2TViT(nn.Module):
def __init__(
self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., t2t_layers = ((7, 4), (3, 2), (3, 2))):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
layers = []
layer_dim = channels
for i, (kernel_size, stride) in enumerate(t2t_layers):
layer_dim *= kernel_size ** 2
is_first = i == 0
layers.extend([
RearrangeImage() if not is_first else nn.Identity(),
nn.Unfold(kernel_size = kernel_size, stride = stride, padding = stride // 2),
Rearrange('b c n -> b n c'),
Transformer(dim = layer_dim, heads = 1, depth = 1, dim_head = layer_dim, mlp_dim = layer_dim, dropout = dropout),
])
layers.append(nn.Linear(layer_dim, dim))
self.to_patch_embedding = nn.Sequential(*layers)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)

View File

@@ -1,7 +1,10 @@
import math
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import nn
from einops.layers.torch import Rearrange
MIN_NUM_PATCHES = 16
@@ -38,7 +41,7 @@ class Attention(nn.Module):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim ** -0.5
self.scale = dim_head ** -0.5
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
@@ -51,25 +54,25 @@ class Attention(nn.Module):
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
mask_value = -torch.finfo(dots.dtype).max
if mask is not None:
mask = F.pad(mask.flatten(1), (1, 0), value = True)
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
mask = mask[:, None, :] * mask[:, :, None]
mask = rearrange(mask, 'b i -> b () i ()') * rearrange(mask, 'b j -> b () () j')
dots.masked_fill_(~mask, mask_value)
del mask
attn = dots.softmax(dim=-1)
out = torch.einsum('bhij,bhjd->bhid', attn, v)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
@@ -92,10 +95,12 @@ class ViT(nn.Module):
assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size'
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.patch_size = patch_size
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
nn.Linear(patch_dim, dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.patch_to_embedding = nn.Linear(patch_dim, dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
@@ -110,10 +115,7 @@ class ViT(nn.Module):
)
def forward(self, img, mask = None):
p = self.patch_size
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
x = self.patch_to_embedding(x)
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)