diff --git a/README.md b/README.md index ada0f53..2278f78 100644 --- a/README.md +++ b/README.md @@ -443,6 +443,17 @@ Coming from computer vision and new to transformers? Here are some resources tha } ``` +```bibtex +@misc{wu2021cvt, + title = {CvT: Introducing Convolutions to Vision Transformers}, + author = {Haiping Wu and Bin Xiao and Noel Codella and Mengchen Liu and Xiyang Dai and Lu Yuan and Lei Zhang}, + year = {2021}, + eprint = {2103.15808}, + archivePrefix = {arXiv}, + primaryClass = {cs.CV} +} +``` + ```bibtex @misc{vaswani2017attention, title = {Attention Is All You Need}, diff --git a/setup.py b/setup.py index 0999784..827acf0 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.10.3', + version = '0.11.0', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', author = 'Phil Wang', diff --git a/vit_pytorch/cvt.py b/vit_pytorch/cvt.py new file mode 100644 index 0000000..b2a1d05 --- /dev/null +++ b/vit_pytorch/cvt.py @@ -0,0 +1,151 @@ +import torch +from torch import nn, einsum +import torch.nn.functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +# helper methods + +def group_dict_by_key(cond, d): + return_val = [dict(), dict()] + for key in d.keys(): + match = bool(cond(key)) + ind = int(not match) + return_val[ind][key] = d[key] + return (*return_val,) + +def group_by_key_prefix_and_remove_prefix(prefix, d): + kwargs_with_prefix, kwargs = group_dict_by_key(lambda x: x.startswith(prefix), d) + kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) + return kwargs_without_prefix, kwargs + +# classes + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + x = rearrange(x, 'b c h w -> b h w c') + x = self.norm(x) + x = rearrange(x, 'b h w c -> b c h w') + return self.fn(x, **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, mult = 4, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Conv2d(dim, dim * mult, 1), + nn.GELU(), + nn.Dropout(dropout), + nn.Conv2d(dim * mult, dim, 1), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, proj_kernel, kv_proj_stride, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + padding = proj_kernel // 2 + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + + self.to_q = nn.Conv2d(dim, inner_dim, 3, padding = padding, stride = 1, bias = False) + self.to_kv = nn.Conv2d(dim, inner_dim * 2, 3, padding = padding, stride = kv_proj_stride, bias = False) + + self.to_out = nn.Sequential( + nn.Conv2d(inner_dim, dim, 1), + nn.Dropout(dropout) + ) + + def forward(self, x): + shape = x.shape + b, n, _, y, h = *shape, self.heads + q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = 1)) + q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> (b h) (x y) d', h = h), (q, k, v)) + + dots = einsum('b i d, b j d -> b i j', q, k) * self.scale + + attn = self.attend(dots) + + out = einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, y = y) + return self.to_out(out) + +class Transformer(nn.Module): + def __init__(self, dim, proj_kernel, kv_proj_stride, depth, heads, dim_head = 64, mlp_mult = 4, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, proj_kernel = proj_kernel, kv_proj_stride = kv_proj_stride, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout)) + ])) + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + +class CvT(nn.Module): + def __init__( + self, + *, + num_classes, + s1_emb_dim = 64, + s1_emb_kernel = 7, + s1_emb_stride = 4, + s1_proj_kernel = 3, + s1_kv_proj_stride = 2, + s1_heads = 1, + s1_depth = 1, + s1_mlp_mult = 4, + s2_emb_dim = 192, + s2_emb_kernel = 3, + s2_emb_stride = 2, + s2_proj_kernel = 3, + s2_kv_proj_stride = 2, + s2_heads = 3, + s2_depth = 2, + s2_mlp_mult = 4, + s3_emb_dim = 384, + s3_emb_kernel = 3, + s3_emb_stride = 2, + s3_proj_kernel = 3, + s3_kv_proj_stride = 2, + s3_heads = 4, + s3_depth = 10, + s3_mlp_mult = 4, + dropout = 0. + ): + super().__init__() + kwargs = dict(locals()) + + dim = 3 + layers = [] + + for prefix in ('s1', 's2', 's3'): + config, kwargs = group_by_key_prefix_and_remove_prefix(f'{prefix}_', kwargs) + + layers.append(nn.Sequential( + nn.Conv2d(dim, config['emb_dim'], kernel_size = config['emb_kernel'], padding = (config['emb_kernel'] // 2), stride = config['emb_stride']), + Transformer(dim = config['emb_dim'], proj_kernel = config['proj_kernel'], kv_proj_stride = config['kv_proj_stride'], depth = config['depth'], heads = config['heads'], mlp_mult = config['mlp_mult'], dropout = dropout) + )) + + dim = config['emb_dim'] + + self.layers = nn.Sequential( + *layers, + nn.AdaptiveAvgPool2d(1), + Rearrange('... () () -> ...'), + nn.Linear(dim, num_classes) + ) + + def forward(self, x): + return self.layers(x)