From 173e07e02ead297e57b1b9fa3f01135c52b4a045 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 8 Mar 2021 07:28:31 -0800 Subject: [PATCH] cleanup and release 0.8.0 --- README.md | 17 ++++++++++------- setup.py | 2 +- vit_pytorch/__init__.py | 1 - vit_pytorch/{mpp_pytorch.py => mpp.py} | 11 ++++++----- 4 files changed, 17 insertions(+), 14 deletions(-) rename vit_pytorch/{mpp_pytorch.py => mpp.py} (93%) diff --git a/README.md b/README.md index 433227f..e9e7c7b 100644 --- a/README.md +++ b/README.md @@ -141,15 +141,14 @@ img = torch.randn(1, 3, 224, 224) v(img) # (1, 1000) ``` -## Research Ideas +## Masked Patch Prediction -### Self Supervised Training - -You can train using the original masked patch prediction task presented in the paper, with the following code. +Thanks to Zach, you can train using the original masked patch prediction task presented in the paper, with the following code. ```python import torch -from vit_pytorch import ViT, MPP +from vit_pytorch import ViT +from vit_pytorch.mpp import MPP model = ViT(image_size=256, patch_size=32, @@ -165,9 +164,9 @@ mpp_trainer = MPP( transformer=model, patch_size=32, dim=1024, - mask_prob=0.15, # probability of using token in masked prediction task + mask_prob=0.15, # probability of using token in masked prediction task random_patch_prob=0.30, # probability of randomly replacing a token being used for mpp - replace_prob=0.50, # probability of replacing a token being used for mpp with the mask token + replace_prob=0.50, # probability of replacing a token being used for mpp with the mask token ) opt = torch.optim.Adam(mpp_trainer.parameters(), lr=3e-4) @@ -188,6 +187,10 @@ for _ in range(100): torch.save(model.state_dict(), './pretrained-net.pt') ``` +## Research Ideas + +### Self Supervised Training + You can train this with a near SOTA self-supervised learning technique, BYOL, with the following code. (1) diff --git a/setup.py b/setup.py index 98819fb..cc706ce 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.7.6', + version = '0.8.0', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', author = 'Phil Wang', diff --git a/vit_pytorch/__init__.py b/vit_pytorch/__init__.py index 1fcf01f..2a84eb0 100644 --- a/vit_pytorch/__init__.py +++ b/vit_pytorch/__init__.py @@ -1,2 +1 @@ from vit_pytorch.vit_pytorch import ViT -from vit_pytorch.mpp_pytorch import MPP diff --git a/vit_pytorch/mpp_pytorch.py b/vit_pytorch/mpp.py similarity index 93% rename from vit_pytorch/mpp_pytorch.py rename to vit_pytorch/mpp.py index ca520bc..ef59a79 100644 --- a/vit_pytorch/mpp_pytorch.py +++ b/vit_pytorch/mpp.py @@ -106,6 +106,7 @@ class MPP(nn.Module): self.mask_token = nn.Parameter(torch.randn(1, 1, dim * channels)) def forward(self, input, **kwargs): + transformer = self.transformer # clone original image for loss img = input.clone().detach() @@ -144,19 +145,19 @@ class MPP(nn.Module): masked_input[bool_mask_replace] = self.mask_token # linear embedding of patches - masked_input = self.transformer.patch_to_embedding(masked_input) + masked_input = transformer.to_patch_embedding[-1](masked_input) # add cls token to input sequence b, n, _ = masked_input.shape - cls_tokens = repeat(self.transformer.cls_token, '() n d -> b n d', b=b) + cls_tokens = repeat(transformer.cls_token, '() n d -> b n d', b=b) masked_input = torch.cat((cls_tokens, masked_input), dim=1) # add positional embeddings to input - masked_input += self.transformer.pos_embedding[:, :(n + 1)] - masked_input = self.transformer.dropout(masked_input) + masked_input += transformer.pos_embedding[:, :(n + 1)] + masked_input = transformer.dropout(masked_input) # get generator output and get mpp loss - masked_input = self.transformer.transformer(masked_input, **kwargs) + masked_input = transformer.transformer(masked_input, **kwargs) cls_logits = self.to_bits(masked_input) logits = cls_logits[:, 1:, :]