diff --git a/README.md b/README.md
index a5e19ee..ca500d2 100644
--- a/README.md
+++ b/README.md
@@ -378,6 +378,32 @@ img = torch.randn(1, 3, 224, 224)
pred = model(img) # (1, 1000)
```
+## NesT
+
+
+
+This paper decided to process the image in heirarchical stages, with attention only within tokens of local blocks, which aggregate as it moves up the heirarchy. The aggregation is done in the image plane, and contains a convolution to allow it to pass information across the boundary.
+
+You can use it with the following code (ex. NesT-T)
+
+```python
+import torch
+from vit_pytorch.nest import NesT
+
+nest = NesT(
+ image_size = 224,
+ patch_size = 4,
+ dim = 96,
+ heads = 3,
+ num_heirarchies = 3, # number of heirarchies
+ block_repeats = (8, 4, 1), # the number of transformer blocks at each heirarchy, starting from the bottom
+ num_classes = 1000
+)
+
+img = torch.randn(1, 3, 224, 224)
+pred = nest(img) # (1, 1000)
+```
+
## Masked Patch Prediction
Thanks to Zach, you can train using the original masked patch prediction task presented in the paper, with the following code.
@@ -787,6 +813,17 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```
+```bibtex
+@misc{zhang2021aggregating,
+ title = {Aggregating Nested Transformers},
+ author = {Zizhao Zhang and Han Zhang and Long Zhao and Ting Chen and Tomas Pfister},
+ year = {2021},
+ eprint = {2105.12723},
+ archivePrefix = {arXiv},
+ primaryClass = {cs.CV}
+}
+```
+
```bibtex
@misc{caron2021emerging,
title = {Emerging Properties in Self-Supervised Vision Transformers},
diff --git a/images/nest.png b/images/nest.png
new file mode 100644
index 0000000..3ddbdc2
Binary files /dev/null and b/images/nest.png differ
diff --git a/setup.py b/setup.py
index 5c86e60..79d7059 100644
--- a/setup.py
+++ b/setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
- version = '0.18.4',
+ version = '0.19.0',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',
diff --git a/vit_pytorch/nest.py b/vit_pytorch/nest.py
new file mode 100644
index 0000000..a54fa34
--- /dev/null
+++ b/vit_pytorch/nest.py
@@ -0,0 +1,177 @@
+import torch
+from torch import nn, einsum
+
+from einops import rearrange
+from einops.layers.torch import Rearrange, Reduce
+
+# helpers
+
+def cast_tuple(val, depth):
+ return val if isinstance(val, tuple) else ((val,) * depth)
+
+# classes
+
+class ChanNorm(nn.Module):
+ def __init__(self, dim, eps = 1e-5):
+ super().__init__()
+ self.eps = eps
+ self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
+ self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
+
+ def forward(self, x):
+ std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
+ mean = torch.mean(x, dim = 1, keepdim = True)
+ return (x - mean) / (std + self.eps) * self.g + self.b
+
+class PreNorm(nn.Module):
+ def __init__(self, dim, fn):
+ super().__init__()
+ self.norm = ChanNorm(dim)
+ self.fn = fn
+
+ def forward(self, x, **kwargs):
+ return self.fn(self.norm(x), **kwargs)
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, mlp_mult = 4, dropout = 0.):
+ super().__init__()
+ self.net = nn.Sequential(
+ nn.Conv2d(dim, dim * mlp_mult, 1),
+ nn.GELU(),
+ nn.Dropout(dropout),
+ nn.Conv2d(dim * mlp_mult, dim, 1),
+ nn.Dropout(dropout)
+ )
+ def forward(self, x):
+ return self.net(x)
+
+class Attention(nn.Module):
+ def __init__(self, dim, heads = 8, dropout = 0.):
+ super().__init__()
+ assert (dim % heads) == 0, 'dimension must be divisible by number of heads'
+ dim_head = dim // heads
+ self.heads = heads
+ self.scale = dim_head ** -0.5
+
+ self.attend = nn.Softmax(dim = -1)
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1, bias = False)
+
+ self.to_out = nn.Sequential(
+ nn.Conv2d(dim, dim, 1),
+ nn.Dropout(dropout)
+ )
+
+ def forward(self, x):
+ b, c, h, w, heads = *x.shape, self.heads
+
+ qkv = self.to_qkv(x).chunk(3, dim = 1)
+ q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> b h (x y) d', h = heads), qkv)
+
+ dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
+
+ attn = self.attend(dots)
+
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
+ out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
+ return self.to_out(out)
+
+def Aggregate(dim, dim_out):
+ return nn.Sequential(
+ nn.Conv2d(dim, dim_out, 3, padding = 1),
+ ChanNorm(dim_out),
+ nn.MaxPool2d(2)
+ )
+
+class Transformer(nn.Module):
+ def __init__(self, dim, seq_len, depth, heads, mlp_mult, dropout = 0.):
+ super().__init__()
+ self.layers = nn.ModuleList([])
+ self.pos_emb = nn.Parameter(torch.randn(seq_len))
+
+ for _ in range(depth):
+ self.layers.append(nn.ModuleList([
+ PreNorm(dim, Attention(dim, heads = heads, dropout = dropout)),
+ PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout))
+ ]))
+ def forward(self, x):
+ *_, h, w = x.shape
+
+ pos_emb = self.pos_emb[:(h * w)]
+ pos_emb = rearrange(pos_emb, '(h w) -> () () h w', h = h, w = w)
+ x = x + pos_emb
+
+ for attn, ff in self.layers:
+ x = attn(x) + x
+ x = ff(x) + x
+ return x
+
+class NesT(nn.Module):
+ def __init__(
+ self,
+ *,
+ image_size,
+ patch_size,
+ num_classes,
+ dim,
+ heads,
+ num_heirarchies,
+ block_repeats,
+ mlp_mult = 4,
+ channels = 3,
+ dim_head = 64,
+ dropout = 0.
+ ):
+ super().__init__()
+ assert (image_size % patch_size) == 0, 'Image dimensions must be divisible by the patch size.'
+ num_patches = (image_size // patch_size) ** 2
+ patch_dim = channels * patch_size ** 2
+ fmap_size = image_size // patch_size
+ blocks = 2 ** (num_heirarchies - 1)
+
+ seq_len = (fmap_size // blocks) ** 2 # sequence length is held constant across heirarchy
+ mults = [2 ** i for i in reversed(range(num_heirarchies))]
+
+ layer_heads = list(map(lambda t: t * heads, mults))
+ layer_dims = list(map(lambda t: t * dim, mults))
+
+ layer_dims = [*layer_dims, layer_dims[-1]]
+ dim_pairs = zip(layer_dims[:-1], layer_dims[1:])
+
+ self.to_patch_embedding = nn.Sequential(
+ Rearrange('b c (h p1) (w p2) -> b (p1 p2 c) h w', p1 = patch_size, p2 = patch_size),
+ nn.Conv2d(patch_dim, layer_dims[0], 1),
+ )
+
+ block_repeats = cast_tuple(block_repeats, num_heirarchies)
+
+ self.layers = nn.ModuleList([])
+
+ for level, heads, (dim_in, dim_out), block_repeat in zip(reversed(range(num_heirarchies)), layer_heads, dim_pairs, block_repeats):
+ is_last = level == 0
+ depth = block_repeat
+
+ self.layers.append(nn.ModuleList([
+ Transformer(dim_in, seq_len, depth, heads, mlp_mult, dropout),
+ Aggregate(dim_in, dim_out) if not is_last else nn.Identity()
+ ]))
+
+ self.mlp_head = nn.Sequential(
+ ChanNorm(dim),
+ Reduce('b c h w -> b c', 'mean'),
+ nn.Linear(dim, num_classes)
+ )
+
+ def forward(self, img):
+ x = self.to_patch_embedding(img)
+ b, c, h, w = x.shape
+
+ num_heirarchies = len(self.layers)
+
+ for level, (transformer, aggregate) in zip(reversed(range(num_heirarchies)), self.layers):
+ block_size = 2 ** level
+ x = rearrange(x, 'b c (b1 h) (b2 w) -> (b b1 b2) c h w', b1 = block_size, b2 = block_size)
+ x = transformer(x)
+ x = rearrange(x, '(b b1 b2) c h w -> b c (b1 h) (b2 w)', b1 = block_size, b2 = block_size)
+ x = aggregate(x)
+
+ return self.mlp_head(x)