diff --git a/README.md b/README.md index e14aad2..9a15b6e 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,7 @@ - [Parallel ViT](#parallel-vit) - [Learnable Memory ViT](#learnable-memory-vit) - [Dino](#dino) +- [EsViT](#esvit) - [Accessing Attention](#accessing-attention) - [Research Ideas](#research-ideas) * [Efficient Attention](#efficient-attention) @@ -1076,6 +1077,80 @@ for _ in range(100): torch.save(model.state_dict(), './pretrained-net.pt') ``` +## EsViT + + + +`EsViT` is a variant of Dino (from above) re-engineered to support efficient `ViT`s with patch merging / downsampling by taking into an account an extra regional loss between the augmented views. To quote the abstract, it `outperforms its supervised counterpart on 17 out of 18 datasets` at 3 times higher throughput. + +Even though it is named as though it were a new `ViT` variant, it actually is just a strategy for training any multistage `ViT` (in the paper, they focused on Swin). The example below will show how to use it with `CvT`. You'll need to set the `hidden_layer` to the name of the layer within your efficient ViT that outputs the non-average pooled visual representations, just before the global pooling and projection to logits. + +```python +import torch +from vit_pytorch.cvt import CvT +from vit_pytorch.es_vit import EsViTTrainer + +cvt = CvT( + num_classes = 1000, + s1_emb_dim = 64, + s1_emb_kernel = 7, + s1_emb_stride = 4, + s1_proj_kernel = 3, + s1_kv_proj_stride = 2, + s1_heads = 1, + s1_depth = 1, + s1_mlp_mult = 4, + s2_emb_dim = 192, + s2_emb_kernel = 3, + s2_emb_stride = 2, + s2_proj_kernel = 3, + s2_kv_proj_stride = 2, + s2_heads = 3, + s2_depth = 2, + s2_mlp_mult = 4, + s3_emb_dim = 384, + s3_emb_kernel = 3, + s3_emb_stride = 2, + s3_proj_kernel = 3, + s3_kv_proj_stride = 2, + s3_heads = 4, + s3_depth = 10, + s3_mlp_mult = 4, + dropout = 0. +) + +learner = EsViTTrainer( + cvt, + image_size = 256, + hidden_layer = 'layers', # hidden layer name or index, from which to extract the embedding + projection_hidden_size = 256, # projector network hidden dimension + projection_layers = 4, # number of layers in projection network + num_classes_K = 65336, # output logits dimensions (referenced as K in paper) + student_temp = 0.9, # student temperature + teacher_temp = 0.04, # teacher temperature, needs to be annealed from 0.04 to 0.07 over 30 epochs + local_upper_crop_scale = 0.4, # upper bound for local crop - 0.4 was recommended in the paper + global_lower_crop_scale = 0.5, # lower bound for global crop - 0.5 was recommended in the paper + moving_average_decay = 0.9, # moving average of encoder - paper showed anywhere from 0.9 to 0.999 was ok + center_moving_average_decay = 0.9, # moving average of teacher centers - paper showed anywhere from 0.9 to 0.999 was ok +) + +opt = torch.optim.AdamW(learner.parameters(), lr = 3e-4) + +def sample_unlabelled_images(): + return torch.randn(8, 3, 256, 256) + +for _ in range(1000): + images = sample_unlabelled_images() + loss = learner(images) + opt.zero_grad() + loss.backward() + opt.step() + learner.update_moving_average() # update moving average of teacher encoder and teacher centers + +# save your improved network +torch.save(cvt.state_dict(), './pretrained-net.pt') +``` + ## Accessing Attention If you would like to visualize the attention weights (post-softmax) for your research, just follow the procedure below diff --git a/setup.py b/setup.py index e4474b4..ab458d0 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.33.2', + version = '0.34.0', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', author = 'Phil Wang', diff --git a/vit_pytorch/cvt.py b/vit_pytorch/cvt.py index add9ce9..2750284 100644 --- a/vit_pytorch/cvt.py +++ b/vit_pytorch/cvt.py @@ -164,12 +164,14 @@ class CvT(nn.Module): dim = config['emb_dim'] - self.layers = nn.Sequential( - *layers, + self.layers = nn.Sequential(*layers) + + self.to_logits = nn.Sequential( nn.AdaptiveAvgPool2d(1), Rearrange('... () () -> ...'), nn.Linear(dim, num_classes) ) def forward(self, x): - return self.layers(x) + latents = self.layers(x) + return self.to_logits(latents)