diff --git a/setup.py b/setup.py index 79d7059..dfd7e52 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.19.0', + version = '0.19.1', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', author = 'Phil Wang', diff --git a/vit_pytorch/nest.py b/vit_pytorch/nest.py index a54fa34..b4ee105 100644 --- a/vit_pytorch/nest.py +++ b/vit_pytorch/nest.py @@ -48,16 +48,16 @@ class FeedForward(nn.Module): class Attention(nn.Module): def __init__(self, dim, heads = 8, dropout = 0.): super().__init__() - assert (dim % heads) == 0, 'dimension must be divisible by number of heads' dim_head = dim // heads + inner_dim = dim_head * heads self.heads = heads self.scale = dim_head ** -0.5 self.attend = nn.Softmax(dim = -1) - self.to_qkv = nn.Conv2d(dim, dim * 3, 1, bias = False) + self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False) self.to_out = nn.Sequential( - nn.Conv2d(dim, dim, 1), + nn.Conv2d(inner_dim, dim, 1), nn.Dropout(dropout) ) @@ -129,7 +129,8 @@ class NesT(nn.Module): blocks = 2 ** (num_heirarchies - 1) seq_len = (fmap_size // blocks) ** 2 # sequence length is held constant across heirarchy - mults = [2 ** i for i in reversed(range(num_heirarchies))] + heirarchies = list(reversed(range(num_heirarchies))) + mults = [2 ** i for i in heirarchies] layer_heads = list(map(lambda t: t * heads, mults)) layer_dims = list(map(lambda t: t * dim, mults)) @@ -146,7 +147,7 @@ class NesT(nn.Module): self.layers = nn.ModuleList([]) - for level, heads, (dim_in, dim_out), block_repeat in zip(reversed(range(num_heirarchies)), layer_heads, dim_pairs, block_repeats): + for level, heads, (dim_in, dim_out), block_repeat in zip(heirarchies, layer_heads, dim_pairs, block_repeats): is_last = level == 0 depth = block_repeat