diff --git a/README.md b/README.md index 1eb3abb..edb97ee 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ - [Simple Masked Image Modeling](#simple-masked-image-modeling) - [Masked Patch Prediction](#masked-patch-prediction) - [Adaptive Token Sampling](#adaptive-token-sampling) +- [Patch Merger](#patch-merger) - [Vision Transformer for Small Datasets](#vision-transformer-for-small-datasets) - [Dino](#dino) - [Accessing Attention](#accessing-attention) @@ -732,12 +733,58 @@ v = ViT( img = torch.randn(4, 3, 256, 256) -preds = v(img) # (1, 1000) +preds = v(img) # (4, 1000) # you can also get a list of the final sampled patch ids # a value of -1 denotes padding -preds, token_ids = v(img, return_sampled_token_ids = True) # (1, 1000), (1, <=8) +preds, token_ids = v(img, return_sampled_token_ids = True) # (4, 1000), (4, <=8) +``` + +## Patch Merger + + + + +This paper proposes a simple module (Patch Merger) for reducing the number of tokens at any layer of a vision transformer without sacrificing performance. + +```python +import torch +from vit_pytorch.vit_with_patch_merger import ViT + +v = ViT( + image_size = 256, + patch_size = 16, + num_classes = 1000, + dim = 1024, + depth = 12, + heads = 8, + patch_merge_layer = 6, # at which transformer layer to do patch merging + patch_merge_num_tokens = 8, # the output number of tokens from the patch merge + mlp_dim = 2048, + dropout = 0.1, + emb_dropout = 0.1 +) + +img = torch.randn(4, 3, 256, 256) + +preds = v(img) # (4, 1000) +``` + +One can also use the `PatchMerger` module by itself + +```python +import torch +from vit_pytorch.vit_with_patch_merger import PatchMerger + +merger = PatchMerger( + dim = 1024, + num_tokens_out = 8 # output number of tokens +) + +features = torch.randn(4, 256, 1024) # (batch, num tokens, dimension) + +out = merger(features) # (4, 8, 1024) ``` ## Vision Transformer for Small Datasets @@ -1294,6 +1341,17 @@ Coming from computer vision and new to transformers? Here are some resources tha } ``` +```bibtex +@misc{renggli2022learning, + title = {Learning to Merge Tokens in Vision Transformers}, + author = {Cedric Renggli and André Susano Pinto and Neil Houlsby and Basil Mustafa and Joan Puigcerver and Carlos Riquelme}, + year = {2022}, + eprint = {2202.12015}, + archivePrefix = {arXiv}, + primaryClass = {cs.CV} +} +``` + ```bibtex @misc{vaswani2017attention, title = {Attention Is All You Need}, diff --git a/images/patch_merger.png b/images/patch_merger.png new file mode 100644 index 0000000..b7a537c Binary files /dev/null and b/images/patch_merger.png differ diff --git a/setup.py b/setup.py index 122e5b0..9f50a6d 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.26.7', + version = '0.27.0', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', author = 'Phil Wang', diff --git a/vit_pytorch/vit_with_patch_merger.py b/vit_pytorch/vit_with_patch_merger.py new file mode 100644 index 0000000..3106bb3 --- /dev/null +++ b/vit_pytorch/vit_with_patch_merger.py @@ -0,0 +1,144 @@ +import torch +from torch import nn + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange, Reduce + +# helpers + +def exists(val): + return val is not None + +def default(val ,d): + return val if exists(val) else d + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + +# patch merger class + +class PatchMerger(nn.Module): + def __init__(self, dim, num_tokens_out): + super().__init__() + self.scale = dim ** -0.5 + self.norm = nn.LayerNorm(dim) + self.queries = nn.Parameter(torch.randn(num_tokens_out, dim)) + + def forward(self, x): + x = self.norm(x) + sim = torch.matmul(self.queries, x.transpose(-1, -2)) * self.scale + attn = sim.softmax(dim = -1) + return torch.matmul(attn, x) + +# classes + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., patch_merge_layer = None, patch_merge_num_tokens = 8): + super().__init__() + self.layers = nn.ModuleList([]) + + self.patch_merge_layer_index = default(patch_merge_layer, depth // 2) - 1 # default to mid-way through transformer, as shown in paper + self.patch_merger = PatchMerger(dim = dim, num_tokens_out = patch_merge_num_tokens) + + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def forward(self, x): + for index, (attn, ff) in enumerate(self.layers): + x = attn(x) + x + x = ff(x) + x + + if index == self.patch_merge_layer_index: + x = self.patch_merger(x) + + return x + +class ViT(nn.Module): + def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, patch_merge_layer = None, patch_merge_num_tokens = 8, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): + super().__init__() + image_height, image_width = pair(image_size) + patch_height, patch_width = pair(patch_size) + + assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' + + num_patches = (image_height // patch_height) * (image_width // patch_width) + patch_dim = channels * patch_height * patch_width + + self.to_patch_embedding = nn.Sequential( + Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), + nn.Linear(patch_dim, dim), + ) + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) + self.dropout = nn.Dropout(emb_dropout) + + self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, patch_merge_layer, patch_merge_num_tokens) + + self.mlp_head = nn.Sequential( + Reduce('b n d -> b d', 'mean'), + nn.LayerNorm(dim), + nn.Linear(dim, num_classes) + ) + + def forward(self, img): + x = self.to_patch_embedding(img) + b, n, _ = x.shape + + x += self.pos_embedding[:, :n] + x = self.dropout(x) + + x = self.transformer(x) + + return self.mlp_head(x)