Compare commits

...

5 Commits

Author SHA1 Message Date
Phil Wang
1616288e30 add xcit (#284)
* add xcit

* use Rearrange layers

* give cross correlation transformer a final norm at end

* document
2023-10-13 09:15:13 -07:00
Jason Chou
9e1e824385 Update README.md (#283)
`patch_size` is size of patches, not number of patches
2023-10-09 11:33:56 -07:00
lucidrains
bbb24e34d4 give a learned bias to and from registers for maxvit + register token variant 2023-10-06 10:40:26 -07:00
lucidrains
df8733d86e improvise a max vit with register tokens 2023-10-06 10:27:36 -07:00
lucidrains
680d446e46 document in readme later 2023-10-03 09:26:02 -07:00
7 changed files with 676 additions and 5 deletions

View File

@@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.7, 3.8, 3.9]
python-version: [3.8, 3.9]
steps:
- uses: actions/checkout@v2

View File

@@ -25,6 +25,7 @@
- [MaxViT](#maxvit)
- [NesT](#nest)
- [MobileViT](#mobilevit)
- [XCiT](#xcit)
- [Masked Autoencoder](#masked-autoencoder)
- [Simple Masked Image Modeling](#simple-masked-image-modeling)
- [Masked Patch Prediction](#masked-patch-prediction)
@@ -92,7 +93,7 @@ preds = v(img) # (1, 1000)
- `image_size`: int.
Image size. If you have rectangular images, make sure your image size is the maximum of the width and height
- `patch_size`: int.
Number of patches. `image_size` must be divisible by `patch_size`.
Size of patches. `image_size` must be divisible by `patch_size`.
The number of patches is: ` n = (image_size // patch_size) ** 2` and `n` **must be greater than 16**.
- `num_classes`: int.
Number of classes to classify.
@@ -772,6 +773,38 @@ img = torch.randn(1, 3, 256, 256)
pred = mbvit_xs(img) # (1, 1000)
```
## XCiT
<img src="./images/xcit.png" width="400px"></img>
This <a href="https://arxiv.org/abs/2106.09681">paper</a> introduces the cross correlation attention (abbreviated XCA). One can think of it as doing attention across the features dimension rather than the spatial one (another perspective would be a dynamic 1x1 convolution, the kernel being attention map defined by spatial correlations).
Technically, this amounts to simply transposing the query, key, values before executing cosine similarity attention with learned temperature.
```python
import torch
from vit_pytorch.xcit import XCiT
v = XCiT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 12, # depth of xcit transformer
cls_depth = 2, # depth of cross attention of CLS tokens to patch, attention pool at end
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1,
layer_dropout = 0.05, # randomly dropout 5% of the layers
local_patch_kernel_size = 3 # kernel size of the local patch interaction module (depthwise convs)
)
img = torch.randn(1, 3, 256, 256)
preds = v(img) # (1, 1000)
```
## Simple Masked Image Modeling
<img src="./images/simmim.png" width="400px"/>
@@ -2029,4 +2062,14 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```
```bibtex
@inproceedings{ElNouby2021XCiTCI,
title = {XCiT: Cross-Covariance Image Transformers},
author = {Alaaeldin El-Nouby and Hugo Touvron and Mathilde Caron and Piotr Bojanowski and Matthijs Douze and Armand Joulin and Ivan Laptev and Natalia Neverova and Gabriel Synnaeve and Jakob Verbeek and Herv{\'e} J{\'e}gou},
booktitle = {Neural Information Processing Systems},
year = {2021},
url = {https://api.semanticscholar.org/CorpusID:235458262}
}
```
*I visualise a time when we will be to robots what dogs are to humans, and Im rooting for the machines.* — Claude Shannon

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.5.0',
version = '1.6.0',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description_content_type = 'text/markdown',
@@ -16,7 +16,7 @@ setup(
'image recognition'
],
install_requires=[
'einops>=0.6.1',
'einops>=0.7.0',
'torch>=1.10',
'torchvision'
],

View File

@@ -173,7 +173,7 @@ class Attention(nn.Module):
# split heads
q, k, v = map(lambda t: rearrange(t, 'b n (h d ) -> b h n d', h = h), (q, k, v))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
# scale

View File

@@ -0,0 +1,340 @@
from functools import partial
import torch
from torch import nn, einsum
import torch.nn.functional as F
from torch.nn import Module, ModuleList, Sequential
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange, Reduce
# helpers
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def pack_one(x, pattern):
return pack([x], pattern)
def unpack_one(x, ps, pattern):
return unpack(x, ps, pattern)[0]
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
# helper classes
def FeedForward(dim, mult = 4, dropout = 0.):
inner_dim = int(dim * mult)
return Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
# MBConv
class SqueezeExcitation(Module):
def __init__(self, dim, shrinkage_rate = 0.25):
super().__init__()
hidden_dim = int(dim * shrinkage_rate)
self.gate = Sequential(
Reduce('b c h w -> b c', 'mean'),
nn.Linear(dim, hidden_dim, bias = False),
nn.SiLU(),
nn.Linear(hidden_dim, dim, bias = False),
nn.Sigmoid(),
Rearrange('b c -> b c 1 1')
)
def forward(self, x):
return x * self.gate(x)
class MBConvResidual(Module):
def __init__(self, fn, dropout = 0.):
super().__init__()
self.fn = fn
self.dropsample = Dropsample(dropout)
def forward(self, x):
out = self.fn(x)
out = self.dropsample(out)
return out + x
class Dropsample(Module):
def __init__(self, prob = 0):
super().__init__()
self.prob = prob
def forward(self, x):
device = x.device
if self.prob == 0. or (not self.training):
return x
keep_mask = torch.FloatTensor((x.shape[0], 1, 1, 1), device = device).uniform_() > self.prob
return x * keep_mask / (1 - self.prob)
def MBConv(
dim_in,
dim_out,
*,
downsample,
expansion_rate = 4,
shrinkage_rate = 0.25,
dropout = 0.
):
hidden_dim = int(expansion_rate * dim_out)
stride = 2 if downsample else 1
net = Sequential(
nn.Conv2d(dim_in, hidden_dim, 1),
nn.BatchNorm2d(hidden_dim),
nn.GELU(),
nn.Conv2d(hidden_dim, hidden_dim, 3, stride = stride, padding = 1, groups = hidden_dim),
nn.BatchNorm2d(hidden_dim),
nn.GELU(),
SqueezeExcitation(hidden_dim, shrinkage_rate = shrinkage_rate),
nn.Conv2d(hidden_dim, dim_out, 1),
nn.BatchNorm2d(dim_out)
)
if dim_in == dim_out and not downsample:
net = MBConvResidual(net, dropout = dropout)
return net
# attention related classes
class Attention(Module):
def __init__(
self,
dim,
dim_head = 32,
dropout = 0.,
window_size = 7,
num_registers = 1
):
super().__init__()
assert num_registers > 0
assert (dim % dim_head) == 0, 'dimension should be divisible by dimension per head'
self.heads = dim // dim_head
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
self.attend = nn.Sequential(
nn.Softmax(dim = -1),
nn.Dropout(dropout)
)
self.to_out = nn.Sequential(
nn.Linear(dim, dim, bias = False),
nn.Dropout(dropout)
)
# relative positional bias
num_rel_pos_bias = (2 * window_size - 1) ** 2
self.rel_pos_bias = nn.Embedding(num_rel_pos_bias + 1, self.heads)
pos = torch.arange(window_size)
grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
grid = rearrange(grid, 'c i j -> (i j) c')
rel_pos = rearrange(grid, 'i ... -> i 1 ...') - rearrange(grid, 'j ... -> 1 j ...')
rel_pos += window_size - 1
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1)
rel_pos_indices = F.pad(rel_pos_indices, (num_registers, 0, num_registers, 0), value = num_rel_pos_bias)
self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False)
def forward(self, x):
device, h, bias_indices = x.device, self.heads, self.rel_pos_indices
x = self.norm(x)
# project for queries, keys, values
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
# split heads
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
# scale
q = q * self.scale
# sim
sim = einsum('b h i d, b h j d -> b h i j', q, k)
# add positional bias
bias = self.rel_pos_bias(bias_indices)
sim = sim + rearrange(bias, 'i j h -> h i j')
# attention
attn = self.attend(sim)
# aggregate
out = einsum('b h i j, b h j d -> b h i d', attn, v)
# combine heads out
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class MaxViT(Module):
def __init__(
self,
*,
num_classes,
dim,
depth,
dim_head = 32,
dim_conv_stem = None,
window_size = 7,
mbconv_expansion_rate = 4,
mbconv_shrinkage_rate = 0.25,
dropout = 0.1,
channels = 3,
num_register_tokens = 4
):
super().__init__()
assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'
assert num_register_tokens > 0
# convolutional stem
dim_conv_stem = default(dim_conv_stem, dim)
self.conv_stem = Sequential(
nn.Conv2d(channels, dim_conv_stem, 3, stride = 2, padding = 1),
nn.Conv2d(dim_conv_stem, dim_conv_stem, 3, padding = 1)
)
# variables
num_stages = len(depth)
dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))
dims = (dim_conv_stem, *dims)
dim_pairs = tuple(zip(dims[:-1], dims[1:]))
self.layers = nn.ModuleList([])
# window size
self.window_size = window_size
self.register_tokens = nn.ParameterList([])
# iterate through stages
for ind, ((layer_dim_in, layer_dim), layer_depth) in enumerate(zip(dim_pairs, depth)):
for stage_ind in range(layer_depth):
is_first = stage_ind == 0
stage_dim_in = layer_dim_in if is_first else layer_dim
conv = MBConv(
stage_dim_in,
layer_dim,
downsample = is_first,
expansion_rate = mbconv_expansion_rate,
shrinkage_rate = mbconv_shrinkage_rate
)
block_attn = Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = window_size, num_registers = num_register_tokens)
block_ff = FeedForward(dim = layer_dim, dropout = dropout)
grid_attn = Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = window_size, num_registers = num_register_tokens)
grid_ff = FeedForward(dim = layer_dim, dropout = dropout)
register_tokens = nn.Parameter(torch.randn(num_register_tokens, layer_dim))
self.layers.append(ModuleList([
conv,
ModuleList([block_attn, block_ff]),
ModuleList([grid_attn, grid_ff])
]))
self.register_tokens.append(register_tokens)
# mlp head out
self.mlp_head = nn.Sequential(
Reduce('b d h w -> b d', 'mean'),
nn.LayerNorm(dims[-1]),
nn.Linear(dims[-1], num_classes)
)
def forward(self, x):
b, w = x.shape[0], self.window_size
x = self.conv_stem(x)
for (conv, (block_attn, block_ff), (grid_attn, grid_ff)), register_tokens in zip(self.layers, self.register_tokens):
x = conv(x)
# block-like attention
x = rearrange(x, 'b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w)
# prepare register tokens
r = repeat(register_tokens, 'n d -> b x y n d', b = b, x = x.shape[1],y = x.shape[2])
r, register_batch_ps = pack_one(r, '* n d')
x, window_ps = pack_one(x, 'b x y * d')
x, batch_ps = pack_one(x, '* n d')
x, register_ps = pack([r, x], 'b * d')
x = block_attn(x) + x
x = block_ff(x) + x
r, x = unpack(x, register_ps, 'b * d')
x = unpack_one(x, batch_ps, '* n d')
x = unpack_one(x, window_ps, 'b x y * d')
x = rearrange(x, 'b x y w1 w2 d -> b d (x w1) (y w2)')
r = unpack_one(r, register_batch_ps, '* n d')
# grid-like attention
x = rearrange(x, 'b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w)
# prepare register tokens
r = reduce(r, 'b x y n d -> b n d', 'mean')
r = repeat(r, 'b n d -> b x y n d', x = x.shape[1], y = x.shape[2])
r, register_batch_ps = pack_one(r, '* n d')
x, window_ps = pack_one(x, 'b x y * d')
x, batch_ps = pack_one(x, '* n d')
x, register_ps = pack([r, x], 'b * d')
x = grid_attn(x) + x
r, x = unpack(x, register_ps, 'b * d')
x = grid_ff(x) + x
x = unpack_one(x, batch_ps, '* n d')
x = unpack_one(x, window_ps, 'b x y * d')
x = rearrange(x, 'b x y w1 w2 d -> b d (w1 x) (w2 y)')
return self.mlp_head(x)

View File

@@ -1,3 +1,8 @@
"""
Vision Transformers Need Registers
https://arxiv.org/abs/2309.16588
"""
import torch
from torch import nn

283
vit_pytorch/xcit.py Normal file
View File

@@ -0,0 +1,283 @@
from random import randrange
import torch
from torch import nn, einsum
from torch.nn import Module, ModuleList
import torch.nn.functional as F
from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange
# helpers
def exists(val):
return val is not None
def pack_one(t, pattern):
return pack([t], pattern)
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
def l2norm(t):
return F.normalize(t, dim = -1, p = 2)
def dropout_layers(layers, dropout):
if dropout == 0:
return layers
num_layers = len(layers)
to_drop = torch.zeros(num_layers).uniform_(0., 1.) < dropout
# make sure at least one layer makes it
if all(to_drop):
rand_index = randrange(num_layers)
to_drop[rand_index] = False
layers = [layer for (layer, drop) in zip(layers, to_drop) if not drop]
return layers
# classes
class LayerScale(Module):
def __init__(self, dim, fn, depth):
super().__init__()
if depth <= 18:
init_eps = 0.1
elif 18 > depth <= 24:
init_eps = 1e-5
else:
init_eps = 1e-6
self.fn = fn
self.scale = nn.Parameter(torch.full((dim,), init_eps))
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) * self.scale
class FeedForward(Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, context = None):
h = self.heads
x = self.norm(x)
context = x if not exists(context) else torch.cat((x, context), dim = 1)
qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = self.attend(sim)
attn = self.dropout(attn)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class XCAttention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.norm = nn.LayerNorm(dim)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
h = self.heads
x, ps = pack_one(x, 'b * d')
x = self.norm(x)
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h d n', h = h), (q, k, v))
q, k = map(l2norm, (q, k))
sim = einsum('b h i n, b h j n -> b h i j', q, k) * self.temperature.exp()
attn = self.attend(sim)
attn = self.dropout(attn)
out = einsum('b h i j, b h j n -> b h i n', attn, v)
out = rearrange(out, 'b h d n -> b n (h d)')
out = unpack_one(out, ps, 'b * d')
return self.to_out(out)
class LocalPatchInteraction(Module):
def __init__(self, dim, kernel_size = 3):
super().__init__()
assert (kernel_size % 2) == 1
padding = kernel_size // 2
self.net = nn.Sequential(
nn.LayerNorm(dim),
Rearrange('b h w c -> b c h w'),
nn.Conv2d(dim, dim, kernel_size, padding = padding, groups = dim),
nn.BatchNorm2d(dim),
nn.GELU(),
nn.Conv2d(dim, dim, kernel_size, padding = padding, groups = dim),
Rearrange('b c h w -> b h w c'),
)
def forward(self, x):
return self.net(x)
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., layer_dropout = 0.):
super().__init__()
self.layers = ModuleList([])
self.layer_dropout = layer_dropout
for ind in range(depth):
layer = ind + 1
self.layers.append(ModuleList([
LayerScale(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), depth = layer),
LayerScale(dim, FeedForward(dim, mlp_dim, dropout = dropout), depth = layer)
]))
def forward(self, x, context = None):
layers = dropout_layers(self.layers, dropout = self.layer_dropout)
for attn, ff in layers:
x = attn(x, context = context) + x
x = ff(x) + x
return x
class XCATransformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, local_patch_kernel_size = 3, dropout = 0., layer_dropout = 0.):
super().__init__()
self.layers = ModuleList([])
self.layer_dropout = layer_dropout
for ind in range(depth):
layer = ind + 1
self.layers.append(ModuleList([
LayerScale(dim, XCAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout), depth = layer),
LayerScale(dim, LocalPatchInteraction(dim, local_patch_kernel_size), depth = layer),
LayerScale(dim, FeedForward(dim, mlp_dim, dropout = dropout), depth = layer)
]))
def forward(self, x):
layers = dropout_layers(self.layers, dropout = self.layer_dropout)
for cross_covariance_attn, local_patch_interaction, ff in layers:
x = cross_covariance_attn(x) + x
x = local_patch_interaction(x) + x
x = ff(x) + x
return x
class XCiT(Module):
def __init__(
self,
*,
image_size,
patch_size,
num_classes,
dim,
depth,
cls_depth,
heads,
mlp_dim,
dim_head = 64,
dropout = 0.,
emb_dropout = 0.,
local_patch_kernel_size = 3,
layer_dropout = 0.
):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
patch_dim = 3 * patch_size ** 2
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_size, p2 = patch_size),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim)
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
self.cls_token = nn.Parameter(torch.randn(dim))
self.dropout = nn.Dropout(emb_dropout)
self.xcit_transformer = XCATransformer(dim, depth, heads, dim_head, mlp_dim, local_patch_kernel_size, dropout, layer_dropout)
self.final_norm = nn.LayerNorm(dim)
self.cls_transformer = Transformer(dim, cls_depth, heads, dim_head, mlp_dim, dropout, layer_dropout)
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
x = self.to_patch_embedding(img)
x, ps = pack_one(x, 'b * d')
b, n, _ = x.shape
x += self.pos_embedding[:, :n]
x = unpack_one(x, ps, 'b * d')
x = self.dropout(x)
x = self.xcit_transformer(x)
x = self.final_norm(x)
cls_tokens = repeat(self.cls_token, 'd -> b 1 d', b = b)
x = rearrange(x, 'b ... d -> b (...) d')
cls_tokens = self.cls_transformer(cls_tokens, context = x)
return self.mlp_head(cls_tokens[:, 0])