diff --git a/README.md b/README.md
index 8206a0a..69e7553 100644
--- a/README.md
+++ b/README.md
@@ -27,6 +27,7 @@
- [Masked Autoencoder](#masked-autoencoder)
- [Simple Masked Image Modeling](#simple-masked-image-modeling)
- [Masked Patch Prediction](#masked-patch-prediction)
+- [Masked Position Prediction](#masked-position-prediction)
- [Adaptive Token Sampling](#adaptive-token-sampling)
- [Patch Merger](#patch-merger)
- [Vision Transformer for Small Datasets](#vision-transformer-for-small-datasets)
@@ -844,6 +845,39 @@ for _ in range(100):
torch.save(model.state_dict(), './pretrained-net.pt')
```
+## Masked Position Prediction
+
+
+
+New paper that introduces masked position prediction pre-training criteria. This strategy is more efficient than the Masked Autoencoder strategy and has comparable performance.
+
+```python
+import torch
+from vit_pytorch.mp3 import MP3
+
+model = MP3(
+ image_size=256,
+ patch_size=8,
+ masking_ratio=0.75
+ dim=1024,
+ depth=6,
+ heads=8,
+ mlp_dim=2048,
+ dropout=0.1,
+)
+
+images = torch.randn(8, 3, 256, 256)
+
+loss = model(images)
+loss.backward()
+
+# that's all!
+# do the above in a for loop many times with a lot of images and your vision transformer will learn
+
+# save your improved vision transformer
+torch.save(v.state_dict(), './trained-vit.pt')
+```
+
## Adaptive Token Sampling
diff --git a/images/mp3.png b/images/mp3.png
new file mode 100644
index 0000000..5ae8d9e
Binary files /dev/null and b/images/mp3.png differ
diff --git a/vit_pytorch/mp3.py b/vit_pytorch/mp3.py
new file mode 100644
index 0000000..d3daa7c
--- /dev/null
+++ b/vit_pytorch/mp3.py
@@ -0,0 +1,137 @@
+import torch
+from torch import nn, einsum
+import torch.nn.functional as F
+
+from einops import rearrange, repeat
+from einops.layers.torch import Rearrange
+
+# helpers
+
+def pair(t):
+ return t if isinstance(t, tuple) else (t, t)
+
+# pre-layernorm
+
+class PreNorm(nn.Module):
+ def __init__(self, dim, fn):
+ super().__init__()
+ self.norm = nn.LayerNorm(dim)
+ self.fn = fn
+ def forward(self, x, **kwargs):
+ return self.fn(self.norm(x), **kwargs)
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, hidden_dim, dropout = 0.):
+ super().__init__()
+ self.net = nn.Sequential(
+ nn.Linear(dim, hidden_dim),
+ nn.GELU(),
+ nn.Dropout(dropout),
+ nn.Linear(hidden_dim, dim),
+ nn.Dropout(dropout)
+ )
+ def forward(self, x):
+ return self.net(x)
+
+# cross attention
+
+class CrossAttention(nn.Module):
+ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
+ super().__init__()
+ inner_dim = dim_head * heads
+ self.heads = heads
+ self.scale = dim_head ** -0.5
+
+ self.attend = nn.Softmax(dim = -1)
+ self.dropout = nn.Dropout(dropout)
+
+ self.to_q = nn.Linear(dim, inner_dim, bias = False)
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, dim),
+ nn.Dropout(dropout)
+ )
+
+ def forward(self, x, context):
+ b, n, _, h = *x.shape, self.heads
+
+ qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
+
+ dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
+
+ attn = self.attend(dots)
+ attn = self.dropout(attn)
+
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
+ out = rearrange(out, 'b h n d -> b n (h d)')
+ return self.to_out(out)
+
+class Transformer(nn.Module):
+ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
+ super().__init__()
+ self.layers = nn.ModuleList([])
+ for _ in range(depth):
+ self.layers.append(nn.ModuleList([
+ PreNorm(dim, CrossAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
+ PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
+ ]))
+ def forward(self, x, context):
+ for attn, ff in self.layers:
+ x = attn(x, context=context) + x
+ x = ff(x) + x
+ return x
+
+# Masked Position Prediction Pre-Training
+
+class MP3(nn.Module):
+ def __init__(self, *, image_size, patch_size, masking_ratio, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0.):
+ super().__init__()
+ image_height, image_width = pair(image_size)
+ patch_height, patch_width = pair(patch_size)
+
+ assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
+
+ assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1'
+ self.masking_ratio = masking_ratio
+
+ num_patches = (image_height // patch_height) * (image_width // patch_width)
+ patch_dim = channels * patch_height * patch_width
+
+ self.to_patch_embedding = nn.Sequential(
+ Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
+ nn.LayerNorm(patch_dim),
+ nn.Linear(patch_dim, dim),
+ nn.LayerNorm(dim),
+ )
+
+ self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
+
+ self.mlp_head = nn.Sequential(
+ nn.LayerNorm(dim),
+ nn.Linear(dim, num_patches)
+ )
+ self.out = nn.Softmax(dim = -1)
+
+ def forward(self, img):
+ device = img.device
+ tokens = self.to_patch_embedding(img)
+ batch, num_patches, *_ = tokens.shape
+
+ # Masking
+ num_masked = int(self.masking_ratio * num_patches)
+ rand_indices = torch.rand(batch, num_patches, device = device).argsort(dim = -1)
+ masked_indices, unmasked_indices = rand_indices[:, :num_masked], rand_indices[:, num_masked:]
+
+ batch_range = torch.arange(batch, device = device)[:, None]
+ tokens_unmasked = tokens[batch_range, unmasked_indices]
+
+ x = rearrange(self.mlp_head(self.transformer(tokens, tokens_unmasked)), 'b n d -> (b n) d')
+ x = self.out(x)
+
+ # Define labels
+ labels = repeat(torch.arange(num_patches, device = device), 'n -> b n', b = batch).flatten()
+ loss = F.cross_entropy(x, labels)
+
+ return loss