diff --git a/README.md b/README.md index 0a5a227..aaa2432 100644 --- a/README.md +++ b/README.md @@ -117,6 +117,31 @@ v = v.to_vit() type(v) # ``` +## Token-to-Token ViT + + + +This paper proposes that the first couple layers should downsample the image sequence by unfolding, leading to overlapping image data in each token as shown in the figure above. You can use this variant of the `ViT` as follows. + +```python +import torch +from vit_pytorch.t2t import T2TViT + +v = T2TViT( + dim = 512, + image_size = 224, + patch_size = 16, + depth = 5, + heads = 8, + mlp_dim = 512, + num_classes = 1000, + t2t_layers = ((7, 4), (3, 2), (3, 2)) # tuples of the kernel size and stride of each consecutive layers of the initial token to token module +) + +img = torch.randn(1, 3, 224, 224) +v(img) # (1, 1000) +``` + ## Research Ideas ### Self Supervised Training @@ -273,6 +298,17 @@ Coming from computer vision and new to transformers? Here are some resources tha } ``` +```bibtex +@misc{yuan2021tokenstotoken, + title = {Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet}, + author = {Li Yuan and Yunpeng Chen and Tao Wang and Weihao Yu and Yujun Shi and Francis EH Tay and Jiashi Feng and Shuicheng Yan}, + year = {2021}, + eprint = {2101.11986}, + archivePrefix = {arXiv}, + primaryClass = {cs.CV} +} +``` + ```bibtex @misc{vaswani2017attention, title = {Attention Is All You Need}, diff --git a/setup.py b/setup.py index 6a58727..e7607ad 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.6.8', + version = '0.7.0', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', author = 'Phil Wang', diff --git a/t2t.png b/t2t.png new file mode 100644 index 0000000..8277839 Binary files /dev/null and b/t2t.png differ diff --git a/vit_pytorch/t2t.py b/vit_pytorch/t2t.py new file mode 100644 index 0000000..84b39ae --- /dev/null +++ b/vit_pytorch/t2t.py @@ -0,0 +1,72 @@ +import math +import torch +from torch import nn + +from vit_pytorch.vit_pytorch import Transformer + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +# classes + +class RearrangeImage(nn.Module): + def forward(self, x): + return rearrange(x, 'b (h w) c -> b c h w', h = int(math.sqrt(x.shape[1]))) + +# main class + +class T2TViT(nn.Module): + def __init__( + self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., t2t_layers = ((7, 4), (3, 2), (3, 2))): + super().__init__() + assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' + num_patches = (image_size // patch_size) ** 2 + patch_dim = channels * patch_size ** 2 + assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' + + layers = [] + layer_dim = channels + + for i, (kernel_size, stride) in enumerate(t2t_layers): + layer_dim *= kernel_size ** 2 + is_first = i == 0 + + layers.extend([ + RearrangeImage() if not is_first else nn.Identity(), + nn.Unfold(kernel_size = kernel_size, stride = stride, padding = stride // 2), + Rearrange('b c n -> b n c'), + Transformer(dim = layer_dim, heads = 1, depth = 1, dim_head = layer_dim, mlp_dim = layer_dim, dropout = dropout), + ]) + + layers.append(nn.Linear(layer_dim, dim)) + self.to_patch_embedding = nn.Sequential(*layers) + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) + self.dropout = nn.Dropout(emb_dropout) + + self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) + + self.pool = pool + self.to_latent = nn.Identity() + + self.mlp_head = nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, num_classes) + ) + + def forward(self, img): + x = self.to_patch_embedding(img) + b, n, _ = x.shape + + cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) + x = torch.cat((cls_tokens, x), dim=1) + x += self.pos_embedding[:, :(n + 1)] + x = self.dropout(x) + + x = self.transformer(x) + + x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] + + x = self.to_latent(x) + return self.mlp_head(x) diff --git a/vit_pytorch/vit_pytorch.py b/vit_pytorch/vit_pytorch.py index a797068..56a5f8a 100644 --- a/vit_pytorch/vit_pytorch.py +++ b/vit_pytorch/vit_pytorch.py @@ -1,7 +1,10 @@ +import math import torch +from torch import nn, einsum import torch.nn.functional as F + from einops import rearrange, repeat -from torch import nn +from einops.layers.torch import Rearrange MIN_NUM_PATCHES = 16 @@ -51,7 +54,7 @@ class Attention(nn.Module): qkv = self.to_qkv(x).chunk(3, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) - dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale mask_value = -torch.finfo(dots.dtype).max if mask is not None: @@ -63,13 +66,13 @@ class Attention(nn.Module): attn = dots.softmax(dim=-1) - out = torch.einsum('b h i j, b h j d -> b h i d', attn, v) + out = einsum('b h i j, b h j d -> b h i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') out = self.to_out(out) return out class Transformer(nn.Module): - def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): @@ -92,10 +95,12 @@ class ViT(nn.Module): assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size' assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' - self.patch_size = patch_size + self.to_patch_embedding = nn.Sequential( + Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size), + nn.Linear(patch_dim, dim), + ) self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) - self.patch_to_embedding = nn.Linear(patch_dim, dim) self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) self.dropout = nn.Dropout(emb_dropout) @@ -110,10 +115,7 @@ class ViT(nn.Module): ) def forward(self, img, mask = None): - p = self.patch_size - - x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p) - x = self.patch_to_embedding(x) + x = self.to_patch_embedding(img) b, n, _ = x.shape cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)