diff --git a/setup.py b/setup.py index 308c818..0b04304 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '1.2.1', + version = '1.2.2', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', long_description_content_type = 'text/markdown', diff --git a/vit_pytorch/vivit.py b/vit_pytorch/vivit.py index 082b699..50daa65 100644 --- a/vit_pytorch/vivit.py +++ b/vit_pytorch/vivit.py @@ -146,7 +146,7 @@ class ViT(nn.Module): x = self.to_patch_embedding(video) b, f, n, _ = x.shape - x = x + self.pos_embedding + x = x + self.pos_embedding[:, :f, :n] if exists(self.spatial_cls_token): spatial_cls_tokens = repeat(self.spatial_cls_token, '1 1 d -> b f 1 d', b = b, f = f)