diff --git a/README.md b/README.md index d0bc379..367184d 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ - [RegionViT](#regionvit) - [NesT](#nest) - [Masked Autoencoder](#masked-autoencoder) +- [Simple Masked Image Modeling](#simple-masked-image-modeling) - [Masked Patch Prediction](#masked-patch-prediction) - [Dino](#dino) - [Accessing Attention](#accessing-attention) @@ -519,6 +520,46 @@ img = torch.randn(1, 3, 224, 224) pred = nest(img) # (1, 1000) ``` +## Simple Masked Image Modeling + + + +This paper proposes a simple masked image modeling (SimMIM) scheme, using only a linear projection off the masked tokens into pixel space followed by an L1 loss with the pixel values of the masked patches. Results are competitive with other more complicated approaches. + +You can use this as follows + +```python +import torch +from vit_pytorch import ViT +from vit_pytorch.simmim import SimMIM + +v = ViT( + image_size = 256, + patch_size = 32, + num_classes = 1000, + dim = 1024, + depth = 6, + heads = 8, + mlp_dim = 2048 +) + +mim = SimMIM( + encoder = v, + masking_ratio = 0.5 # they found 50% to yield the best results +) + +images = torch.randn(8, 3, 256, 256) + +loss = mim(images) +loss.backward() + +# that's all! +# do the above in a for loop many times with a lot of images and your vision transformer will learn + +torch.save(v.state_dict(), './trained-vit.pt') +``` + + ## Masked Autoencoder @@ -1026,6 +1067,17 @@ Coming from computer vision and new to transformers? Here are some resources tha } ``` +```bibtex +@misc{xie2021simmim, + title = {SimMIM: A Simple Framework for Masked Image Modeling}, + author = {Zhenda Xie and Zheng Zhang and Yue Cao and Yutong Lin and Jianmin Bao and Zhuliang Yao and Qi Dai and Han Hu}, + year = {2021}, + eprint = {2111.09886}, + archivePrefix = {arXiv}, + primaryClass = {cs.CV} +} +``` + ```bibtex @misc{vaswani2017attention, title = {Attention Is All You Need}, diff --git a/images/simmim.png b/images/simmim.png new file mode 100644 index 0000000..7dfa85b Binary files /dev/null and b/images/simmim.png differ diff --git a/setup.py b/setup.py index 956eee9..79f452e 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.22.0', + version = '0.23.0', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', author = 'Phil Wang', diff --git a/vit_pytorch/simmim.py b/vit_pytorch/simmim.py new file mode 100644 index 0000000..1537e76 --- /dev/null +++ b/vit_pytorch/simmim.py @@ -0,0 +1,87 @@ +import torch +from torch import nn +import torch.nn.functional as F +from einops import repeat + +from vit_pytorch.vit import Transformer + +class SimMIM(nn.Module): + def __init__( + self, + *, + encoder, + masking_ratio = 0.5 + ): + super().__init__() + assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1' + self.masking_ratio = masking_ratio + + # extract some hyperparameters and functions from encoder (vision transformer to be trained) + + self.encoder = encoder + num_patches, encoder_dim = encoder.pos_embedding.shape[-2:] + self.to_patch, self.patch_to_emb = encoder.to_patch_embedding[:2] + pixel_values_per_patch = self.patch_to_emb.weight.shape[-1] + + # simple linear head + + self.mask_token = nn.Parameter(torch.randn(encoder_dim)) + self.to_pixels = nn.Linear(encoder_dim, pixel_values_per_patch) + + def forward(self, img): + device = img.device + + # get patches + + patches = self.to_patch(img) + batch, num_patches, *_ = patches.shape + + # for indexing purposes + + batch_range = torch.arange(batch, device = device)[:, None] + + # get positions + + pos_emb = self.encoder.pos_embedding[:, 1:(num_patches + 1)] + + # patch to encoder tokens and add positions + + tokens = self.patch_to_emb(patches) + tokens = tokens + pos_emb + + # prepare mask tokens + + mask_tokens = repeat(self.mask_token, 'd -> b n d', b = batch, n = num_patches) + mask_tokens = mask_tokens + pos_emb + + # calculate of patches needed to be masked, and get random indices, dividing it up for mask vs unmasked + + num_masked = int(self.masking_ratio * num_patches) + rand_indices = torch.rand(batch, num_patches, device = device).argsort(dim = -1) + masked_indices = rand_indices[:, :num_masked] + masked_bool_mask = torch.zeros((batch, num_patches), device = device).scatter_(-1, masked_indices, 1).bool() + + # mask tokens + + tokens = torch.where(masked_bool_mask[..., None], mask_tokens, tokens) + + # attend with vision transformer + + encoded = self.encoder.transformer(tokens) + + # get the masked tokens + + encoded_mask_tokens = encoded[batch_range, masked_indices] + + # small linear projection for predicted pixel values + + pred_pixel_values = self.to_pixels(encoded_mask_tokens) + + # get the masked patches for the final reconstruction loss + + masked_patches = patches[batch_range, masked_indices] + + # calculate reconstruction loss + + recon_loss = F.l1_loss(pred_pixel_values, masked_patches) / num_masked + return recon_loss