From b900850144ff995c02fe304a9b85774deb794c2f Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 23 Mar 2021 11:57:13 -0700 Subject: [PATCH] add deep vit --- README.md | 50 +++++++++++++++++++--- setup.py | 2 +- vit_pytorch/__init__.py | 2 +- vit_pytorch/{vit_pytorch.py => deepvit.py} | 42 ++++++++++-------- vit_pytorch/t2t.py | 3 +- 5 files changed, 71 insertions(+), 28 deletions(-) rename vit_pytorch/{vit_pytorch.py => deepvit.py} (84%) diff --git a/README.md b/README.md index 693b09f..01053e9 100644 --- a/README.md +++ b/README.md @@ -117,6 +117,33 @@ v = v.to_vit() type(v) # ``` +## Deep ViT + +This paper notes that ViT struggles to attend at greater depths (past 12 layers), and suggests mixing the attention of each head post-softmax as a solution, dubbed Re-attention. The results line up with the Talking Heads paper from NLP. + +You can use it as follows + +```python +import torch +from vit_pytorch.deepvit import DeepViT + +v = DeepViT( + image_size = 256, + patch_size = 32, + num_classes = 1000, + dim = 1024, + depth = 6, + heads = 16, + mlp_dim = 2048, + dropout = 0.1, + emb_dropout = 0.1 +) + +img = torch.randn(1, 3, 256, 256) + +preds = v(img) # (1, 1000) +``` + ## Token-to-Token ViT @@ -345,12 +372,23 @@ Coming from computer vision and new to transformers? Here are some resources tha ```bibtex @misc{yuan2021tokenstotoken, - title = {Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet}, - author = {Li Yuan and Yunpeng Chen and Tao Wang and Weihao Yu and Yujun Shi and Francis EH Tay and Jiashi Feng and Shuicheng Yan}, - year = {2021}, - eprint = {2101.11986}, - archivePrefix = {arXiv}, - primaryClass = {cs.CV} + title = {Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet}, + author = {Li Yuan and Yunpeng Chen and Tao Wang and Weihao Yu and Yujun Shi and Francis EH Tay and Jiashi Feng and Shuicheng Yan}, + year = {2021}, + eprint = {2101.11986}, + archivePrefix = {arXiv}, + primaryClass = {cs.CV} +} +``` + +```bibtex +@misc{zhou2021deepvit, + title = {DeepViT: Towards Deeper Vision Transformer}, + author = {Daquan Zhou and Bingyi Kang and Xiaojie Jin and Linjie Yang and Xiaochen Lian and Qibin Hou and Jiashi Feng}, + year = {2021}, + eprint = {2103.11886}, + archivePrefix = {arXiv}, + primaryClass = {cs.CV} } ``` diff --git a/setup.py b/setup.py index cc706ce..078d456 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.8.0', + version = '0.9.1', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', author = 'Phil Wang', diff --git a/vit_pytorch/__init__.py b/vit_pytorch/__init__.py index 2a84eb0..1edbfeb 100644 --- a/vit_pytorch/__init__.py +++ b/vit_pytorch/__init__.py @@ -1 +1 @@ -from vit_pytorch.vit_pytorch import ViT +from vit_pytorch.vit import ViT diff --git a/vit_pytorch/vit_pytorch.py b/vit_pytorch/deepvit.py similarity index 84% rename from vit_pytorch/vit_pytorch.py rename to vit_pytorch/deepvit.py index 3153618..bf9d228 100644 --- a/vit_pytorch/vit_pytorch.py +++ b/vit_pytorch/deepvit.py @@ -37,35 +37,41 @@ class Attention(nn.Module): def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): super().__init__() inner_dim = dim_head * heads - project_out = not (heads == 1 and dim_head == dim) - self.heads = heads self.scale = dim_head ** -0.5 self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + self.reattn_weights = nn.Parameter(torch.randn(heads, heads)) + + self.reattn_norm = nn.Sequential( + Rearrange('b h i j -> b i j h'), + nn.LayerNorm(heads), + Rearrange('b i j h -> b h i j') + ) + self.to_out = nn.Sequential( nn.Linear(inner_dim, dim), nn.Dropout(dropout) - ) if project_out else nn.Identity() + ) - def forward(self, x, mask = None): + def forward(self, x): b, n, _, h = *x.shape, self.heads qkv = self.to_qkv(x).chunk(3, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) + # attention + dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale - mask_value = -torch.finfo(dots.dtype).max - - if mask is not None: - mask = F.pad(mask.flatten(1), (1, 0), value = True) - assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' - mask = rearrange(mask, 'b i -> b () i ()') * rearrange(mask, 'b j -> b () () j') - dots.masked_fill_(~mask, mask_value) - del mask - attn = dots.softmax(dim=-1) + # re-attention + + attn = einsum('b h i j, h g -> b g i j', attn, self.reattn_weights) + attn = self.reattn_norm(attn) + + # aggregate and out + out = einsum('b h i j, b h j d -> b h i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') out = self.to_out(out) @@ -80,13 +86,13 @@ class Transformer(nn.Module): Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))), Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))) ])) - def forward(self, x, mask = None): + def forward(self, x): for attn, ff in self.layers: - x = attn(x, mask = mask) + x = attn(x) x = ff(x) return x -class ViT(nn.Module): +class DeepViT(nn.Module): def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): super().__init__() assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' @@ -113,7 +119,7 @@ class ViT(nn.Module): nn.Linear(dim, num_classes) ) - def forward(self, img, mask = None): + def forward(self, img): x = self.to_patch_embedding(img) b, n, _ = x.shape @@ -122,7 +128,7 @@ class ViT(nn.Module): x += self.pos_embedding[:, :(n + 1)] x = self.dropout(x) - x = self.transformer(x, mask) + x = self.transformer(x) x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] diff --git a/vit_pytorch/t2t.py b/vit_pytorch/t2t.py index 626aca0..e45fd8d 100644 --- a/vit_pytorch/t2t.py +++ b/vit_pytorch/t2t.py @@ -24,8 +24,7 @@ class RearrangeImage(nn.Module): # main class class T2TViT(nn.Module): - def __init__( - self, *, image_size, num_classes, dim, depth = None, heads = None, mlp_dim = None, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., transformer = None, t2t_layers = ((7, 4), (3, 2), (3, 2))): + def __init__(self, *, image_size, num_classes, dim, depth = None, heads = None, mlp_dim = None, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., transformer = None, t2t_layers = ((7, 4), (3, 2), (3, 2))): super().__init__() assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'