bit more flexible

This commit is contained in:
lucidrains
2025-07-27 08:14:48 -07:00
parent b22dc0ecd2
commit 68e13a3c7d
2 changed files with 10 additions and 3 deletions

View File

@@ -6,7 +6,7 @@ with open('README.md') as f:
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.11.0',
version = '1.11.1',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description = long_description,

View File

@@ -43,7 +43,7 @@ class AcceptVideoWrapper(Module):
video # (b c t h w)
):
add_time_pos_emb = self.add_time_pos_emb
batch, time = video.shape[0], video.shape[2]
time = video.shape[2]
# maybe validate time positional embedding
@@ -67,11 +67,18 @@ class AcceptVideoWrapper(Module):
# maybe add time positional embedding
if add_time_pos_emb:
pos_emb = repeat(self.pos_emb, 't d -> b t 1 d', b = batch)
outputs = list(outputs)
embed = outputs[self.output_pos_add_pos_emb]
pos_emb = rearrange(self.pos_emb, 't d -> 1 t d')
# handle the network outputting embeddings with spatial dimensions intact - assume embedded dimension is last
dims_to_unsqueeze = embed.ndim - pos_emb.ndim
pos_emb = pos_emb.reshape(*pos_emb.shape[:2], *((1,) * dims_to_unsqueeze) , pos_emb.shape[-1])
embed = embed + pos_emb
outputs[self.output_pos_add_pos_emb] = embed