Compare commits

..

10 Commits
0.2.7 ... 0.6.0

Author SHA1 Message Date
Phil Wang
aa9ed249a3 add knowledge distillation with distillation tokens, in light of new finding from facebook ai 2020-12-24 10:39:15 -08:00
Phil Wang
ea0924ec96 update readme 2020-12-23 19:06:48 -08:00
Phil Wang
59787a6b7e allow for mean pool with efficient version too 2020-12-23 18:15:40 -08:00
Phil Wang
24339644ca offer a way to use mean pooling of last layer 2020-12-23 17:23:58 -08:00
Phil Wang
b786029e18 fix the dimension per head to be independent of dim and heads, to make sure users do not have it be too small to learn anything 2020-12-17 07:43:52 -08:00
Phil Wang
9624181940 simplify mlp head 2020-12-07 14:31:50 -08:00
Phil Wang
a656a213e6 update diagram 2020-12-04 12:26:28 -08:00
Phil Wang
f1deb5fb7e Merge pull request #31 from minhlong94/main
Update README and documentation
2020-11-21 08:05:38 -08:00
Long M. Lưu
3f50dd72cf Update README.md 2020-11-21 18:37:03 +07:00
Long M. Lưu
ee5e4e9929 Update vit_pytorch.py 2020-11-21 18:23:04 +07:00
7 changed files with 184 additions and 25 deletions

View File

@@ -1,4 +1,4 @@
<img src="./vit.png" width="500px"></img>
<img src="./vit.gif" width="500px"></img>
## Vision Transformer - Pytorch
@@ -24,7 +24,7 @@ v = ViT(
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
@@ -36,6 +36,73 @@ mask = torch.ones(1, 8, 8).bool() # optional mask, designating which patch to at
preds = v(img, mask = mask) # (1, 1000)
```
## Parameters
- `image_size`: int.
Image size.
- `patch_size`: int.
Number of patches. `image_size` must be divisible by `patch_size`.
The number of patches is: ` n = (image_size // patch_size) ** 2` and `n` **must be greater than 16**.
- `num_classes`: int.
Number of classes to classify.
- `dim`: int.
Last dimension of output tensor after linear transformation `nn.Linear(..., dim)`.
- `depth`: int.
Number of Transformer blocks.
- `heads`: int.
Number of heads in Multi-head Attention layer.
- `mlp_dim`: int.
Dimension of the MLP (FeedForward) layer.
- `channels`: int, default `3`.
Number of image's channels.
- `dropout`: float between `[0, 1]`, default `0.`.
Dropout rate.
- `emb_dropout`: float between `[0, 1]`, default `0`.
Embedding dropout rate.
- `pool`: string, either `cls` token pooling or `mean` pooling
## Distillation
A recent <a href="https://arxiv.org/abs/2012.12877">paper</a> has shown that use of a distillation token for distilling knowledge from convolutional nets to vision transformer can yield small and efficient vision transformers. This repository offers the means to do distillation easily.
ex. distilling from Resnet50 (or any teacher) to a vision transformer
```python
import torch
from torchvision.models import resnet50
from vit_pytorch.distill import DistillableViT, DistillWrapper
teacher = resnet50(pretrained = True)
v = DistillableViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1,
pool = 'mean'
)
distiller = DistillWrapper(
student = v,
teacher = teacher,
temperature = 3, # temperature of distillation
alpha = 0.5 # trade between main loss and distillation loss
)
img = torch.randn(2, 3, 256, 256)
labels = torch.randint(0, 1000, (2,))
loss = distiller(img, labels)
loss.backward()
```
The `DistillableViT` class is identical to `ViT` except for how the forward pass is handled, so you should be able to load the parameters back to `ViT` after you have completed distillation training.
## Research Ideas
### Self Supervised Training
@@ -66,7 +133,7 @@ model = ViT(
learner = BYOL(
model,
image_size = 256,
hidden_layer = 'to_cls_token'
hidden_layer = 'to_latent'
)
opt = torch.optim.Adam(learner.parameters(), lr=3e-4)

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.2.7',
version = '0.6.0',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',

BIN
vit.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.8 MiB

BIN
vit.png

Binary file not shown.

Before

Width:  |  Height:  |  Size: 137 KiB

88
vit_pytorch/distill.py Normal file
View File

@@ -0,0 +1,88 @@
import torch
import torch.nn.functional as F
from torch import nn
from vit_pytorch.vit_pytorch import ViT
from einops import rearrange, repeat
# helpers
def exists(val):
return val is not None
# classes
class DistillableViT(ViT):
def __init__(self, *args, **kwargs):
super(DistillableViT, self).__init__(*args, **kwargs)
self.dim = kwargs['dim']
self.num_classes = kwargs['num_classes']
def forward(self, img, distill_token, mask = None):
p = self.patch_size
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
x = self.patch_to_embedding(x)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim = 1)
x += self.pos_embedding[:, :(n + 1)]
distill_tokens = repeat(distill_token, '() n d -> b n d', b = b)
x = torch.cat((x, distill_tokens), dim = 1)
x = self.dropout(x)
x = self.transformer(x, mask)
x, distill_tokens = x[:, :-1], x[:, -1]
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x), distill_tokens
class DistillWrapper(nn.Module):
def __init__(
self,
*,
teacher,
student,
temperature = 1.,
alpha = 0.5
):
super().__init__()
assert isinstance(student, DistillableViT), 'student must be a vision transformer'
self.teacher = teacher
self.student = student
dim = student.dim
num_classes = student.num_classes
self.temperature = temperature
self.alpha = alpha
self.distillation_token = nn.Parameter(torch.randn(1, 1, dim))
self.distill_mlp = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img, labels, temperature = None, **kwargs):
b, *_, alpha = *img.shape, self.alpha
T = temperature if exists(temperature) else self.temperature
teacher_logits = self.teacher(img)
student_logits, distill_tokens = self.student(img, distill_token = self.distillation_token, **kwargs)
distill_logits = self.distill_mlp(distill_tokens)
loss = F.cross_entropy(student_logits, labels)
distill_loss = F.kl_div(
F.log_softmax(distill_logits / T, dim = -1),
F.softmax(teacher_logits / T, dim = -1).detach(),
reduction = 'batchmean')
distill_loss *= T ** 2
return loss * alpha + distill_loss * (1 - alpha)

View File

@@ -3,9 +3,10 @@ from einops import rearrange, repeat
from torch import nn
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, transformer, channels = 3):
def __init__(self, *, image_size, patch_size, num_classes, dim, transformer, pool = 'cls', channels = 3):
super().__init__()
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
@@ -16,13 +17,12 @@ class ViT(nn.Module):
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.transformer = transformer
self.to_cls_token = nn.Identity()
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, num_classes)
nn.Linear(dim, num_classes)
)
def forward(self, img):
@@ -37,5 +37,7 @@ class ViT(nn.Module):
x += self.pos_embedding[:, :(n + 1)]
x = self.transformer(x)
x = self.to_cls_token(x[:, 0])
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)

View File

@@ -34,14 +34,15 @@ class FeedForward(nn.Module):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dropout = 0.):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim ** -0.5
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(dim, dim),
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
@@ -68,12 +69,12 @@ class Attention(nn.Module):
return out
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, mlp_dim, dropout):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Residual(PreNorm(dim, Attention(dim, heads = heads, dropout = dropout))),
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
]))
def forward(self, x, mask = None):
@@ -83,12 +84,13 @@ class Transformer(nn.Module):
return x
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dropout = 0., emb_dropout = 0.):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective. try decreasing your patch size'
assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size'
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.patch_size = patch_size
@@ -97,16 +99,14 @@ class ViT(nn.Module):
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.to_cls_token = nn.Identity()
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, mlp_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_dim, num_classes)
nn.Linear(dim, num_classes)
)
def forward(self, img, mask = None):
@@ -123,5 +123,7 @@ class ViT(nn.Module):
x = self.transformer(x, mask)
x = self.to_cls_token(x[:, 0])
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)