Compare commits

...

3 Commits
0.0.2 ... 0.0.3

Author SHA1 Message Date
Phil Wang
8fb261ca66 fix a bug and add suggestion for BYOL pre-training 2020-10-04 14:55:29 -07:00
Phil Wang
112ba5c476 update with link to Yannics video 2020-10-04 13:53:47 -07:00
Phil Wang
f899226d4f add diagram 2020-10-04 12:47:08 -07:00
4 changed files with 61 additions and 4 deletions

View File

@@ -1,6 +1,8 @@
<img src="./vit.png" width="500px"></img>
## Vision Transformer - Pytorch
Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. There's really not much to code here, but may as well lay out all the code so we expedite the attention revolution and get everyone on the same page.
Implementation of <a href="https://openreview.net/pdf?id=YicbFdNTTy">Vision Transformer</a>, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. Significance is further explained in <a href="https://www.youtube.com/watch?v=TrdevFK_am4">Yannic Kilcher's</a> video. There's really not much to code here, but may as well lay it out for everyone so we expedite the attention revolution.
## Install
@@ -28,6 +30,56 @@ img = torch.randn(1, 3, 256, 256)
preds = v(img) # (1, 1000)
```
## Suggestion
You can train this with a near SOTA self-supervised learning technique, <a href="https://github.com/lucidrains/byol-pytorch">BYOL</a>, with the following code.
(1)
```bash
$ pip install byol-pytorch
```
(2)
```python
import torch
from vit_pytorch import ViT
from byol_pytorch import BYOL
model = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048
)
learner = BYOL(
model,
image_size = 256,
hidden_layer = 'to_cls_token'
)
opt = torch.optim.Adam(learner.parameters(), lr=3e-4)
def sample_unlabelled_images():
return torch.randn(20, 3, 256, 256)
for _ in range(100):
images = sample_unlabelled_images()
loss = learner(images)
opt.zero_grad()
loss.backward()
opt.step()
learner.update_moving_average() # update moving average of target encoder
# save your improved network
torch.save(model.state_dict(), './pretrained-net.pt')
```
A pytorch-lightning script is ready for you to use at the repository link above.
## Citations
```bibtex

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(),
version = '0.0.2',
version = '0.0.3',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',

BIN
vit.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 137 KiB

View File

@@ -77,6 +77,8 @@ class ViT(nn.Module):
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.transformer = Transformer(dim, depth, heads, mlp_dim)
self.to_cls_token = nn.Identity()
self.mlp_head = nn.Sequential(
nn.Linear(dim, mlp_dim),
nn.GELU(),
@@ -88,8 +90,11 @@ class ViT(nn.Module):
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
x = self.patch_to_embedding(x)
x = torch.cat((self.cls_token, x), dim=1)
cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding
x = self.transformer(x)
return self.mlp_head(x[:, 0])
x = self.to_cls_token(x[:, 0])
return self.mlp_head(x)