mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
fix a bug and add suggestion for BYOL pre-training
This commit is contained in:
50
README.md
50
README.md
@@ -30,6 +30,56 @@ img = torch.randn(1, 3, 256, 256)
|
|||||||
preds = v(img) # (1, 1000)
|
preds = v(img) # (1, 1000)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Suggestion
|
||||||
|
|
||||||
|
You can train this with a near SOTA self-supervised learning technique, <a href="https://github.com/lucidrains/byol-pytorch">BYOL</a>, with the following code.
|
||||||
|
|
||||||
|
(1)
|
||||||
|
```bash
|
||||||
|
$ pip install byol-pytorch
|
||||||
|
```
|
||||||
|
|
||||||
|
(2)
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from vit_pytorch import ViT
|
||||||
|
from byol_pytorch import BYOL
|
||||||
|
|
||||||
|
model = ViT(
|
||||||
|
image_size = 256,
|
||||||
|
patch_size = 32,
|
||||||
|
num_classes = 1000,
|
||||||
|
dim = 1024,
|
||||||
|
depth = 6,
|
||||||
|
heads = 8,
|
||||||
|
mlp_dim = 2048
|
||||||
|
)
|
||||||
|
|
||||||
|
learner = BYOL(
|
||||||
|
model,
|
||||||
|
image_size = 256,
|
||||||
|
hidden_layer = 'to_cls_token'
|
||||||
|
)
|
||||||
|
|
||||||
|
opt = torch.optim.Adam(learner.parameters(), lr=3e-4)
|
||||||
|
|
||||||
|
def sample_unlabelled_images():
|
||||||
|
return torch.randn(20, 3, 256, 256)
|
||||||
|
|
||||||
|
for _ in range(100):
|
||||||
|
images = sample_unlabelled_images()
|
||||||
|
loss = learner(images)
|
||||||
|
opt.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
opt.step()
|
||||||
|
learner.update_moving_average() # update moving average of target encoder
|
||||||
|
|
||||||
|
# save your improved network
|
||||||
|
torch.save(model.state_dict(), './pretrained-net.pt')
|
||||||
|
```
|
||||||
|
|
||||||
|
A pytorch-lightning script is ready for you to use at the repository link above.
|
||||||
|
|
||||||
## Citations
|
## Citations
|
||||||
|
|
||||||
```bibtex
|
```bibtex
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
|||||||
setup(
|
setup(
|
||||||
name = 'vit-pytorch',
|
name = 'vit-pytorch',
|
||||||
packages = find_packages(),
|
packages = find_packages(),
|
||||||
version = '0.0.2',
|
version = '0.0.3',
|
||||||
license='MIT',
|
license='MIT',
|
||||||
description = 'Vision Transformer (ViT) - Pytorch',
|
description = 'Vision Transformer (ViT) - Pytorch',
|
||||||
author = 'Phil Wang',
|
author = 'Phil Wang',
|
||||||
|
|||||||
@@ -77,6 +77,8 @@ class ViT(nn.Module):
|
|||||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||||
self.transformer = Transformer(dim, depth, heads, mlp_dim)
|
self.transformer = Transformer(dim, depth, heads, mlp_dim)
|
||||||
|
|
||||||
|
self.to_cls_token = nn.Identity()
|
||||||
|
|
||||||
self.mlp_head = nn.Sequential(
|
self.mlp_head = nn.Sequential(
|
||||||
nn.Linear(dim, mlp_dim),
|
nn.Linear(dim, mlp_dim),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
@@ -88,8 +90,11 @@ class ViT(nn.Module):
|
|||||||
|
|
||||||
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
|
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
|
||||||
x = self.patch_to_embedding(x)
|
x = self.patch_to_embedding(x)
|
||||||
x = torch.cat((self.cls_token, x), dim=1)
|
|
||||||
|
cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
|
||||||
|
x = torch.cat((cls_tokens, x), dim=1)
|
||||||
x += self.pos_embedding
|
x += self.pos_embedding
|
||||||
x = self.transformer(x)
|
x = self.transformer(x)
|
||||||
|
|
||||||
return self.mlp_head(x[:, 0])
|
x = self.to_cls_token(x[:, 0])
|
||||||
|
return self.mlp_head(x)
|
||||||
|
|||||||
Reference in New Issue
Block a user