Compare commits

...

14 Commits
0.6.1 ... 0.6.5

Author SHA1 Message Date
Phil Wang
e8ca6038c9 allow for DistillableVit to still run predictions 2021-01-11 10:49:14 -08:00
Phil Wang
1106a2ba88 link to official repo 2021-01-08 08:23:50 -08:00
Phil Wang
f95fa59422 link to resources for vision people 2021-01-04 10:10:54 -08:00
Phil Wang
be1712ebe2 add quote 2020-12-28 10:22:59 -08:00
Phil Wang
1a76944124 update readme 2020-12-27 19:10:38 -08:00
Phil Wang
2263b7396f allow distillable efficient vit to restore efficient vit as well 2020-12-25 19:31:25 -08:00
Phil Wang
74074e2b6c offer easy way to turn DistillableViT to ViT at the end of training 2020-12-25 11:16:52 -08:00
Phil Wang
0c68688d61 bump for release 2020-12-25 09:30:48 -08:00
Phil Wang
5918f301a2 cleanup 2020-12-25 09:30:38 -08:00
Phil Wang
4a6469eecc Merge pull request #51 from umbertov/main
Add class for distillation with efficient attention
2020-12-25 09:21:17 -08:00
Umberto Valleriani
5a225c8e3f Add class for distillation with efficient attention
DistillableEfficientViT does the same as DistillableViT, except it
may accept a custom transformer encoder, possibly implementing an
efficient attention mechanism
2020-12-25 13:46:29 +01:00
Phil Wang
e0007bd801 add distill diagram 2020-12-24 11:34:15 -08:00
Phil Wang
db98ed7a8e allow for overriding alpha as well on forward in distillation wrapper 2020-12-24 11:18:36 -08:00
Phil Wang
dc4b3327ce no grad for teacher in distillation 2020-12-24 11:11:58 -08:00
4 changed files with 124 additions and 22 deletions

View File

@@ -4,7 +4,9 @@
Implementation of <a href="https://openreview.net/pdf?id=YicbFdNTTy">Vision Transformer</a>, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. Significance is further explained in <a href="https://www.youtube.com/watch?v=TrdevFK_am4">Yannic Kilcher's</a> video. There's really not much to code here, but may as well lay it out for everyone so we expedite the attention revolution.
For a Pytorch implementation with pretrained models, please see Ross Wightman's repository <a href="https://github.com/rwightman/pytorch-image-models">here</a>
For a Pytorch implementation with pretrained models, please see Ross Wightman's repository <a href="https://github.com/rwightman/pytorch-image-models">here</a>.
The official Jax repository is <a href="https://github.com/google-research/vision_transformer">here</a>.
## Install
@@ -62,6 +64,8 @@ Embedding dropout rate.
## Distillation
<img src="./distill.png" width="300px"></img>
A recent <a href="https://arxiv.org/abs/2012.12877">paper</a> has shown that use of a distillation token for distilling knowledge from convolutional nets to vision transformer can yield small and efficient vision transformers. This repository offers the means to do distillation easily.
ex. distilling from Resnet50 (or any teacher) to a vision transformer
@@ -83,8 +87,7 @@ v = DistillableViT(
heads = 8,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1,
pool = 'mean'
emb_dropout = 0.1
)
distiller = DistillWrapper(
@@ -99,10 +102,21 @@ labels = torch.randint(0, 1000, (2,))
loss = distiller(img, labels)
loss.backward()
# after lots of training above ...
pred = v(img) # (2, 1000)
```
The `DistillableViT` class is identical to `ViT` except for how the forward pass is handled, so you should be able to load the parameters back to `ViT` after you have completed distillation training.
You can also use the handy `.to_vit` method on the `DistillableViT` instance to get back a `ViT` instance.
```python
v = v.to_vit()
type(v) # <class 'vit_pytorch.vit_pytorch.ViT'>
```
## Research Ideas
### Self Supervised Training
@@ -192,6 +206,50 @@ v(img) # (1, 1000)
Other sparse attention frameworks I would highly recommend is <a href="https://github.com/lucidrains/routing-transformer">Routing Transformer</a> or <a href="https://github.com/lucidrains/sinkhorn-transformer">Sinkhorn Transformer</a>
### Combining with other Transformer improvements
This paper purposely used the most vanilla of attention networks to make a statement. If you would like to use some of the latest improvements for attention nets, please use the `Encoder` from <a href="https://github.com/lucidrains/x-transformers">this repository</a>.
ex.
```bash
$ pip install x-transformers
```
```python
import torch
from vit_pytorch.efficient import ViT
from x_transformers import Encoder
v = ViT(
dim = 512,
image_size = 224,
patch_size = 16,
num_classes = 1000,
transformer = Encoder(
dim = 512, # set to be the same as the wrapper
depth = 12,
heads = 8,
ff_glu = True, # ex. feed forward GLU variant https://arxiv.org/abs/2002.05202
residual_attn = True # ex. residual attention https://arxiv.org/abs/2012.11747
)
)
img = torch.randn(1, 3, 224, 224)
v(img) # (1, 1000)
```
## Resources
Coming from computer vision and new to transformers? Here are some resources that greatly accelerated my learning.
1. <a href="http://jalammar.github.io/illustrated-transformer/">Illustrated Transformer</a> - Jay Alammar
2. <a href="http://peterbloem.nl/blog/transformers">Transformers from Scratch</a> - Peter Bloem
3. <a href="https://nlp.seas.harvard.edu/2018/04/03/attention.html">The Annotated Transformer</a> - Harvard NLP
## Citations
```bibtex
@@ -226,3 +284,5 @@ Other sparse attention frameworks I would highly recommend is <a href="https://g
primaryClass = {cs.CL}
}
```
*I visualise a time when we will be to robots what dogs are to humans, and Im rooting for the machines.* — Claude Shannon

BIN
distill.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 49 KiB

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.6.0',
version = '0.6.6',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',

View File

@@ -2,6 +2,7 @@ import torch
import torch.nn.functional as F
from torch import nn
from vit_pytorch.vit_pytorch import ViT
from vit_pytorch.efficient import ViT as EfficientViT
from einops import rearrange, repeat
@@ -12,14 +13,9 @@ def exists(val):
# classes
class DistillableViT(ViT):
def __init__(self, *args, **kwargs):
super(DistillableViT, self).__init__(*args, **kwargs)
self.dim = kwargs['dim']
self.num_classes = kwargs['num_classes']
def forward(self, img, distill_token, mask = None):
p = self.patch_size
class DistillMixin:
def forward(self, img, distill_token = None, mask = None):
p, distilling = self.patch_size, exists(distill_token)
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
x = self.patch_to_embedding(x)
@@ -29,18 +25,60 @@ class DistillableViT(ViT):
x = torch.cat((cls_tokens, x), dim = 1)
x += self.pos_embedding[:, :(n + 1)]
distill_tokens = repeat(distill_token, '() n d -> b n d', b = b)
x = torch.cat((x, distill_tokens), dim = 1)
if distilling:
distill_tokens = repeat(distill_token, '() n d -> b n d', b = b)
x = torch.cat((x, distill_tokens), dim = 1)
x = self.dropout(x)
x = self._attend(x, mask)
x = self.transformer(x, mask)
if distilling:
x, distill_tokens = x[:, :-1], x[:, -1]
x, distill_tokens = x[:, :-1], x[:, -1]
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x), distill_tokens
out = self.mlp_head(x)
if distilling:
return out, distill_tokens
return out
class DistillableViT(DistillMixin, ViT):
def __init__(self, *args, **kwargs):
super(DistillableViT, self).__init__(*args, **kwargs)
self.args = args
self.kwargs = kwargs
self.dim = kwargs['dim']
self.num_classes = kwargs['num_classes']
def to_vit(self):
v = ViT(*self.args, **self.kwargs)
v.load_state_dict(self.state_dict())
return v
def _attend(self, x, mask):
x = self.dropout(x)
x = self.transformer(x, mask)
return x
class DistillableEfficientViT(DistillMixin, EfficientViT):
def __init__(self, *args, **kwargs):
super(DistillableEfficientViT, self).__init__(*args, **kwargs)
self.args = args
self.kwargs = kwargs
self.dim = kwargs['dim']
self.num_classes = kwargs['num_classes']
def to_vit(self):
v = EfficientViT(*self.args, **self.kwargs)
v.load_state_dict(self.state_dict())
return v
def _attend(self, x, mask):
return self.transformer(x)
# knowledge distillation wrapper
class DistillWrapper(nn.Module):
def __init__(
@@ -52,7 +90,8 @@ class DistillWrapper(nn.Module):
alpha = 0.5
):
super().__init__()
assert isinstance(student, DistillableViT), 'student must be a vision transformer'
assert (isinstance(student, (DistillableViT, DistillableEfficientViT))) , 'student must be a vision transformer'
self.teacher = teacher
self.student = student
@@ -68,11 +107,14 @@ class DistillWrapper(nn.Module):
nn.Linear(dim, num_classes)
)
def forward(self, img, labels, temperature = None, **kwargs):
b, *_, alpha = *img.shape, self.alpha
def forward(self, img, labels, temperature = None, alpha = None, **kwargs):
b, *_ = img.shape
alpha = alpha if exists(alpha) else self.alpha
T = temperature if exists(temperature) else self.temperature
teacher_logits = self.teacher(img)
with torch.no_grad():
teacher_logits = self.teacher(img)
student_logits, distill_tokens = self.student(img, distill_token = self.distillation_token, **kwargs)
distill_logits = self.distill_mlp(distill_tokens)