Compare commits

...

20 Commits
0.5.1 ... 0.6.7

Author SHA1 Message Date
Phil Wang
85314cf0b6 patch for scaling factor, thanks to @urkax 2021-01-21 09:39:42 -08:00
Phil Wang
5db8d9deed update readme about non-square images 2021-01-12 06:55:45 -08:00
Phil Wang
e8ca6038c9 allow for DistillableVit to still run predictions 2021-01-11 10:49:14 -08:00
Phil Wang
1106a2ba88 link to official repo 2021-01-08 08:23:50 -08:00
Phil Wang
f95fa59422 link to resources for vision people 2021-01-04 10:10:54 -08:00
Phil Wang
be1712ebe2 add quote 2020-12-28 10:22:59 -08:00
Phil Wang
1a76944124 update readme 2020-12-27 19:10:38 -08:00
Phil Wang
2263b7396f allow distillable efficient vit to restore efficient vit as well 2020-12-25 19:31:25 -08:00
Phil Wang
74074e2b6c offer easy way to turn DistillableViT to ViT at the end of training 2020-12-25 11:16:52 -08:00
Phil Wang
0c68688d61 bump for release 2020-12-25 09:30:48 -08:00
Phil Wang
5918f301a2 cleanup 2020-12-25 09:30:38 -08:00
Phil Wang
4a6469eecc Merge pull request #51 from umbertov/main
Add class for distillation with efficient attention
2020-12-25 09:21:17 -08:00
Umberto Valleriani
5a225c8e3f Add class for distillation with efficient attention
DistillableEfficientViT does the same as DistillableViT, except it
may accept a custom transformer encoder, possibly implementing an
efficient attention mechanism
2020-12-25 13:46:29 +01:00
Phil Wang
e0007bd801 add distill diagram 2020-12-24 11:34:15 -08:00
Phil Wang
db98ed7a8e allow for overriding alpha as well on forward in distillation wrapper 2020-12-24 11:18:36 -08:00
Phil Wang
dc4b3327ce no grad for teacher in distillation 2020-12-24 11:11:58 -08:00
Phil Wang
aa8f0a7bf3 Update README.md 2020-12-24 10:59:03 -08:00
Phil Wang
34e6284f95 Update README.md 2020-12-24 10:58:41 -08:00
Phil Wang
aa9ed249a3 add knowledge distillation with distillation tokens, in light of new finding from facebook ai 2020-12-24 10:39:15 -08:00
Phil Wang
ea0924ec96 update readme 2020-12-23 19:06:48 -08:00
5 changed files with 248 additions and 4 deletions

118
README.md
View File

@@ -4,7 +4,9 @@
Implementation of <a href="https://openreview.net/pdf?id=YicbFdNTTy">Vision Transformer</a>, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. Significance is further explained in <a href="https://www.youtube.com/watch?v=TrdevFK_am4">Yannic Kilcher's</a> video. There's really not much to code here, but may as well lay it out for everyone so we expedite the attention revolution.
For a Pytorch implementation with pretrained models, please see Ross Wightman's repository <a href="https://github.com/rwightman/pytorch-image-models">here</a>
For a Pytorch implementation with pretrained models, please see Ross Wightman's repository <a href="https://github.com/rwightman/pytorch-image-models">here</a>.
The official Jax repository is <a href="https://github.com/google-research/vision_transformer">here</a>.
## Install
@@ -38,7 +40,7 @@ preds = v(img, mask = mask) # (1, 1000)
## Parameters
- `image_size`: int.
Image size.
Image size. If you have rectangular images, make sure your image size is the maximum of the width and height
- `patch_size`: int.
Number of patches. `image_size` must be divisible by `patch_size`.
The number of patches is: ` n = (image_size // patch_size) ** 2` and `n` **must be greater than 16**.
@@ -60,6 +62,61 @@ Dropout rate.
Embedding dropout rate.
- `pool`: string, either `cls` token pooling or `mean` pooling
## Distillation
<img src="./distill.png" width="300px"></img>
A recent <a href="https://arxiv.org/abs/2012.12877">paper</a> has shown that use of a distillation token for distilling knowledge from convolutional nets to vision transformer can yield small and efficient vision transformers. This repository offers the means to do distillation easily.
ex. distilling from Resnet50 (or any teacher) to a vision transformer
```python
import torch
from torchvision.models import resnet50
from vit_pytorch.distill import DistillableViT, DistillWrapper
teacher = resnet50(pretrained = True)
v = DistillableViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
distiller = DistillWrapper(
student = v,
teacher = teacher,
temperature = 3, # temperature of distillation
alpha = 0.5 # trade between main loss and distillation loss
)
img = torch.randn(2, 3, 256, 256)
labels = torch.randint(0, 1000, (2,))
loss = distiller(img, labels)
loss.backward()
# after lots of training above ...
pred = v(img) # (2, 1000)
```
The `DistillableViT` class is identical to `ViT` except for how the forward pass is handled, so you should be able to load the parameters back to `ViT` after you have completed distillation training.
You can also use the handy `.to_vit` method on the `DistillableViT` instance to get back a `ViT` instance.
```python
v = v.to_vit()
type(v) # <class 'vit_pytorch.vit_pytorch.ViT'>
```
## Research Ideas
### Self Supervised Training
@@ -149,6 +206,50 @@ v(img) # (1, 1000)
Other sparse attention frameworks I would highly recommend is <a href="https://github.com/lucidrains/routing-transformer">Routing Transformer</a> or <a href="https://github.com/lucidrains/sinkhorn-transformer">Sinkhorn Transformer</a>
### Combining with other Transformer improvements
This paper purposely used the most vanilla of attention networks to make a statement. If you would like to use some of the latest improvements for attention nets, please use the `Encoder` from <a href="https://github.com/lucidrains/x-transformers">this repository</a>.
ex.
```bash
$ pip install x-transformers
```
```python
import torch
from vit_pytorch.efficient import ViT
from x_transformers import Encoder
v = ViT(
dim = 512,
image_size = 224,
patch_size = 16,
num_classes = 1000,
transformer = Encoder(
dim = 512, # set to be the same as the wrapper
depth = 12,
heads = 8,
ff_glu = True, # ex. feed forward GLU variant https://arxiv.org/abs/2002.05202
residual_attn = True # ex. residual attention https://arxiv.org/abs/2012.11747
)
)
img = torch.randn(1, 3, 224, 224)
v(img) # (1, 1000)
```
## Resources
Coming from computer vision and new to transformers? Here are some resources that greatly accelerated my learning.
1. <a href="http://jalammar.github.io/illustrated-transformer/">Illustrated Transformer</a> - Jay Alammar
2. <a href="http://peterbloem.nl/blog/transformers">Transformers from Scratch</a> - Peter Bloem
3. <a href="https://nlp.seas.harvard.edu/2018/04/03/attention.html">The Annotated Transformer</a> - Harvard NLP
## Citations
```bibtex
@@ -162,6 +263,17 @@ Other sparse attention frameworks I would highly recommend is <a href="https://g
}
```
```bibtex
@misc{touvron2020training,
title = {Training data-efficient image transformers & distillation through attention},
author = {Hugo Touvron and Matthieu Cord and Matthijs Douze and Francisco Massa and Alexandre Sablayrolles and Hervé Jégou},
year = {2020},
eprint = {2012.12877},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@misc{vaswani2017attention,
title = {Attention Is All You Need},
@@ -172,3 +284,5 @@ Other sparse attention frameworks I would highly recommend is <a href="https://g
primaryClass = {cs.CL}
}
```
*I visualise a time when we will be to robots what dogs are to humans, and Im rooting for the machines.* — Claude Shannon

BIN
distill.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 49 KiB

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.5.1',
version = '0.6.7',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',

130
vit_pytorch/distill.py Normal file
View File

@@ -0,0 +1,130 @@
import torch
import torch.nn.functional as F
from torch import nn
from vit_pytorch.vit_pytorch import ViT
from vit_pytorch.efficient import ViT as EfficientViT
from einops import rearrange, repeat
# helpers
def exists(val):
return val is not None
# classes
class DistillMixin:
def forward(self, img, distill_token = None, mask = None):
p, distilling = self.patch_size, exists(distill_token)
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
x = self.patch_to_embedding(x)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim = 1)
x += self.pos_embedding[:, :(n + 1)]
if distilling:
distill_tokens = repeat(distill_token, '() n d -> b n d', b = b)
x = torch.cat((x, distill_tokens), dim = 1)
x = self._attend(x, mask)
if distilling:
x, distill_tokens = x[:, :-1], x[:, -1]
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
out = self.mlp_head(x)
if distilling:
return out, distill_tokens
return out
class DistillableViT(DistillMixin, ViT):
def __init__(self, *args, **kwargs):
super(DistillableViT, self).__init__(*args, **kwargs)
self.args = args
self.kwargs = kwargs
self.dim = kwargs['dim']
self.num_classes = kwargs['num_classes']
def to_vit(self):
v = ViT(*self.args, **self.kwargs)
v.load_state_dict(self.state_dict())
return v
def _attend(self, x, mask):
x = self.dropout(x)
x = self.transformer(x, mask)
return x
class DistillableEfficientViT(DistillMixin, EfficientViT):
def __init__(self, *args, **kwargs):
super(DistillableEfficientViT, self).__init__(*args, **kwargs)
self.args = args
self.kwargs = kwargs
self.dim = kwargs['dim']
self.num_classes = kwargs['num_classes']
def to_vit(self):
v = EfficientViT(*self.args, **self.kwargs)
v.load_state_dict(self.state_dict())
return v
def _attend(self, x, mask):
return self.transformer(x)
# knowledge distillation wrapper
class DistillWrapper(nn.Module):
def __init__(
self,
*,
teacher,
student,
temperature = 1.,
alpha = 0.5
):
super().__init__()
assert (isinstance(student, (DistillableViT, DistillableEfficientViT))) , 'student must be a vision transformer'
self.teacher = teacher
self.student = student
dim = student.dim
num_classes = student.num_classes
self.temperature = temperature
self.alpha = alpha
self.distillation_token = nn.Parameter(torch.randn(1, 1, dim))
self.distill_mlp = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img, labels, temperature = None, alpha = None, **kwargs):
b, *_ = img.shape
alpha = alpha if exists(alpha) else self.alpha
T = temperature if exists(temperature) else self.temperature
with torch.no_grad():
teacher_logits = self.teacher(img)
student_logits, distill_tokens = self.student(img, distill_token = self.distillation_token, **kwargs)
distill_logits = self.distill_mlp(distill_tokens)
loss = F.cross_entropy(student_logits, labels)
distill_loss = F.kl_div(
F.log_softmax(distill_logits / T, dim = -1),
F.softmax(teacher_logits / T, dim = -1).detach(),
reduction = 'batchmean')
distill_loss *= T ** 2
return loss * alpha + distill_loss * (1 - alpha)

View File

@@ -38,7 +38,7 @@ class Attention(nn.Module):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim ** -0.5
self.scale = dim_head ** -0.5
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(