diff --git a/README.md b/README.md index 64a28a6..574559e 100644 --- a/README.md +++ b/README.md @@ -24,8 +24,8 @@ v = ViT( depth = 6, heads = 8, mlp_dim = 2048, - attn_dropout = 0.1, - ff_dropout = 0.1 + dropout = 0.1, + emb_dropout = 0.1 ) img = torch.randn(1, 3, 256, 256) diff --git a/setup.py b/setup.py index 723730c..b3d03bf 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.2.1', + version = '0.2.2', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', author = 'Phil Wang', diff --git a/vit_pytorch/vit_pytorch.py b/vit_pytorch/vit_pytorch.py index e1781b2..816047e 100644 --- a/vit_pytorch/vit_pytorch.py +++ b/vit_pytorch/vit_pytorch.py @@ -25,7 +25,8 @@ class FeedForward(nn.Module): nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), - nn.Linear(hidden_dim, dim) + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) ) def forward(self, x): return self.net(x) @@ -37,8 +38,11 @@ class Attention(nn.Module): self.scale = dim ** -0.5 self.to_qkv = nn.Linear(dim, dim * 3, bias = False) - self.to_out = nn.Linear(dim, dim) - self.dropout = nn.Dropout(dropout) + self.to_out = nn.Sequential( + nn.Linear(dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x, mask = None): b, n, _, h = *x.shape, self.heads qkv = self.to_qkv(x) @@ -54,7 +58,6 @@ class Attention(nn.Module): del mask attn = dots.softmax(dim=-1) - attn = self.dropout(attn) out = torch.einsum('bhij,bhjd->bhid', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') @@ -62,13 +65,13 @@ class Attention(nn.Module): return out class Transformer(nn.Module): - def __init__(self, dim, depth, heads, mlp_dim, attn_dropout, ff_dropout): + def __init__(self, dim, depth, heads, mlp_dim, dropout): super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ - Residual(PreNorm(dim, Attention(dim, heads = heads, dropout = attn_dropout))), - Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = ff_dropout))) + Residual(PreNorm(dim, Attention(dim, heads = heads, dropout = dropout))), + Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))) ])) def forward(self, x, mask = None): for attn, ff in self.layers: @@ -77,7 +80,7 @@ class Transformer(nn.Module): return x class ViT(nn.Module): - def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, attn_dropout = 0., ff_dropout = 0.): + def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dropout = 0., emb_dropout = 0.): super().__init__() assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size' num_patches = (image_size // patch_size) ** 2 @@ -88,7 +91,9 @@ class ViT(nn.Module): self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) self.patch_to_embedding = nn.Linear(patch_dim, dim) self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) - self.transformer = Transformer(dim, depth, heads, mlp_dim, attn_dropout, ff_dropout) + self.dropout = nn.Dropout(emb_dropout) + + self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout) self.to_cls_token = nn.Identity() @@ -96,7 +101,9 @@ class ViT(nn.Module): nn.LayerNorm(dim), nn.Linear(dim, mlp_dim), nn.GELU(), - nn.Linear(mlp_dim, num_classes) + nn.Dropout(dropout), + nn.Linear(mlp_dim, num_classes), + nn.Dropout(dropout) ) def forward(self, img, mask = None): @@ -108,6 +115,8 @@ class ViT(nn.Module): cls_tokens = self.cls_token.expand(img.shape[0], -1, -1) x = torch.cat((cls_tokens, x), dim=1) x += self.pos_embedding + x = self.dropout(x) + x = self.transformer(x, mask) x = self.to_cls_token(x[:, 0])