From 3f754956fbfb1f97ae4f1e244a7ecb16eab79296 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sat, 14 Aug 2021 08:06:23 -0700 Subject: [PATCH] remove last transformer layer in t2t --- setup.py | 2 +- vit_pytorch/t2t.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 6a8ec1e..2a3abf8 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.20.3', + version = '0.20.4', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', author = 'Phil Wang', diff --git a/vit_pytorch/t2t.py b/vit_pytorch/t2t.py index 0901c48..c5582df 100644 --- a/vit_pytorch/t2t.py +++ b/vit_pytorch/t2t.py @@ -35,13 +35,14 @@ class T2TViT(nn.Module): for i, (kernel_size, stride) in enumerate(t2t_layers): layer_dim *= kernel_size ** 2 is_first = i == 0 + is_last = i == (len(t2t_layers) - 1) output_image_size = conv_output_size(output_image_size, kernel_size, stride, stride // 2) layers.extend([ RearrangeImage() if not is_first else nn.Identity(), nn.Unfold(kernel_size = kernel_size, stride = stride, padding = stride // 2), Rearrange('b c n -> b n c'), - Transformer(dim = layer_dim, heads = 1, depth = 1, dim_head = layer_dim, mlp_dim = layer_dim, dropout = dropout), + Transformer(dim = layer_dim, heads = 1, depth = 1, dim_head = layer_dim, mlp_dim = layer_dim, dropout = dropout) if not is_last else nn.Identity(), ]) layers.append(nn.Linear(layer_dim, dim))