mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
add knowledge distillation with distillation tokens, in light of new finding from facebook ai
This commit is contained in:
45
README.md
45
README.md
@@ -6,8 +6,6 @@ Implementation of <a href="https://openreview.net/pdf?id=YicbFdNTTy">Vision Tran
|
||||
|
||||
For a Pytorch implementation with pretrained models, please see Ross Wightman's repository <a href="https://github.com/rwightman/pytorch-image-models">here</a>
|
||||
|
||||
Update: A <a href="https://arxiv.org/abs/2012.12877">new paper</a> proposes using distillation tokens for distilling conv nets into efficient image transformers. I'll offer a way to toggle the use of distillation tokens in the coming days.
|
||||
|
||||
## Install
|
||||
|
||||
```bash
|
||||
@@ -62,6 +60,49 @@ Dropout rate.
|
||||
Embedding dropout rate.
|
||||
- `pool`: string, either `cls` token pooling or `mean` pooling
|
||||
|
||||
## Distillation
|
||||
|
||||
A recent <a href="https://arxiv.org/abs/2012.12877">paper</a> has shown that use of a distillation token for distilling knowledge from convolutional nets to vision transformer can yield small and efficient vision transformers. This repository offers the means to do distillation easily.
|
||||
|
||||
ex. distilling from Resnet50 (or any teacher) to a vision transformer
|
||||
|
||||
```python
|
||||
import torch
|
||||
from torchvision.models import resnet50
|
||||
|
||||
from vit_pytorch.distill import DistillableViT, DistillWrapper
|
||||
|
||||
teacher = resnet50(pretrained = True)
|
||||
|
||||
v = DistillableViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1,
|
||||
pool = 'mean'
|
||||
)
|
||||
|
||||
distiller = DistillWrapper(
|
||||
student = v,
|
||||
teacher = teacher,
|
||||
temperature = 3, # temperature of distillation
|
||||
alpha = 0.5 # trade between main loss and distillation loss
|
||||
)
|
||||
|
||||
img = torch.randn(2, 3, 256, 256)
|
||||
labels = torch.randint(0, 1000, (2,))
|
||||
|
||||
loss = distiller(img, labels)
|
||||
loss.backward()
|
||||
```
|
||||
|
||||
The `DistillableViT` class is identical to `ViT` except for how the forward pass is handled, so you should be able to load the parameters back to `ViT` after you have completed distillation training.
|
||||
|
||||
## Research Ideas
|
||||
|
||||
### Self Supervised Training
|
||||
|
||||
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.5.1',
|
||||
version = '0.6.0',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
|
||||
88
vit_pytorch/distill.py
Normal file
88
vit_pytorch/distill.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from vit_pytorch.vit_pytorch import ViT
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
# classes
|
||||
|
||||
class DistillableViT(ViT):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(DistillableViT, self).__init__(*args, **kwargs)
|
||||
self.dim = kwargs['dim']
|
||||
self.num_classes = kwargs['num_classes']
|
||||
|
||||
def forward(self, img, distill_token, mask = None):
|
||||
p = self.patch_size
|
||||
|
||||
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
|
||||
x = self.patch_to_embedding(x)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim = 1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
|
||||
distill_tokens = repeat(distill_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((x, distill_tokens), dim = 1)
|
||||
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x, mask)
|
||||
|
||||
x, distill_tokens = x[:, :-1], x[:, -1]
|
||||
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.mlp_head(x), distill_tokens
|
||||
|
||||
class DistillWrapper(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
teacher,
|
||||
student,
|
||||
temperature = 1.,
|
||||
alpha = 0.5
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(student, DistillableViT), 'student must be a vision transformer'
|
||||
self.teacher = teacher
|
||||
self.student = student
|
||||
|
||||
dim = student.dim
|
||||
num_classes = student.num_classes
|
||||
self.temperature = temperature
|
||||
self.alpha = alpha
|
||||
|
||||
self.distillation_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
|
||||
self.distill_mlp = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img, labels, temperature = None, **kwargs):
|
||||
b, *_, alpha = *img.shape, self.alpha
|
||||
T = temperature if exists(temperature) else self.temperature
|
||||
|
||||
teacher_logits = self.teacher(img)
|
||||
student_logits, distill_tokens = self.student(img, distill_token = self.distillation_token, **kwargs)
|
||||
distill_logits = self.distill_mlp(distill_tokens)
|
||||
|
||||
loss = F.cross_entropy(student_logits, labels)
|
||||
|
||||
distill_loss = F.kl_div(
|
||||
F.log_softmax(distill_logits / T, dim = -1),
|
||||
F.softmax(teacher_logits / T, dim = -1).detach(),
|
||||
reduction = 'batchmean')
|
||||
|
||||
distill_loss *= T ** 2
|
||||
|
||||
return loss * alpha + distill_loss * (1 - alpha)
|
||||
Reference in New Issue
Block a user