diff --git a/setup.py b/setup.py index d438cbe..a3881ab 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ with open('README.md') as f: setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '1.11.3', + version = '1.11.4', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', long_description = long_description, diff --git a/vit_pytorch/accept_video_wrapper.py b/vit_pytorch/accept_video_wrapper.py index c5ef059..ee10f0e 100644 --- a/vit_pytorch/accept_video_wrapper.py +++ b/vit_pytorch/accept_video_wrapper.py @@ -91,7 +91,7 @@ class AcceptVideoWrapper(Module): pos_emb = pos_emb.reshape(*pos_emb.shape[:2], *((1,) * dims_to_unsqueeze) , pos_emb.shape[-1]) - embed = embed + pos_emb + embed = embed + pos_emb[:, :embed.shape[1]] outputs[self.output_pos_add_pos_emb] = embed @@ -114,16 +114,16 @@ if __name__ == '__main__': emb_dropout = 0.1 ) - videos = torch.randn(1, 3, 10, 256, 256) + videos = torch.randn(1, 3, 7, 256, 256) # step up the difficulty and return embeddings for robotics from vit_pytorch.extractor import Extractor v = Extractor(v) - video_acceptor = AcceptVideoWrapper(v, add_time_pos_emb = True, output_pos_add_pos_emb = 1, time_seq_len = 10, dim_emb = 1024) + video_acceptor = AcceptVideoWrapper(v, add_time_pos_emb = True, output_pos_add_pos_emb = 1, time_seq_len = 12, dim_emb = 1024) logits, embeddings = video_acceptor(videos, eval_with_no_grad = True) # always (batch, channels, time, height, width) - time is always dimension 2 - assert logits.shape == (1, 10, 1000) - assert embeddings.shape == (1, 10, 65, 1024) + assert logits.shape == (1, 7, 1000) + assert embeddings.shape == (1, 7, 65, 1024)