diff --git a/setup.py b/setup.py index d39d544..e3b22c9 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.2.5', + version = '0.2.6', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', author = 'Phil Wang', diff --git a/vit_pytorch/efficient.py b/vit_pytorch/efficient.py index ed1d775..2f5d798 100644 --- a/vit_pytorch/efficient.py +++ b/vit_pytorch/efficient.py @@ -30,10 +30,11 @@ class ViT(nn.Module): x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p) x = self.patch_to_embedding(x) + b, n, _ = x.shape - cls_tokens = self.cls_token.expand(img.shape[0], -1, -1) + cls_tokens = self.cls_token.expand(b, -1, -1) x = torch.cat((cls_tokens, x), dim=1) - x += self.pos_embedding + x += self.pos_embedding[:, :(n + 1)] x = self.transformer(x) x = self.to_cls_token(x[:, 0]) diff --git a/vit_pytorch/vit_pytorch.py b/vit_pytorch/vit_pytorch.py index c97b638..910c8b1 100644 --- a/vit_pytorch/vit_pytorch.py +++ b/vit_pytorch/vit_pytorch.py @@ -113,10 +113,11 @@ class ViT(nn.Module): x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p) x = self.patch_to_embedding(x) + b, n, _ = x.shape - cls_tokens = self.cls_token.expand(img.shape[0], -1, -1) + cls_tokens = self.cls_token.expand(b, -1, -1) x = torch.cat((cls_tokens, x), dim=1) - x += self.pos_embedding + x += self.pos_embedding[:, :(n + 1)] x = self.dropout(x) x = self.transformer(x, mask)