diff --git a/README.md b/README.md index e097d56..64a28a6 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,9 @@ v = ViT( dim = 1024, depth = 6, heads = 8, - mlp_dim = 2048 + mlp_dim = 2048, + attn_dropout = 0.1, + ff_dropout = 0.1 ) img = torch.randn(1, 3, 256, 256) diff --git a/vit_pytorch/vit_pytorch.py b/vit_pytorch/vit_pytorch.py index c471dc0..e1781b2 100644 --- a/vit_pytorch/vit_pytorch.py +++ b/vit_pytorch/vit_pytorch.py @@ -19,24 +19,26 @@ class PreNorm(nn.Module): return self.fn(self.norm(x), **kwargs) class FeedForward(nn.Module): - def __init__(self, dim, hidden_dim): + def __init__(self, dim, hidden_dim, dropout = 0.): super().__init__() self.net = nn.Sequential( nn.Linear(dim, hidden_dim), nn.GELU(), + nn.Dropout(dropout), nn.Linear(hidden_dim, dim) ) def forward(self, x): return self.net(x) class Attention(nn.Module): - def __init__(self, dim, heads = 8): + def __init__(self, dim, heads = 8, dropout = 0.): super().__init__() self.heads = heads self.scale = dim ** -0.5 self.to_qkv = nn.Linear(dim, dim * 3, bias = False) self.to_out = nn.Linear(dim, dim) + self.dropout = nn.Dropout(dropout) def forward(self, x, mask = None): b, n, _, h = *x.shape, self.heads qkv = self.to_qkv(x) @@ -52,6 +54,7 @@ class Attention(nn.Module): del mask attn = dots.softmax(dim=-1) + attn = self.dropout(attn) out = torch.einsum('bhij,bhjd->bhid', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') @@ -59,13 +62,13 @@ class Attention(nn.Module): return out class Transformer(nn.Module): - def __init__(self, dim, depth, heads, mlp_dim): + def __init__(self, dim, depth, heads, mlp_dim, attn_dropout, ff_dropout): super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ - Residual(PreNorm(dim, Attention(dim, heads = heads))), - Residual(PreNorm(dim, FeedForward(dim, mlp_dim))) + Residual(PreNorm(dim, Attention(dim, heads = heads, dropout = attn_dropout))), + Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = ff_dropout))) ])) def forward(self, x, mask = None): for attn, ff in self.layers: @@ -74,7 +77,7 @@ class Transformer(nn.Module): return x class ViT(nn.Module): - def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3): + def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, attn_dropout = 0., ff_dropout = 0.): super().__init__() assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size' num_patches = (image_size // patch_size) ** 2 @@ -85,7 +88,7 @@ class ViT(nn.Module): self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) self.patch_to_embedding = nn.Linear(patch_dim, dim) self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) - self.transformer = Transformer(dim, depth, heads, mlp_dim) + self.transformer = Transformer(dim, depth, heads, mlp_dim, attn_dropout, ff_dropout) self.to_cls_token = nn.Identity()