diff --git a/README.md b/README.md index ea188a1..44736e7 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ - [RegionViT](#regionvit) - [ScalableViT](#scalablevit) - [SepViT](#sepvit) +- [MaxViT](#maxvit) - [NesT](#nest) - [MobileViT](#mobilevit) - [Masked Autoencoder](#masked-autoencoder) @@ -596,6 +597,37 @@ img = torch.randn(1, 3, 224, 224) preds = v(img) # (1, 1000) ``` +## MaxViT + + + +This paper proposes a hybrid convolutional / attention network, using MBConv from the convolution side, and then block / grid axial sparse attention. + +They also claim this specific vision transformer is good for generative models (GANs). + +ex. MaxViT-S + +```python +import torch +from vit_pytorch.max_vit import MaxViT + +v = MaxViT( + num_classes = 1000, + dim_conv_stem = 64, # dimension of the convolutional stem, would default to dimension of first layer if not specified + dim = 96, # dimension of first layer, doubles every layer + dim_head = 32, # dimension of attention heads, kept at 32 in paper + depth = (2, 2, 5, 2), # number of MaxViT blocks per stage, which consists of MBConv, block-like attention, grid-like attention + window_size = 7, # window size for block and grids + mbconv_expansion_rate = 4, # expansion rate of MBConv + mbconv_shrinkage_rate = 0.25, # shrinkage rate of squeeze-excitation in MBConv + dropout = 0.1 # dropout +) + +img = torch.randn(2, 3, 224, 224) + +preds = v(img) # (2, 1000) +``` + ## NesT @@ -1544,6 +1576,14 @@ Coming from computer vision and new to transformers? Here are some resources tha } ``` +```bibtex +@inproceedings{Tu2022MaxViTMV, + title = {MaxViT: Multi-Axis Vision Transformer}, + author = {Zhe-Wei Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li}, + year = {2022} +} +``` + ```bibtex @misc{vaswani2017attention, title = {Attention Is All You Need}, diff --git a/images/max-vit.png b/images/max-vit.png new file mode 100644 index 0000000..2b76f01 Binary files /dev/null and b/images/max-vit.png differ diff --git a/setup.py b/setup.py index df44508..7a42e3b 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.32.2', + version = '0.33.0', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', author = 'Phil Wang', @@ -16,7 +16,7 @@ setup( ], install_requires=[ 'einops>=0.4.1', - 'torch>=1.6', + 'torch>=1.10', 'torchvision' ], setup_requires=[ diff --git a/vit_pytorch/crossformer.py b/vit_pytorch/crossformer.py index 65ec3cb..2725532 100644 --- a/vit_pytorch/crossformer.py +++ b/vit_pytorch/crossformer.py @@ -108,7 +108,7 @@ class Attention(nn.Module): # calculate and store indices for retrieving bias pos = torch.arange(window_size) - grid = torch.stack(torch.meshgrid(pos, pos)) + grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij')) grid = rearrange(grid, 'c i j -> (i j) c') rel_pos = grid[:, None] - grid[None, :] rel_pos += window_size - 1 @@ -144,7 +144,7 @@ class Attention(nn.Module): # add dynamic positional bias pos = torch.arange(-wsz, wsz + 1, device = device) - rel_pos = torch.stack(torch.meshgrid(pos, pos)) + rel_pos = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij')) rel_pos = rearrange(rel_pos, 'c i j -> (i j) c') biases = self.dpb(rel_pos.float()) rel_pos_bias = biases[self.rel_pos_indices] diff --git a/vit_pytorch/levit.py b/vit_pytorch/levit.py index ffb3efa..bf9a092 100644 --- a/vit_pytorch/levit.py +++ b/vit_pytorch/levit.py @@ -71,8 +71,8 @@ class Attention(nn.Module): q_range = torch.arange(0, fmap_size, step = (2 if downsample else 1)) k_range = torch.arange(fmap_size) - q_pos = torch.stack(torch.meshgrid(q_range, q_range), dim = -1) - k_pos = torch.stack(torch.meshgrid(k_range, k_range), dim = -1) + q_pos = torch.stack(torch.meshgrid(q_range, q_range, indexing = 'ij'), dim = -1) + k_pos = torch.stack(torch.meshgrid(k_range, k_range, indexing = 'ij'), dim = -1) q_pos, k_pos = map(lambda t: rearrange(t, 'i j c -> (i j) c'), (q_pos, k_pos)) rel_pos = (q_pos[:, None, ...] - k_pos[None, :, ...]).abs() diff --git a/vit_pytorch/max_vit.py b/vit_pytorch/max_vit.py new file mode 100644 index 0000000..cf9ac45 --- /dev/null +++ b/vit_pytorch/max_vit.py @@ -0,0 +1,270 @@ +from functools import partial + +import torch +from torch import nn, einsum + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange, Reduce + +# helpers + +def exists(val): + return val is not None + +def default(val, d): + return val if exists(val) else d + +def cast_tuple(val, length = 1): + return val if isinstance(val, tuple) else ((val,) * length) + +# helper classes + +class PreNormResidual(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + + def forward(self, x): + return self.fn(self.norm(x)) + x + +# MBConv + +class SqueezeExcitation(nn.Module): + def __init__(self, dim, shrinkage_rate = 0.25): + super().__init__() + hidden_dim = int(dim * shrinkage_rate) + + self.gate = nn.Sequential( + Reduce('b c h w -> b c', 'mean'), + nn.Linear(dim, hidden_dim, bias = False), + nn.SiLU(), + nn.Linear(hidden_dim, dim, bias = False), + nn.Sigmoid(), + Rearrange('b c -> b c 1 1') + ) + + def forward(self, x): + return x * self.gate(x) + + +class MBConvResidual(nn.Module): + def __init__(self, fn, dropout = 0.): + super().__init__() + self.fn = fn + self.dropsample = Dropsample(dropout) + + def forward(self, x): + out = self.fn(x) + out = self.dropsample(out) + return out + +class Dropsample(nn.Module): + def __init__(self, prob = 0): + super().__init__() + self.prob = prob + + def forward(self, x): + device = x.device + + if self.prob == 0. or (not self.training): + return x + + keep_mask = torch.FloatTensor((x.shape[0], 1, 1, 1), device = device).uniform_() > self.prob + return x * keep_mask / (1 - self.prob) + +def MBConv( + dim_in, + dim_out, + *, + downsample, + expansion_rate = 4, + shrinkage_rate = 0.25, + dropout = 0. +): + hidden_dim = int(expansion_rate * dim_out) + stride = 2 if downsample else 1 + + net = nn.Sequential( + nn.Conv2d(dim_in, dim_out, 1), + nn.BatchNorm2d(dim_out), + nn.SiLU(), + nn.Conv2d(dim_out, dim_out, 3, stride = stride, padding = 1, groups = dim_out), + SqueezeExcitation(dim_out, shrinkage_rate = shrinkage_rate), + nn.Conv2d(dim_out, dim_out, 1), + nn.BatchNorm2d(dim_out) + ) + + if dim_in == dim_out and not downsample: + net = MBConvResidual(net, dropout = dropout) + + return net + +# attention related classes + +class Attention(nn.Module): + def __init__( + self, + dim, + dim_head = 32, + dropout = 0., + window_size = 7 + ): + super().__init__() + assert (dim % dim_head) == 0, 'dimension should be divisible by dimension per head' + + self.heads = dim // dim_head + self.scale = dim_head ** -0.5 + + self.to_qkv = nn.Linear(dim, dim * 3, bias = False) + + self.attend = nn.Sequential( + nn.Softmax(dim = -1), + nn.Dropout(dropout) + ) + + self.to_out = nn.Sequential( + nn.Linear(dim, dim, bias = False), + nn.Dropout(dropout) + ) + + # relative positional bias + + self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads) + + pos = torch.arange(window_size) + grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij')) + grid = rearrange(grid, 'c i j -> (i j) c') + rel_pos = rearrange(grid, 'i ... -> i 1 ...') - rearrange(grid, 'j ... -> 1 j ...') + rel_pos += window_size - 1 + rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1) + + self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False) + + def forward(self, x): + batch, height, width, window_height, window_width, _, device, h = *x.shape, x.device, self.heads + + # flatten + + x = rearrange(x, 'b x y w1 w2 d -> (b x y) (w1 w2) d') + + # project for queries, keys, values + + q, k, v = self.to_qkv(x).chunk(3, dim = -1) + + # split heads + + q, k, v = map(lambda t: rearrange(t, 'b n (h d ) -> b h n d', h = h), (q, k, v)) + + # scale + + q = q * self.scale + + # sim + + sim = einsum('b h i d, b h j d -> b h i j', q, k) + + # add positional bias + + bias = self.rel_pos_bias(self.rel_pos_indices) + sim = sim + rearrange(bias, 'i j h -> h i j') + + # attention + + attn = self.attend(sim) + + # aggregate + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + + # merge heads + + out = rearrange(out, 'b h (w1 w2) d -> b w1 w2 (h d)', w1 = window_height, w2 = window_width) + + # combine heads out + + out = self.to_out(out) + return rearrange(out, '(b x y) ... -> b x y ...', x = height, y = width) + +class MaxViT(nn.Module): + def __init__( + self, + *, + num_classes, + dim, + depth, + dim_head = 32, + dim_conv_stem = None, + window_size = 7, + mbconv_expansion_rate = 4, + mbconv_shrinkage_rate = 0.25, + dropout = 0.1, + channels = 3 + ): + super().__init__() + assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage' + + # convolutional stem + + dim_conv_stem = default(dim_conv_stem, dim) + + self.conv_stem = nn.Sequential( + nn.Conv2d(channels, dim_conv_stem, 3, stride = 2, padding = 1), + nn.Conv2d(dim_conv_stem, dim_conv_stem, 3, padding = 1) + ) + + # variables + + num_stages = len(depth) + + dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages))) + dims = (dim_conv_stem, *dims) + dim_pairs = tuple(zip(dims[:-1], dims[1:])) + + self.layers = nn.ModuleList([]) + + # shorthand for window size for efficient block - grid like attention + + w = window_size + + # iterate through stages + + for ind, ((layer_dim_in, layer_dim), layer_depth) in enumerate(zip(dim_pairs, depth)): + for stage_ind in range(layer_depth): + is_first = stage_ind == 0 + stage_dim_in = layer_dim_in if is_first else layer_dim + + block = nn.Sequential( + MBConv( + stage_dim_in, + layer_dim, + downsample = is_first, + expansion_rate = mbconv_expansion_rate, + shrinkage_rate = mbconv_shrinkage_rate + ), + Rearrange('b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w), # block-like attention + PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)), + Rearrange('b x y w1 w2 d -> b d (x w1) (y w2)'), + + Rearrange('b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w), # grid-like attention + PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)), + Rearrange('b x y w1 w2 d -> b d (w1 x) (w2 y)'), + ) + + self.layers.append(block) + + # mlp head out + + self.mlp_head = nn.Sequential( + Reduce('b d h w -> b d', 'mean'), + nn.LayerNorm(dims[-1]), + nn.Linear(dims[-1], num_classes) + ) + + def forward(self, x): + x = self.conv_stem(x) + + for stage in self.layers: + x = stage(x) + + return self.mlp_head(x) diff --git a/vit_pytorch/regionvit.py b/vit_pytorch/regionvit.py index bdb6095..2e155a1 100644 --- a/vit_pytorch/regionvit.py +++ b/vit_pytorch/regionvit.py @@ -138,7 +138,7 @@ class R2LTransformer(nn.Module): h_range = torch.arange(window_size_h, device = device) w_range = torch.arange(window_size_w, device = device) - grid_x, grid_y = torch.meshgrid(h_range, w_range) + grid_x, grid_y = torch.meshgrid(h_range, w_range, indexing = 'ij') grid = torch.stack((grid_x, grid_y)) grid = rearrange(grid, 'c h w -> c (h w)') grid = (grid[:, :, None] - grid[:, None, :]) + (self.window_size - 1)