mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2026-05-01 13:00:18 +00:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8fb261ca66 | ||
|
|
112ba5c476 | ||
|
|
f899226d4f |
54
README.md
54
README.md
@@ -1,6 +1,8 @@
|
||||
<img src="./vit.png" width="500px"></img>
|
||||
|
||||
## Vision Transformer - Pytorch
|
||||
|
||||
Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. There's really not much to code here, but may as well lay out all the code so we expedite the attention revolution and get everyone on the same page.
|
||||
Implementation of <a href="https://openreview.net/pdf?id=YicbFdNTTy">Vision Transformer</a>, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. Significance is further explained in <a href="https://www.youtube.com/watch?v=TrdevFK_am4">Yannic Kilcher's</a> video. There's really not much to code here, but may as well lay it out for everyone so we expedite the attention revolution.
|
||||
|
||||
## Install
|
||||
|
||||
@@ -28,6 +30,56 @@ img = torch.randn(1, 3, 256, 256)
|
||||
preds = v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## Suggestion
|
||||
|
||||
You can train this with a near SOTA self-supervised learning technique, <a href="https://github.com/lucidrains/byol-pytorch">BYOL</a>, with the following code.
|
||||
|
||||
(1)
|
||||
```bash
|
||||
$ pip install byol-pytorch
|
||||
```
|
||||
|
||||
(2)
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch import ViT
|
||||
from byol_pytorch import BYOL
|
||||
|
||||
model = ViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048
|
||||
)
|
||||
|
||||
learner = BYOL(
|
||||
model,
|
||||
image_size = 256,
|
||||
hidden_layer = 'to_cls_token'
|
||||
)
|
||||
|
||||
opt = torch.optim.Adam(learner.parameters(), lr=3e-4)
|
||||
|
||||
def sample_unlabelled_images():
|
||||
return torch.randn(20, 3, 256, 256)
|
||||
|
||||
for _ in range(100):
|
||||
images = sample_unlabelled_images()
|
||||
loss = learner(images)
|
||||
opt.zero_grad()
|
||||
loss.backward()
|
||||
opt.step()
|
||||
learner.update_moving_average() # update moving average of target encoder
|
||||
|
||||
# save your improved network
|
||||
torch.save(model.state_dict(), './pretrained-net.pt')
|
||||
```
|
||||
|
||||
A pytorch-lightning script is ready for you to use at the repository link above.
|
||||
|
||||
## Citations
|
||||
|
||||
```bibtex
|
||||
|
||||
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(),
|
||||
version = '0.0.2',
|
||||
version = '0.0.3',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
|
||||
@@ -77,6 +77,8 @@ class ViT(nn.Module):
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.transformer = Transformer(dim, depth, heads, mlp_dim)
|
||||
|
||||
self.to_cls_token = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.Linear(dim, mlp_dim),
|
||||
nn.GELU(),
|
||||
@@ -88,8 +90,11 @@ class ViT(nn.Module):
|
||||
|
||||
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
|
||||
x = self.patch_to_embedding(x)
|
||||
x = torch.cat((self.cls_token, x), dim=1)
|
||||
|
||||
cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding
|
||||
x = self.transformer(x)
|
||||
|
||||
return self.mlp_head(x[:, 0])
|
||||
x = self.to_cls_token(x[:, 0])
|
||||
return self.mlp_head(x)
|
||||
|
||||
Reference in New Issue
Block a user