From f7123720c320e2eea815cbd7d41c4bd293fbc75a Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 7 Oct 2020 11:21:03 -0700 Subject: [PATCH] add masking --- README.md | 4 +++- setup.py | 2 +- vit_pytorch/vit_pytorch.py | 38 +++++++++++++++++++++++++------------- 3 files changed, 29 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index b7e00b6..cef6ae7 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,9 @@ v = ViT( ) img = torch.randn(1, 3, 256, 256) -preds = v(img) # (1, 1000) +mask = torch.ones(1, 8, 8).bool() # optional mask, designating which patch to attend to + +preds = v(img, mask = mask) # (1, 1000) ``` ## Suggestion diff --git a/setup.py b/setup.py index 22ed1df..d5ff1b0 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(), - version = '0.0.3', + version = '0.0.4', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', author = 'Phil Wang', diff --git a/vit_pytorch/vit_pytorch.py b/vit_pytorch/vit_pytorch.py index f2ce2eb..1906d01 100644 --- a/vit_pytorch/vit_pytorch.py +++ b/vit_pytorch/vit_pytorch.py @@ -1,4 +1,5 @@ import torch +import torch.nn.functional as F from einops import rearrange from torch import nn @@ -6,16 +7,16 @@ class Residual(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn - def forward(self, x): - return self.fn(x) + x + def forward(self, x, **kwargs): + return self.fn(x, **kwargs) + x class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() self.norm = nn.LayerNorm(dim) self.fn = fn - def forward(self, x): - return self.fn(self.norm(x)) + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) class FeedForward(nn.Module): def __init__(self, dim, hidden_dim): @@ -36,12 +37,21 @@ class Attention(nn.Module): self.to_qkv = nn.Linear(dim, dim * 3, bias = False) self.to_out = nn.Linear(dim, dim) - def forward(self, x): + def forward(self, x, mask = None): b, n, _, h = *x.shape, self.heads qkv = self.to_qkv(x) q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = h) dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale + + if mask is not None: + mask = F.pad(mask.flatten(1), (1, 0), value = True) + print(mask.shape[-1], dots.shape[-1]) + assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' + mask = mask[:, None, :] * mask[:, :, None] + dots.masked_fill_(~mask, float('-inf')) + del mask + attn = dots.softmax(dim=-1) out = torch.einsum('bhij,bhjd->bhid', attn, v) @@ -52,15 +62,17 @@ class Attention(nn.Module): class Transformer(nn.Module): def __init__(self, dim, depth, heads, mlp_dim): super().__init__() - layers = [] + self.layers = nn.ModuleList([]) for _ in range(depth): - layers.extend([ + self.layers.append(nn.ModuleList([ Residual(PreNorm(dim, Attention(dim, heads = heads))), Residual(PreNorm(dim, FeedForward(dim, mlp_dim))) - ]) - self.net = nn.Sequential(*layers) - def forward(self, x): - return self.net(x) + ])) + def forward(self, x, mask = None): + for attn, ff in self.layers: + x = attn(x, mask = mask) + x = ff(x) + return x class ViT(nn.Module): def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3): @@ -84,7 +96,7 @@ class ViT(nn.Module): nn.Linear(mlp_dim, num_classes) ) - def forward(self, img): + def forward(self, img, mask = None): p = self.patch_size x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p) @@ -93,7 +105,7 @@ class ViT(nn.Module): cls_tokens = self.cls_token.expand(img.shape[0], -1, -1) x = torch.cat((cls_tokens, x), dim=1) x += self.pos_embedding - x = self.transformer(x) + x = self.transformer(x, mask) x = self.to_cls_token(x[:, 0]) return self.mlp_head(x)