mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
maxvit intent to build (#211)
complete hybrid mbconv + block / grid efficient self attention MaxViT
This commit is contained in:
40
README.md
40
README.md
@@ -20,6 +20,7 @@
|
|||||||
- [RegionViT](#regionvit)
|
- [RegionViT](#regionvit)
|
||||||
- [ScalableViT](#scalablevit)
|
- [ScalableViT](#scalablevit)
|
||||||
- [SepViT](#sepvit)
|
- [SepViT](#sepvit)
|
||||||
|
- [MaxViT](#maxvit)
|
||||||
- [NesT](#nest)
|
- [NesT](#nest)
|
||||||
- [MobileViT](#mobilevit)
|
- [MobileViT](#mobilevit)
|
||||||
- [Masked Autoencoder](#masked-autoencoder)
|
- [Masked Autoencoder](#masked-autoencoder)
|
||||||
@@ -596,6 +597,37 @@ img = torch.randn(1, 3, 224, 224)
|
|||||||
preds = v(img) # (1, 1000)
|
preds = v(img) # (1, 1000)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## MaxViT
|
||||||
|
|
||||||
|
<img src="./images/max-vit.png" width="400px"></img>
|
||||||
|
|
||||||
|
This paper proposes a hybrid convolutional / attention network, using MBConv from the convolution side, and then block / grid axial sparse attention.
|
||||||
|
|
||||||
|
They also claim this specific vision transformer is good for generative models (GANs).
|
||||||
|
|
||||||
|
ex. MaxViT-S
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from vit_pytorch.max_vit import MaxViT
|
||||||
|
|
||||||
|
v = MaxViT(
|
||||||
|
num_classes = 1000,
|
||||||
|
dim_conv_stem = 64, # dimension of the convolutional stem, would default to dimension of first layer if not specified
|
||||||
|
dim = 96, # dimension of first layer, doubles every layer
|
||||||
|
dim_head = 32, # dimension of attention heads, kept at 32 in paper
|
||||||
|
depth = (2, 2, 5, 2), # number of MaxViT blocks per stage, which consists of MBConv, block-like attention, grid-like attention
|
||||||
|
window_size = 7, # window size for block and grids
|
||||||
|
mbconv_expansion_rate = 4, # expansion rate of MBConv
|
||||||
|
mbconv_shrinkage_rate = 0.25, # shrinkage rate of squeeze-excitation in MBConv
|
||||||
|
dropout = 0.1 # dropout
|
||||||
|
)
|
||||||
|
|
||||||
|
img = torch.randn(2, 3, 224, 224)
|
||||||
|
|
||||||
|
preds = v(img) # (2, 1000)
|
||||||
|
```
|
||||||
|
|
||||||
## NesT
|
## NesT
|
||||||
|
|
||||||
<img src="./images/nest.png" width="400px"></img>
|
<img src="./images/nest.png" width="400px"></img>
|
||||||
@@ -1544,6 +1576,14 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@inproceedings{Tu2022MaxViTMV,
|
||||||
|
title = {MaxViT: Multi-Axis Vision Transformer},
|
||||||
|
author = {Zhe-Wei Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
|
||||||
|
year = {2022}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
```bibtex
|
```bibtex
|
||||||
@misc{vaswani2017attention,
|
@misc{vaswani2017attention,
|
||||||
title = {Attention Is All You Need},
|
title = {Attention Is All You Need},
|
||||||
|
|||||||
BIN
images/max-vit.png
Normal file
BIN
images/max-vit.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 133 KiB |
4
setup.py
4
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
|||||||
setup(
|
setup(
|
||||||
name = 'vit-pytorch',
|
name = 'vit-pytorch',
|
||||||
packages = find_packages(exclude=['examples']),
|
packages = find_packages(exclude=['examples']),
|
||||||
version = '0.32.2',
|
version = '0.33.0',
|
||||||
license='MIT',
|
license='MIT',
|
||||||
description = 'Vision Transformer (ViT) - Pytorch',
|
description = 'Vision Transformer (ViT) - Pytorch',
|
||||||
author = 'Phil Wang',
|
author = 'Phil Wang',
|
||||||
@@ -16,7 +16,7 @@ setup(
|
|||||||
],
|
],
|
||||||
install_requires=[
|
install_requires=[
|
||||||
'einops>=0.4.1',
|
'einops>=0.4.1',
|
||||||
'torch>=1.6',
|
'torch>=1.10',
|
||||||
'torchvision'
|
'torchvision'
|
||||||
],
|
],
|
||||||
setup_requires=[
|
setup_requires=[
|
||||||
|
|||||||
@@ -108,7 +108,7 @@ class Attention(nn.Module):
|
|||||||
# calculate and store indices for retrieving bias
|
# calculate and store indices for retrieving bias
|
||||||
|
|
||||||
pos = torch.arange(window_size)
|
pos = torch.arange(window_size)
|
||||||
grid = torch.stack(torch.meshgrid(pos, pos))
|
grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
|
||||||
grid = rearrange(grid, 'c i j -> (i j) c')
|
grid = rearrange(grid, 'c i j -> (i j) c')
|
||||||
rel_pos = grid[:, None] - grid[None, :]
|
rel_pos = grid[:, None] - grid[None, :]
|
||||||
rel_pos += window_size - 1
|
rel_pos += window_size - 1
|
||||||
@@ -144,7 +144,7 @@ class Attention(nn.Module):
|
|||||||
# add dynamic positional bias
|
# add dynamic positional bias
|
||||||
|
|
||||||
pos = torch.arange(-wsz, wsz + 1, device = device)
|
pos = torch.arange(-wsz, wsz + 1, device = device)
|
||||||
rel_pos = torch.stack(torch.meshgrid(pos, pos))
|
rel_pos = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
|
||||||
rel_pos = rearrange(rel_pos, 'c i j -> (i j) c')
|
rel_pos = rearrange(rel_pos, 'c i j -> (i j) c')
|
||||||
biases = self.dpb(rel_pos.float())
|
biases = self.dpb(rel_pos.float())
|
||||||
rel_pos_bias = biases[self.rel_pos_indices]
|
rel_pos_bias = biases[self.rel_pos_indices]
|
||||||
|
|||||||
@@ -71,8 +71,8 @@ class Attention(nn.Module):
|
|||||||
q_range = torch.arange(0, fmap_size, step = (2 if downsample else 1))
|
q_range = torch.arange(0, fmap_size, step = (2 if downsample else 1))
|
||||||
k_range = torch.arange(fmap_size)
|
k_range = torch.arange(fmap_size)
|
||||||
|
|
||||||
q_pos = torch.stack(torch.meshgrid(q_range, q_range), dim = -1)
|
q_pos = torch.stack(torch.meshgrid(q_range, q_range, indexing = 'ij'), dim = -1)
|
||||||
k_pos = torch.stack(torch.meshgrid(k_range, k_range), dim = -1)
|
k_pos = torch.stack(torch.meshgrid(k_range, k_range, indexing = 'ij'), dim = -1)
|
||||||
|
|
||||||
q_pos, k_pos = map(lambda t: rearrange(t, 'i j c -> (i j) c'), (q_pos, k_pos))
|
q_pos, k_pos = map(lambda t: rearrange(t, 'i j c -> (i j) c'), (q_pos, k_pos))
|
||||||
rel_pos = (q_pos[:, None, ...] - k_pos[None, :, ...]).abs()
|
rel_pos = (q_pos[:, None, ...] - k_pos[None, :, ...]).abs()
|
||||||
|
|||||||
270
vit_pytorch/max_vit.py
Normal file
270
vit_pytorch/max_vit.py
Normal file
@@ -0,0 +1,270 @@
|
|||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn, einsum
|
||||||
|
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from einops.layers.torch import Rearrange, Reduce
|
||||||
|
|
||||||
|
# helpers
|
||||||
|
|
||||||
|
def exists(val):
|
||||||
|
return val is not None
|
||||||
|
|
||||||
|
def default(val, d):
|
||||||
|
return val if exists(val) else d
|
||||||
|
|
||||||
|
def cast_tuple(val, length = 1):
|
||||||
|
return val if isinstance(val, tuple) else ((val,) * length)
|
||||||
|
|
||||||
|
# helper classes
|
||||||
|
|
||||||
|
class PreNormResidual(nn.Module):
|
||||||
|
def __init__(self, dim, fn):
|
||||||
|
super().__init__()
|
||||||
|
self.norm = nn.LayerNorm(dim)
|
||||||
|
self.fn = fn
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.fn(self.norm(x)) + x
|
||||||
|
|
||||||
|
# MBConv
|
||||||
|
|
||||||
|
class SqueezeExcitation(nn.Module):
|
||||||
|
def __init__(self, dim, shrinkage_rate = 0.25):
|
||||||
|
super().__init__()
|
||||||
|
hidden_dim = int(dim * shrinkage_rate)
|
||||||
|
|
||||||
|
self.gate = nn.Sequential(
|
||||||
|
Reduce('b c h w -> b c', 'mean'),
|
||||||
|
nn.Linear(dim, hidden_dim, bias = False),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(hidden_dim, dim, bias = False),
|
||||||
|
nn.Sigmoid(),
|
||||||
|
Rearrange('b c -> b c 1 1')
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x * self.gate(x)
|
||||||
|
|
||||||
|
|
||||||
|
class MBConvResidual(nn.Module):
|
||||||
|
def __init__(self, fn, dropout = 0.):
|
||||||
|
super().__init__()
|
||||||
|
self.fn = fn
|
||||||
|
self.dropsample = Dropsample(dropout)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.fn(x)
|
||||||
|
out = self.dropsample(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
class Dropsample(nn.Module):
|
||||||
|
def __init__(self, prob = 0):
|
||||||
|
super().__init__()
|
||||||
|
self.prob = prob
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
device = x.device
|
||||||
|
|
||||||
|
if self.prob == 0. or (not self.training):
|
||||||
|
return x
|
||||||
|
|
||||||
|
keep_mask = torch.FloatTensor((x.shape[0], 1, 1, 1), device = device).uniform_() > self.prob
|
||||||
|
return x * keep_mask / (1 - self.prob)
|
||||||
|
|
||||||
|
def MBConv(
|
||||||
|
dim_in,
|
||||||
|
dim_out,
|
||||||
|
*,
|
||||||
|
downsample,
|
||||||
|
expansion_rate = 4,
|
||||||
|
shrinkage_rate = 0.25,
|
||||||
|
dropout = 0.
|
||||||
|
):
|
||||||
|
hidden_dim = int(expansion_rate * dim_out)
|
||||||
|
stride = 2 if downsample else 1
|
||||||
|
|
||||||
|
net = nn.Sequential(
|
||||||
|
nn.Conv2d(dim_in, dim_out, 1),
|
||||||
|
nn.BatchNorm2d(dim_out),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Conv2d(dim_out, dim_out, 3, stride = stride, padding = 1, groups = dim_out),
|
||||||
|
SqueezeExcitation(dim_out, shrinkage_rate = shrinkage_rate),
|
||||||
|
nn.Conv2d(dim_out, dim_out, 1),
|
||||||
|
nn.BatchNorm2d(dim_out)
|
||||||
|
)
|
||||||
|
|
||||||
|
if dim_in == dim_out and not downsample:
|
||||||
|
net = MBConvResidual(net, dropout = dropout)
|
||||||
|
|
||||||
|
return net
|
||||||
|
|
||||||
|
# attention related classes
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
dim_head = 32,
|
||||||
|
dropout = 0.,
|
||||||
|
window_size = 7
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert (dim % dim_head) == 0, 'dimension should be divisible by dimension per head'
|
||||||
|
|
||||||
|
self.heads = dim // dim_head
|
||||||
|
self.scale = dim_head ** -0.5
|
||||||
|
|
||||||
|
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
|
||||||
|
|
||||||
|
self.attend = nn.Sequential(
|
||||||
|
nn.Softmax(dim = -1),
|
||||||
|
nn.Dropout(dropout)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.to_out = nn.Sequential(
|
||||||
|
nn.Linear(dim, dim, bias = False),
|
||||||
|
nn.Dropout(dropout)
|
||||||
|
)
|
||||||
|
|
||||||
|
# relative positional bias
|
||||||
|
|
||||||
|
self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads)
|
||||||
|
|
||||||
|
pos = torch.arange(window_size)
|
||||||
|
grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
|
||||||
|
grid = rearrange(grid, 'c i j -> (i j) c')
|
||||||
|
rel_pos = rearrange(grid, 'i ... -> i 1 ...') - rearrange(grid, 'j ... -> 1 j ...')
|
||||||
|
rel_pos += window_size - 1
|
||||||
|
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1)
|
||||||
|
|
||||||
|
self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
batch, height, width, window_height, window_width, _, device, h = *x.shape, x.device, self.heads
|
||||||
|
|
||||||
|
# flatten
|
||||||
|
|
||||||
|
x = rearrange(x, 'b x y w1 w2 d -> (b x y) (w1 w2) d')
|
||||||
|
|
||||||
|
# project for queries, keys, values
|
||||||
|
|
||||||
|
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
|
||||||
|
|
||||||
|
# split heads
|
||||||
|
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d ) -> b h n d', h = h), (q, k, v))
|
||||||
|
|
||||||
|
# scale
|
||||||
|
|
||||||
|
q = q * self.scale
|
||||||
|
|
||||||
|
# sim
|
||||||
|
|
||||||
|
sim = einsum('b h i d, b h j d -> b h i j', q, k)
|
||||||
|
|
||||||
|
# add positional bias
|
||||||
|
|
||||||
|
bias = self.rel_pos_bias(self.rel_pos_indices)
|
||||||
|
sim = sim + rearrange(bias, 'i j h -> h i j')
|
||||||
|
|
||||||
|
# attention
|
||||||
|
|
||||||
|
attn = self.attend(sim)
|
||||||
|
|
||||||
|
# aggregate
|
||||||
|
|
||||||
|
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||||
|
|
||||||
|
# merge heads
|
||||||
|
|
||||||
|
out = rearrange(out, 'b h (w1 w2) d -> b w1 w2 (h d)', w1 = window_height, w2 = window_width)
|
||||||
|
|
||||||
|
# combine heads out
|
||||||
|
|
||||||
|
out = self.to_out(out)
|
||||||
|
return rearrange(out, '(b x y) ... -> b x y ...', x = height, y = width)
|
||||||
|
|
||||||
|
class MaxViT(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
num_classes,
|
||||||
|
dim,
|
||||||
|
depth,
|
||||||
|
dim_head = 32,
|
||||||
|
dim_conv_stem = None,
|
||||||
|
window_size = 7,
|
||||||
|
mbconv_expansion_rate = 4,
|
||||||
|
mbconv_shrinkage_rate = 0.25,
|
||||||
|
dropout = 0.1,
|
||||||
|
channels = 3
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'
|
||||||
|
|
||||||
|
# convolutional stem
|
||||||
|
|
||||||
|
dim_conv_stem = default(dim_conv_stem, dim)
|
||||||
|
|
||||||
|
self.conv_stem = nn.Sequential(
|
||||||
|
nn.Conv2d(channels, dim_conv_stem, 3, stride = 2, padding = 1),
|
||||||
|
nn.Conv2d(dim_conv_stem, dim_conv_stem, 3, padding = 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
# variables
|
||||||
|
|
||||||
|
num_stages = len(depth)
|
||||||
|
|
||||||
|
dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))
|
||||||
|
dims = (dim_conv_stem, *dims)
|
||||||
|
dim_pairs = tuple(zip(dims[:-1], dims[1:]))
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList([])
|
||||||
|
|
||||||
|
# shorthand for window size for efficient block - grid like attention
|
||||||
|
|
||||||
|
w = window_size
|
||||||
|
|
||||||
|
# iterate through stages
|
||||||
|
|
||||||
|
for ind, ((layer_dim_in, layer_dim), layer_depth) in enumerate(zip(dim_pairs, depth)):
|
||||||
|
for stage_ind in range(layer_depth):
|
||||||
|
is_first = stage_ind == 0
|
||||||
|
stage_dim_in = layer_dim_in if is_first else layer_dim
|
||||||
|
|
||||||
|
block = nn.Sequential(
|
||||||
|
MBConv(
|
||||||
|
stage_dim_in,
|
||||||
|
layer_dim,
|
||||||
|
downsample = is_first,
|
||||||
|
expansion_rate = mbconv_expansion_rate,
|
||||||
|
shrinkage_rate = mbconv_shrinkage_rate
|
||||||
|
),
|
||||||
|
Rearrange('b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w), # block-like attention
|
||||||
|
PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
|
||||||
|
Rearrange('b x y w1 w2 d -> b d (x w1) (y w2)'),
|
||||||
|
|
||||||
|
Rearrange('b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w), # grid-like attention
|
||||||
|
PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
|
||||||
|
Rearrange('b x y w1 w2 d -> b d (w1 x) (w2 y)'),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.layers.append(block)
|
||||||
|
|
||||||
|
# mlp head out
|
||||||
|
|
||||||
|
self.mlp_head = nn.Sequential(
|
||||||
|
Reduce('b d h w -> b d', 'mean'),
|
||||||
|
nn.LayerNorm(dims[-1]),
|
||||||
|
nn.Linear(dims[-1], num_classes)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv_stem(x)
|
||||||
|
|
||||||
|
for stage in self.layers:
|
||||||
|
x = stage(x)
|
||||||
|
|
||||||
|
return self.mlp_head(x)
|
||||||
@@ -138,7 +138,7 @@ class R2LTransformer(nn.Module):
|
|||||||
h_range = torch.arange(window_size_h, device = device)
|
h_range = torch.arange(window_size_h, device = device)
|
||||||
w_range = torch.arange(window_size_w, device = device)
|
w_range = torch.arange(window_size_w, device = device)
|
||||||
|
|
||||||
grid_x, grid_y = torch.meshgrid(h_range, w_range)
|
grid_x, grid_y = torch.meshgrid(h_range, w_range, indexing = 'ij')
|
||||||
grid = torch.stack((grid_x, grid_y))
|
grid = torch.stack((grid_x, grid_y))
|
||||||
grid = rearrange(grid, 'c h w -> c (h w)')
|
grid = rearrange(grid, 'c h w -> c (h w)')
|
||||||
grid = (grid[:, :, None] - grid[:, None, :]) + (self.window_size - 1)
|
grid = (grid[:, :, None] - grid[:, None, :]) + (self.window_size - 1)
|
||||||
|
|||||||
Reference in New Issue
Block a user