diff --git a/README.md b/README.md
index d0bc379..367184d 100644
--- a/README.md
+++ b/README.md
@@ -19,6 +19,7 @@
- [RegionViT](#regionvit)
- [NesT](#nest)
- [Masked Autoencoder](#masked-autoencoder)
+- [Simple Masked Image Modeling](#simple-masked-image-modeling)
- [Masked Patch Prediction](#masked-patch-prediction)
- [Dino](#dino)
- [Accessing Attention](#accessing-attention)
@@ -519,6 +520,46 @@ img = torch.randn(1, 3, 224, 224)
pred = nest(img) # (1, 1000)
```
+## Simple Masked Image Modeling
+
+
+
+This paper proposes a simple masked image modeling (SimMIM) scheme, using only a linear projection off the masked tokens into pixel space followed by an L1 loss with the pixel values of the masked patches. Results are competitive with other more complicated approaches.
+
+You can use this as follows
+
+```python
+import torch
+from vit_pytorch import ViT
+from vit_pytorch.simmim import SimMIM
+
+v = ViT(
+ image_size = 256,
+ patch_size = 32,
+ num_classes = 1000,
+ dim = 1024,
+ depth = 6,
+ heads = 8,
+ mlp_dim = 2048
+)
+
+mim = SimMIM(
+ encoder = v,
+ masking_ratio = 0.5 # they found 50% to yield the best results
+)
+
+images = torch.randn(8, 3, 256, 256)
+
+loss = mim(images)
+loss.backward()
+
+# that's all!
+# do the above in a for loop many times with a lot of images and your vision transformer will learn
+
+torch.save(v.state_dict(), './trained-vit.pt')
+```
+
+
## Masked Autoencoder
@@ -1026,6 +1067,17 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```
+```bibtex
+@misc{xie2021simmim,
+ title = {SimMIM: A Simple Framework for Masked Image Modeling},
+ author = {Zhenda Xie and Zheng Zhang and Yue Cao and Yutong Lin and Jianmin Bao and Zhuliang Yao and Qi Dai and Han Hu},
+ year = {2021},
+ eprint = {2111.09886},
+ archivePrefix = {arXiv},
+ primaryClass = {cs.CV}
+}
+```
+
```bibtex
@misc{vaswani2017attention,
title = {Attention Is All You Need},
diff --git a/images/simmim.png b/images/simmim.png
new file mode 100644
index 0000000..7dfa85b
Binary files /dev/null and b/images/simmim.png differ
diff --git a/setup.py b/setup.py
index 956eee9..79f452e 100644
--- a/setup.py
+++ b/setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
- version = '0.22.0',
+ version = '0.23.0',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',
diff --git a/vit_pytorch/simmim.py b/vit_pytorch/simmim.py
new file mode 100644
index 0000000..1537e76
--- /dev/null
+++ b/vit_pytorch/simmim.py
@@ -0,0 +1,87 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+from einops import repeat
+
+from vit_pytorch.vit import Transformer
+
+class SimMIM(nn.Module):
+ def __init__(
+ self,
+ *,
+ encoder,
+ masking_ratio = 0.5
+ ):
+ super().__init__()
+ assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1'
+ self.masking_ratio = masking_ratio
+
+ # extract some hyperparameters and functions from encoder (vision transformer to be trained)
+
+ self.encoder = encoder
+ num_patches, encoder_dim = encoder.pos_embedding.shape[-2:]
+ self.to_patch, self.patch_to_emb = encoder.to_patch_embedding[:2]
+ pixel_values_per_patch = self.patch_to_emb.weight.shape[-1]
+
+ # simple linear head
+
+ self.mask_token = nn.Parameter(torch.randn(encoder_dim))
+ self.to_pixels = nn.Linear(encoder_dim, pixel_values_per_patch)
+
+ def forward(self, img):
+ device = img.device
+
+ # get patches
+
+ patches = self.to_patch(img)
+ batch, num_patches, *_ = patches.shape
+
+ # for indexing purposes
+
+ batch_range = torch.arange(batch, device = device)[:, None]
+
+ # get positions
+
+ pos_emb = self.encoder.pos_embedding[:, 1:(num_patches + 1)]
+
+ # patch to encoder tokens and add positions
+
+ tokens = self.patch_to_emb(patches)
+ tokens = tokens + pos_emb
+
+ # prepare mask tokens
+
+ mask_tokens = repeat(self.mask_token, 'd -> b n d', b = batch, n = num_patches)
+ mask_tokens = mask_tokens + pos_emb
+
+ # calculate of patches needed to be masked, and get random indices, dividing it up for mask vs unmasked
+
+ num_masked = int(self.masking_ratio * num_patches)
+ rand_indices = torch.rand(batch, num_patches, device = device).argsort(dim = -1)
+ masked_indices = rand_indices[:, :num_masked]
+ masked_bool_mask = torch.zeros((batch, num_patches), device = device).scatter_(-1, masked_indices, 1).bool()
+
+ # mask tokens
+
+ tokens = torch.where(masked_bool_mask[..., None], mask_tokens, tokens)
+
+ # attend with vision transformer
+
+ encoded = self.encoder.transformer(tokens)
+
+ # get the masked tokens
+
+ encoded_mask_tokens = encoded[batch_range, masked_indices]
+
+ # small linear projection for predicted pixel values
+
+ pred_pixel_values = self.to_pixels(encoded_mask_tokens)
+
+ # get the masked patches for the final reconstruction loss
+
+ masked_patches = patches[batch_range, masked_indices]
+
+ # calculate reconstruction loss
+
+ recon_loss = F.l1_loss(pred_pixel_values, masked_patches) / num_masked
+ return recon_loss