diff --git a/setup.py b/setup.py index 6a8ec1e..2a3abf8 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.20.3', + version = '0.20.4', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', author = 'Phil Wang', diff --git a/vit_pytorch/t2t.py b/vit_pytorch/t2t.py index 0901c48..c5582df 100644 --- a/vit_pytorch/t2t.py +++ b/vit_pytorch/t2t.py @@ -35,13 +35,14 @@ class T2TViT(nn.Module): for i, (kernel_size, stride) in enumerate(t2t_layers): layer_dim *= kernel_size ** 2 is_first = i == 0 + is_last = i == (len(t2t_layers) - 1) output_image_size = conv_output_size(output_image_size, kernel_size, stride, stride // 2) layers.extend([ RearrangeImage() if not is_first else nn.Identity(), nn.Unfold(kernel_size = kernel_size, stride = stride, padding = stride // 2), Rearrange('b c n -> b n c'), - Transformer(dim = layer_dim, heads = 1, depth = 1, dim_head = layer_dim, mlp_dim = layer_dim, dropout = dropout), + Transformer(dim = layer_dim, heads = 1, depth = 1, dim_head = layer_dim, mlp_dim = layer_dim, dropout = dropout) if not is_last else nn.Identity(), ]) layers.append(nn.Linear(layer_dim, dim))