mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
fix a bug and add suggestion for BYOL pre-training
This commit is contained in:
50
README.md
50
README.md
@@ -30,6 +30,56 @@ img = torch.randn(1, 3, 256, 256)
|
||||
preds = v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## Suggestion
|
||||
|
||||
You can train this with a near SOTA self-supervised learning technique, <a href="https://github.com/lucidrains/byol-pytorch">BYOL</a>, with the following code.
|
||||
|
||||
(1)
|
||||
```bash
|
||||
$ pip install byol-pytorch
|
||||
```
|
||||
|
||||
(2)
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch import ViT
|
||||
from byol_pytorch import BYOL
|
||||
|
||||
model = ViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048
|
||||
)
|
||||
|
||||
learner = BYOL(
|
||||
model,
|
||||
image_size = 256,
|
||||
hidden_layer = 'to_cls_token'
|
||||
)
|
||||
|
||||
opt = torch.optim.Adam(learner.parameters(), lr=3e-4)
|
||||
|
||||
def sample_unlabelled_images():
|
||||
return torch.randn(20, 3, 256, 256)
|
||||
|
||||
for _ in range(100):
|
||||
images = sample_unlabelled_images()
|
||||
loss = learner(images)
|
||||
opt.zero_grad()
|
||||
loss.backward()
|
||||
opt.step()
|
||||
learner.update_moving_average() # update moving average of target encoder
|
||||
|
||||
# save your improved network
|
||||
torch.save(model.state_dict(), './pretrained-net.pt')
|
||||
```
|
||||
|
||||
A pytorch-lightning script is ready for you to use at the repository link above.
|
||||
|
||||
## Citations
|
||||
|
||||
```bibtex
|
||||
|
||||
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(),
|
||||
version = '0.0.2',
|
||||
version = '0.0.3',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
|
||||
@@ -77,6 +77,8 @@ class ViT(nn.Module):
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.transformer = Transformer(dim, depth, heads, mlp_dim)
|
||||
|
||||
self.to_cls_token = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.Linear(dim, mlp_dim),
|
||||
nn.GELU(),
|
||||
@@ -88,8 +90,11 @@ class ViT(nn.Module):
|
||||
|
||||
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
|
||||
x = self.patch_to_embedding(x)
|
||||
x = torch.cat((self.cls_token, x), dim=1)
|
||||
|
||||
cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding
|
||||
x = self.transformer(x)
|
||||
|
||||
return self.mlp_head(x[:, 0])
|
||||
x = self.to_cls_token(x[:, 0])
|
||||
return self.mlp_head(x)
|
||||
|
||||
Reference in New Issue
Block a user