From 8fb261ca66b187958dc59f36d188ea78e59e89b3 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 4 Oct 2020 14:55:29 -0700 Subject: [PATCH] fix a bug and add suggestion for BYOL pre-training --- README.md | 50 ++++++++++++++++++++++++++++++++++++++ setup.py | 2 +- vit_pytorch/vit_pytorch.py | 9 +++++-- 3 files changed, 58 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 9d271e6..b7e00b6 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,56 @@ img = torch.randn(1, 3, 256, 256) preds = v(img) # (1, 1000) ``` +## Suggestion + +You can train this with a near SOTA self-supervised learning technique, BYOL, with the following code. + +(1) +```bash +$ pip install byol-pytorch +``` + +(2) +```python +import torch +from vit_pytorch import ViT +from byol_pytorch import BYOL + +model = ViT( + image_size = 256, + patch_size = 32, + num_classes = 1000, + dim = 1024, + depth = 6, + heads = 8, + mlp_dim = 2048 +) + +learner = BYOL( + model, + image_size = 256, + hidden_layer = 'to_cls_token' +) + +opt = torch.optim.Adam(learner.parameters(), lr=3e-4) + +def sample_unlabelled_images(): + return torch.randn(20, 3, 256, 256) + +for _ in range(100): + images = sample_unlabelled_images() + loss = learner(images) + opt.zero_grad() + loss.backward() + opt.step() + learner.update_moving_average() # update moving average of target encoder + +# save your improved network +torch.save(model.state_dict(), './pretrained-net.pt') +``` + +A pytorch-lightning script is ready for you to use at the repository link above. + ## Citations ```bibtex diff --git a/setup.py b/setup.py index b8d4755..22ed1df 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(), - version = '0.0.2', + version = '0.0.3', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', author = 'Phil Wang', diff --git a/vit_pytorch/vit_pytorch.py b/vit_pytorch/vit_pytorch.py index 47fd990..54b12e9 100644 --- a/vit_pytorch/vit_pytorch.py +++ b/vit_pytorch/vit_pytorch.py @@ -77,6 +77,8 @@ class ViT(nn.Module): self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) self.transformer = Transformer(dim, depth, heads, mlp_dim) + self.to_cls_token = nn.Identity() + self.mlp_head = nn.Sequential( nn.Linear(dim, mlp_dim), nn.GELU(), @@ -88,8 +90,11 @@ class ViT(nn.Module): x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p) x = self.patch_to_embedding(x) - x = torch.cat((self.cls_token, x), dim=1) + + cls_tokens = self.cls_token.expand(img.shape[0], -1, -1) + x = torch.cat((cls_tokens, x), dim=1) x += self.pos_embedding x = self.transformer(x) - return self.mlp_head(x[:, 0]) + x = self.to_cls_token(x[:, 0]) + return self.mlp_head(x)