diff --git a/README.md b/README.md index 965bdcb..89e7a9c 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ - [MaxViT](#maxvit) - [NesT](#nest) - [MobileViT](#mobilevit) +- [XCiT](#xcit) - [Masked Autoencoder](#masked-autoencoder) - [Simple Masked Image Modeling](#simple-masked-image-modeling) - [Masked Patch Prediction](#masked-patch-prediction) @@ -772,6 +773,38 @@ img = torch.randn(1, 3, 256, 256) pred = mbvit_xs(img) # (1, 1000) ``` +## XCiT + + + +This paper introduces the cross correlation attention (abbreviated XCA). One can think of it as doing attention across the features dimension rather than the spatial one (another perspective would be a dynamic 1x1 convolution, the kernel being attention map defined by spatial correlations). + +Technically, this amounts to simply transposing the query, key, values before executing cosine similarity attention with learned temperature. + +```python +import torch +from vit_pytorch.xcit import XCiT + +v = XCiT( + image_size = 256, + patch_size = 32, + num_classes = 1000, + dim = 1024, + depth = 12, # depth of xcit transformer + cls_depth = 2, # depth of cross attention of CLS tokens to patch, attention pool at end + heads = 16, + mlp_dim = 2048, + dropout = 0.1, + emb_dropout = 0.1, + layer_dropout = 0.05, # randomly dropout 5% of the layers + local_patch_kernel_size = 3 # kernel size of the local patch interaction module (depthwise convs) +) + +img = torch.randn(1, 3, 256, 256) + +preds = v(img) # (1, 1000) +``` + ## Simple Masked Image Modeling @@ -2029,4 +2062,14 @@ Coming from computer vision and new to transformers? Here are some resources tha } ``` +```bibtex +@inproceedings{ElNouby2021XCiTCI, + title = {XCiT: Cross-Covariance Image Transformers}, + author = {Alaaeldin El-Nouby and Hugo Touvron and Mathilde Caron and Piotr Bojanowski and Matthijs Douze and Armand Joulin and Ivan Laptev and Natalia Neverova and Gabriel Synnaeve and Jakob Verbeek and Herv{\'e} J{\'e}gou}, + booktitle = {Neural Information Processing Systems}, + year = {2021}, + url = {https://api.semanticscholar.org/CorpusID:235458262} +} +``` + *I visualise a time when we will be to robots what dogs are to humans, and I’m rooting for the machines.* — Claude Shannon diff --git a/setup.py b/setup.py index 929b58a..a836cbf 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '1.5.3', + version = '1.6.0', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', long_description_content_type = 'text/markdown', diff --git a/vit_pytorch/xcit.py b/vit_pytorch/xcit.py new file mode 100644 index 0000000..83b098f --- /dev/null +++ b/vit_pytorch/xcit.py @@ -0,0 +1,283 @@ +from random import randrange + +import torch +from torch import nn, einsum +from torch.nn import Module, ModuleList +import torch.nn.functional as F + +from einops import rearrange, repeat, pack, unpack +from einops.layers.torch import Rearrange + +# helpers + +def exists(val): + return val is not None + +def pack_one(t, pattern): + return pack([t], pattern) + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + +def l2norm(t): + return F.normalize(t, dim = -1, p = 2) + +def dropout_layers(layers, dropout): + if dropout == 0: + return layers + + num_layers = len(layers) + to_drop = torch.zeros(num_layers).uniform_(0., 1.) < dropout + + # make sure at least one layer makes it + if all(to_drop): + rand_index = randrange(num_layers) + to_drop[rand_index] = False + + layers = [layer for (layer, drop) in zip(layers, to_drop) if not drop] + return layers + +# classes + +class LayerScale(Module): + def __init__(self, dim, fn, depth): + super().__init__() + if depth <= 18: + init_eps = 0.1 + elif 18 > depth <= 24: + init_eps = 1e-5 + else: + init_eps = 1e-6 + + self.fn = fn + self.scale = nn.Parameter(torch.full((dim,), init_eps)) + + def forward(self, x, **kwargs): + return self.fn(x, **kwargs) * self.scale + +class FeedForward(Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + self.heads = heads + self.scale = dim_head ** -0.5 + + self.norm = nn.LayerNorm(dim) + self.to_q = nn.Linear(dim, inner_dim, bias = False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context = None): + h = self.heads + + x = self.norm(x) + context = x if not exists(context) else torch.cat((x, context), dim = 1) + + qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1)) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) + + sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + + attn = self.attend(sim) + attn = self.dropout(attn) + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class XCAttention(Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + self.heads = heads + self.norm = nn.LayerNorm(dim) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.temperature = nn.Parameter(torch.ones(heads, 1, 1)) + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) + + def forward(self, x): + h = self.heads + x, ps = pack_one(x, 'b * d') + + x = self.norm(x) + q, k, v = self.to_qkv(x).chunk(3, dim = -1) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h d n', h = h), (q, k, v)) + + q, k = map(l2norm, (q, k)) + + sim = einsum('b h i n, b h j n -> b h i j', q, k) * self.temperature.exp() + + attn = self.attend(sim) + attn = self.dropout(attn) + + out = einsum('b h i j, b h j n -> b h i n', attn, v) + out = rearrange(out, 'b h d n -> b n (h d)') + + out = unpack_one(out, ps, 'b * d') + return self.to_out(out) + +class LocalPatchInteraction(Module): + def __init__(self, dim, kernel_size = 3): + super().__init__() + assert (kernel_size % 2) == 1 + padding = kernel_size // 2 + + self.net = nn.Sequential( + nn.LayerNorm(dim), + Rearrange('b h w c -> b c h w'), + nn.Conv2d(dim, dim, kernel_size, padding = padding, groups = dim), + nn.BatchNorm2d(dim), + nn.GELU(), + nn.Conv2d(dim, dim, kernel_size, padding = padding, groups = dim), + Rearrange('b c h w -> b h w c'), + ) + + def forward(self, x): + return self.net(x) + +class Transformer(Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., layer_dropout = 0.): + super().__init__() + self.layers = ModuleList([]) + self.layer_dropout = layer_dropout + + for ind in range(depth): + layer = ind + 1 + self.layers.append(ModuleList([ + LayerScale(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), depth = layer), + LayerScale(dim, FeedForward(dim, mlp_dim, dropout = dropout), depth = layer) + ])) + + def forward(self, x, context = None): + layers = dropout_layers(self.layers, dropout = self.layer_dropout) + + for attn, ff in layers: + x = attn(x, context = context) + x + x = ff(x) + x + + return x + +class XCATransformer(Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, local_patch_kernel_size = 3, dropout = 0., layer_dropout = 0.): + super().__init__() + self.layers = ModuleList([]) + self.layer_dropout = layer_dropout + + for ind in range(depth): + layer = ind + 1 + self.layers.append(ModuleList([ + LayerScale(dim, XCAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout), depth = layer), + LayerScale(dim, LocalPatchInteraction(dim, local_patch_kernel_size), depth = layer), + LayerScale(dim, FeedForward(dim, mlp_dim, dropout = dropout), depth = layer) + ])) + + def forward(self, x): + layers = dropout_layers(self.layers, dropout = self.layer_dropout) + + for cross_covariance_attn, local_patch_interaction, ff in layers: + x = cross_covariance_attn(x) + x + x = local_patch_interaction(x) + x + x = ff(x) + x + + return x + +class XCiT(Module): + def __init__( + self, + *, + image_size, + patch_size, + num_classes, + dim, + depth, + cls_depth, + heads, + mlp_dim, + dim_head = 64, + dropout = 0., + emb_dropout = 0., + local_patch_kernel_size = 3, + layer_dropout = 0. + ): + super().__init__() + assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' + + num_patches = (image_size // patch_size) ** 2 + patch_dim = 3 * patch_size ** 2 + + self.to_patch_embedding = nn.Sequential( + Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_size, p2 = patch_size), + nn.LayerNorm(patch_dim), + nn.Linear(patch_dim, dim), + nn.LayerNorm(dim) + ) + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim)) + self.cls_token = nn.Parameter(torch.randn(dim)) + + self.dropout = nn.Dropout(emb_dropout) + + self.xcit_transformer = XCATransformer(dim, depth, heads, dim_head, mlp_dim, local_patch_kernel_size, dropout, layer_dropout) + + self.final_norm = nn.LayerNorm(dim) + + self.cls_transformer = Transformer(dim, cls_depth, heads, dim_head, mlp_dim, dropout, layer_dropout) + + self.mlp_head = nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, num_classes) + ) + + def forward(self, img): + x = self.to_patch_embedding(img) + + x, ps = pack_one(x, 'b * d') + + b, n, _ = x.shape + x += self.pos_embedding[:, :n] + + x = unpack_one(x, ps, 'b * d') + + x = self.dropout(x) + + x = self.xcit_transformer(x) + + x = self.final_norm(x) + + cls_tokens = repeat(self.cls_token, 'd -> b 1 d', b = b) + + x = rearrange(x, 'b ... d -> b (...) d') + cls_tokens = self.cls_transformer(cls_tokens, context = x) + + return self.mlp_head(cls_tokens[:, 0])