diff --git a/README.md b/README.md
index ea188a1..44736e7 100644
--- a/README.md
+++ b/README.md
@@ -20,6 +20,7 @@
- [RegionViT](#regionvit)
- [ScalableViT](#scalablevit)
- [SepViT](#sepvit)
+- [MaxViT](#maxvit)
- [NesT](#nest)
- [MobileViT](#mobilevit)
- [Masked Autoencoder](#masked-autoencoder)
@@ -596,6 +597,37 @@ img = torch.randn(1, 3, 224, 224)
preds = v(img) # (1, 1000)
```
+## MaxViT
+
+
+
+This paper proposes a hybrid convolutional / attention network, using MBConv from the convolution side, and then block / grid axial sparse attention.
+
+They also claim this specific vision transformer is good for generative models (GANs).
+
+ex. MaxViT-S
+
+```python
+import torch
+from vit_pytorch.max_vit import MaxViT
+
+v = MaxViT(
+ num_classes = 1000,
+ dim_conv_stem = 64, # dimension of the convolutional stem, would default to dimension of first layer if not specified
+ dim = 96, # dimension of first layer, doubles every layer
+ dim_head = 32, # dimension of attention heads, kept at 32 in paper
+ depth = (2, 2, 5, 2), # number of MaxViT blocks per stage, which consists of MBConv, block-like attention, grid-like attention
+ window_size = 7, # window size for block and grids
+ mbconv_expansion_rate = 4, # expansion rate of MBConv
+ mbconv_shrinkage_rate = 0.25, # shrinkage rate of squeeze-excitation in MBConv
+ dropout = 0.1 # dropout
+)
+
+img = torch.randn(2, 3, 224, 224)
+
+preds = v(img) # (2, 1000)
+```
+
## NesT
@@ -1544,6 +1576,14 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```
+```bibtex
+@inproceedings{Tu2022MaxViTMV,
+ title = {MaxViT: Multi-Axis Vision Transformer},
+ author = {Zhe-Wei Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
+ year = {2022}
+}
+```
+
```bibtex
@misc{vaswani2017attention,
title = {Attention Is All You Need},
diff --git a/images/max-vit.png b/images/max-vit.png
new file mode 100644
index 0000000..2b76f01
Binary files /dev/null and b/images/max-vit.png differ
diff --git a/setup.py b/setup.py
index df44508..7a42e3b 100644
--- a/setup.py
+++ b/setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
- version = '0.32.2',
+ version = '0.33.0',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',
@@ -16,7 +16,7 @@ setup(
],
install_requires=[
'einops>=0.4.1',
- 'torch>=1.6',
+ 'torch>=1.10',
'torchvision'
],
setup_requires=[
diff --git a/vit_pytorch/crossformer.py b/vit_pytorch/crossformer.py
index 65ec3cb..2725532 100644
--- a/vit_pytorch/crossformer.py
+++ b/vit_pytorch/crossformer.py
@@ -108,7 +108,7 @@ class Attention(nn.Module):
# calculate and store indices for retrieving bias
pos = torch.arange(window_size)
- grid = torch.stack(torch.meshgrid(pos, pos))
+ grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
grid = rearrange(grid, 'c i j -> (i j) c')
rel_pos = grid[:, None] - grid[None, :]
rel_pos += window_size - 1
@@ -144,7 +144,7 @@ class Attention(nn.Module):
# add dynamic positional bias
pos = torch.arange(-wsz, wsz + 1, device = device)
- rel_pos = torch.stack(torch.meshgrid(pos, pos))
+ rel_pos = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
rel_pos = rearrange(rel_pos, 'c i j -> (i j) c')
biases = self.dpb(rel_pos.float())
rel_pos_bias = biases[self.rel_pos_indices]
diff --git a/vit_pytorch/levit.py b/vit_pytorch/levit.py
index ffb3efa..bf9a092 100644
--- a/vit_pytorch/levit.py
+++ b/vit_pytorch/levit.py
@@ -71,8 +71,8 @@ class Attention(nn.Module):
q_range = torch.arange(0, fmap_size, step = (2 if downsample else 1))
k_range = torch.arange(fmap_size)
- q_pos = torch.stack(torch.meshgrid(q_range, q_range), dim = -1)
- k_pos = torch.stack(torch.meshgrid(k_range, k_range), dim = -1)
+ q_pos = torch.stack(torch.meshgrid(q_range, q_range, indexing = 'ij'), dim = -1)
+ k_pos = torch.stack(torch.meshgrid(k_range, k_range, indexing = 'ij'), dim = -1)
q_pos, k_pos = map(lambda t: rearrange(t, 'i j c -> (i j) c'), (q_pos, k_pos))
rel_pos = (q_pos[:, None, ...] - k_pos[None, :, ...]).abs()
diff --git a/vit_pytorch/max_vit.py b/vit_pytorch/max_vit.py
new file mode 100644
index 0000000..cf9ac45
--- /dev/null
+++ b/vit_pytorch/max_vit.py
@@ -0,0 +1,270 @@
+from functools import partial
+
+import torch
+from torch import nn, einsum
+
+from einops import rearrange, repeat
+from einops.layers.torch import Rearrange, Reduce
+
+# helpers
+
+def exists(val):
+ return val is not None
+
+def default(val, d):
+ return val if exists(val) else d
+
+def cast_tuple(val, length = 1):
+ return val if isinstance(val, tuple) else ((val,) * length)
+
+# helper classes
+
+class PreNormResidual(nn.Module):
+ def __init__(self, dim, fn):
+ super().__init__()
+ self.norm = nn.LayerNorm(dim)
+ self.fn = fn
+
+ def forward(self, x):
+ return self.fn(self.norm(x)) + x
+
+# MBConv
+
+class SqueezeExcitation(nn.Module):
+ def __init__(self, dim, shrinkage_rate = 0.25):
+ super().__init__()
+ hidden_dim = int(dim * shrinkage_rate)
+
+ self.gate = nn.Sequential(
+ Reduce('b c h w -> b c', 'mean'),
+ nn.Linear(dim, hidden_dim, bias = False),
+ nn.SiLU(),
+ nn.Linear(hidden_dim, dim, bias = False),
+ nn.Sigmoid(),
+ Rearrange('b c -> b c 1 1')
+ )
+
+ def forward(self, x):
+ return x * self.gate(x)
+
+
+class MBConvResidual(nn.Module):
+ def __init__(self, fn, dropout = 0.):
+ super().__init__()
+ self.fn = fn
+ self.dropsample = Dropsample(dropout)
+
+ def forward(self, x):
+ out = self.fn(x)
+ out = self.dropsample(out)
+ return out
+
+class Dropsample(nn.Module):
+ def __init__(self, prob = 0):
+ super().__init__()
+ self.prob = prob
+
+ def forward(self, x):
+ device = x.device
+
+ if self.prob == 0. or (not self.training):
+ return x
+
+ keep_mask = torch.FloatTensor((x.shape[0], 1, 1, 1), device = device).uniform_() > self.prob
+ return x * keep_mask / (1 - self.prob)
+
+def MBConv(
+ dim_in,
+ dim_out,
+ *,
+ downsample,
+ expansion_rate = 4,
+ shrinkage_rate = 0.25,
+ dropout = 0.
+):
+ hidden_dim = int(expansion_rate * dim_out)
+ stride = 2 if downsample else 1
+
+ net = nn.Sequential(
+ nn.Conv2d(dim_in, dim_out, 1),
+ nn.BatchNorm2d(dim_out),
+ nn.SiLU(),
+ nn.Conv2d(dim_out, dim_out, 3, stride = stride, padding = 1, groups = dim_out),
+ SqueezeExcitation(dim_out, shrinkage_rate = shrinkage_rate),
+ nn.Conv2d(dim_out, dim_out, 1),
+ nn.BatchNorm2d(dim_out)
+ )
+
+ if dim_in == dim_out and not downsample:
+ net = MBConvResidual(net, dropout = dropout)
+
+ return net
+
+# attention related classes
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim,
+ dim_head = 32,
+ dropout = 0.,
+ window_size = 7
+ ):
+ super().__init__()
+ assert (dim % dim_head) == 0, 'dimension should be divisible by dimension per head'
+
+ self.heads = dim // dim_head
+ self.scale = dim_head ** -0.5
+
+ self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
+
+ self.attend = nn.Sequential(
+ nn.Softmax(dim = -1),
+ nn.Dropout(dropout)
+ )
+
+ self.to_out = nn.Sequential(
+ nn.Linear(dim, dim, bias = False),
+ nn.Dropout(dropout)
+ )
+
+ # relative positional bias
+
+ self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads)
+
+ pos = torch.arange(window_size)
+ grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
+ grid = rearrange(grid, 'c i j -> (i j) c')
+ rel_pos = rearrange(grid, 'i ... -> i 1 ...') - rearrange(grid, 'j ... -> 1 j ...')
+ rel_pos += window_size - 1
+ rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1)
+
+ self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False)
+
+ def forward(self, x):
+ batch, height, width, window_height, window_width, _, device, h = *x.shape, x.device, self.heads
+
+ # flatten
+
+ x = rearrange(x, 'b x y w1 w2 d -> (b x y) (w1 w2) d')
+
+ # project for queries, keys, values
+
+ q, k, v = self.to_qkv(x).chunk(3, dim = -1)
+
+ # split heads
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d ) -> b h n d', h = h), (q, k, v))
+
+ # scale
+
+ q = q * self.scale
+
+ # sim
+
+ sim = einsum('b h i d, b h j d -> b h i j', q, k)
+
+ # add positional bias
+
+ bias = self.rel_pos_bias(self.rel_pos_indices)
+ sim = sim + rearrange(bias, 'i j h -> h i j')
+
+ # attention
+
+ attn = self.attend(sim)
+
+ # aggregate
+
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
+
+ # merge heads
+
+ out = rearrange(out, 'b h (w1 w2) d -> b w1 w2 (h d)', w1 = window_height, w2 = window_width)
+
+ # combine heads out
+
+ out = self.to_out(out)
+ return rearrange(out, '(b x y) ... -> b x y ...', x = height, y = width)
+
+class MaxViT(nn.Module):
+ def __init__(
+ self,
+ *,
+ num_classes,
+ dim,
+ depth,
+ dim_head = 32,
+ dim_conv_stem = None,
+ window_size = 7,
+ mbconv_expansion_rate = 4,
+ mbconv_shrinkage_rate = 0.25,
+ dropout = 0.1,
+ channels = 3
+ ):
+ super().__init__()
+ assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'
+
+ # convolutional stem
+
+ dim_conv_stem = default(dim_conv_stem, dim)
+
+ self.conv_stem = nn.Sequential(
+ nn.Conv2d(channels, dim_conv_stem, 3, stride = 2, padding = 1),
+ nn.Conv2d(dim_conv_stem, dim_conv_stem, 3, padding = 1)
+ )
+
+ # variables
+
+ num_stages = len(depth)
+
+ dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))
+ dims = (dim_conv_stem, *dims)
+ dim_pairs = tuple(zip(dims[:-1], dims[1:]))
+
+ self.layers = nn.ModuleList([])
+
+ # shorthand for window size for efficient block - grid like attention
+
+ w = window_size
+
+ # iterate through stages
+
+ for ind, ((layer_dim_in, layer_dim), layer_depth) in enumerate(zip(dim_pairs, depth)):
+ for stage_ind in range(layer_depth):
+ is_first = stage_ind == 0
+ stage_dim_in = layer_dim_in if is_first else layer_dim
+
+ block = nn.Sequential(
+ MBConv(
+ stage_dim_in,
+ layer_dim,
+ downsample = is_first,
+ expansion_rate = mbconv_expansion_rate,
+ shrinkage_rate = mbconv_shrinkage_rate
+ ),
+ Rearrange('b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w), # block-like attention
+ PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
+ Rearrange('b x y w1 w2 d -> b d (x w1) (y w2)'),
+
+ Rearrange('b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w), # grid-like attention
+ PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
+ Rearrange('b x y w1 w2 d -> b d (w1 x) (w2 y)'),
+ )
+
+ self.layers.append(block)
+
+ # mlp head out
+
+ self.mlp_head = nn.Sequential(
+ Reduce('b d h w -> b d', 'mean'),
+ nn.LayerNorm(dims[-1]),
+ nn.Linear(dims[-1], num_classes)
+ )
+
+ def forward(self, x):
+ x = self.conv_stem(x)
+
+ for stage in self.layers:
+ x = stage(x)
+
+ return self.mlp_head(x)
diff --git a/vit_pytorch/regionvit.py b/vit_pytorch/regionvit.py
index bdb6095..2e155a1 100644
--- a/vit_pytorch/regionvit.py
+++ b/vit_pytorch/regionvit.py
@@ -138,7 +138,7 @@ class R2LTransformer(nn.Module):
h_range = torch.arange(window_size_h, device = device)
w_range = torch.arange(window_size_w, device = device)
- grid_x, grid_y = torch.meshgrid(h_range, w_range)
+ grid_x, grid_y = torch.meshgrid(h_range, w_range, indexing = 'ij')
grid = torch.stack((grid_x, grid_y))
grid = rearrange(grid, 'c h w -> c (h w)')
grid = (grid[:, :, None] - grid[:, None, :]) + (self.window_size - 1)