diff --git a/README.md b/README.md index e8bb4cc..312fbe5 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ v = ViT( num_classes = 1000, dim = 1024, depth = 6, - heads = 8, + heads = 16, mlp_dim = 2048, dropout = 0.1, emb_dropout = 0.1 diff --git a/setup.py b/setup.py index a801599..a276970 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.3.0', + version = '0.4.0', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', author = 'Phil Wang', diff --git a/vit_pytorch/vit_pytorch.py b/vit_pytorch/vit_pytorch.py index e5d6e3f..08b0762 100644 --- a/vit_pytorch/vit_pytorch.py +++ b/vit_pytorch/vit_pytorch.py @@ -34,14 +34,15 @@ class FeedForward(nn.Module): return self.net(x) class Attention(nn.Module): - def __init__(self, dim, heads = 8, dropout = 0.): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): super().__init__() + inner_dim = dim_head * heads self.heads = heads self.scale = dim ** -0.5 - self.to_qkv = nn.Linear(dim, dim * 3, bias = False) + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) self.to_out = nn.Sequential( - nn.Linear(dim, dim), + nn.Linear(inner_dim, dim), nn.Dropout(dropout) ) @@ -68,12 +69,12 @@ class Attention(nn.Module): return out class Transformer(nn.Module): - def __init__(self, dim, depth, heads, mlp_dim, dropout): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout): super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ - Residual(PreNorm(dim, Attention(dim, heads = heads, dropout = dropout))), + Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))), Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))) ])) def forward(self, x, mask = None): @@ -83,7 +84,7 @@ class Transformer(nn.Module): return x class ViT(nn.Module): - def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dropout = 0., emb_dropout = 0.): + def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): super().__init__() assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' num_patches = (image_size // patch_size) ** 2 @@ -97,7 +98,7 @@ class ViT(nn.Module): self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) self.dropout = nn.Dropout(emb_dropout) - self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout) + self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) self.to_cls_token = nn.Identity()