fix when video time seq len less than max time seq len for video acceptor

This commit is contained in:
lucidrains
2025-07-27 09:00:56 -07:00
parent e05cd6d8b8
commit 29ac8e143c
2 changed files with 6 additions and 6 deletions

View File

@@ -6,7 +6,7 @@ with open('README.md') as f:
setup( setup(
name = 'vit-pytorch', name = 'vit-pytorch',
packages = find_packages(exclude=['examples']), packages = find_packages(exclude=['examples']),
version = '1.11.3', version = '1.11.4',
license='MIT', license='MIT',
description = 'Vision Transformer (ViT) - Pytorch', description = 'Vision Transformer (ViT) - Pytorch',
long_description = long_description, long_description = long_description,

View File

@@ -91,7 +91,7 @@ class AcceptVideoWrapper(Module):
pos_emb = pos_emb.reshape(*pos_emb.shape[:2], *((1,) * dims_to_unsqueeze) , pos_emb.shape[-1]) pos_emb = pos_emb.reshape(*pos_emb.shape[:2], *((1,) * dims_to_unsqueeze) , pos_emb.shape[-1])
embed = embed + pos_emb embed = embed + pos_emb[:, :embed.shape[1]]
outputs[self.output_pos_add_pos_emb] = embed outputs[self.output_pos_add_pos_emb] = embed
@@ -114,16 +114,16 @@ if __name__ == '__main__':
emb_dropout = 0.1 emb_dropout = 0.1
) )
videos = torch.randn(1, 3, 10, 256, 256) videos = torch.randn(1, 3, 7, 256, 256)
# step up the difficulty and return embeddings for robotics # step up the difficulty and return embeddings for robotics
from vit_pytorch.extractor import Extractor from vit_pytorch.extractor import Extractor
v = Extractor(v) v = Extractor(v)
video_acceptor = AcceptVideoWrapper(v, add_time_pos_emb = True, output_pos_add_pos_emb = 1, time_seq_len = 10, dim_emb = 1024) video_acceptor = AcceptVideoWrapper(v, add_time_pos_emb = True, output_pos_add_pos_emb = 1, time_seq_len = 12, dim_emb = 1024)
logits, embeddings = video_acceptor(videos, eval_with_no_grad = True) # always (batch, channels, time, height, width) - time is always dimension 2 logits, embeddings = video_acceptor(videos, eval_with_no_grad = True) # always (batch, channels, time, height, width) - time is always dimension 2
assert logits.shape == (1, 10, 1000) assert logits.shape == (1, 7, 1000)
assert embeddings.shape == (1, 10, 65, 1024) assert embeddings.shape == (1, 7, 65, 1024)