handle channel first for accept video wrapper

This commit is contained in:
lucidrains
2025-08-03 08:29:40 -07:00
parent 29ac8e143c
commit 297e7d00a2
2 changed files with 13 additions and 3 deletions

View File

@@ -6,7 +6,7 @@ with open('README.md') as f:
setup( setup(
name = 'vit-pytorch', name = 'vit-pytorch',
packages = find_packages(exclude=['examples']), packages = find_packages(exclude=['examples']),
version = '1.11.4', version = '1.11.5',
license='MIT', license='MIT',
description = 'Vision Transformer (ViT) - Pytorch', description = 'Vision Transformer (ViT) - Pytorch',
long_description = long_description, long_description = long_description,

View File

@@ -25,6 +25,7 @@ class AcceptVideoWrapper(Module):
add_time_pos_emb = False, add_time_pos_emb = False,
dim_emb = None, dim_emb = None,
time_seq_len = None, time_seq_len = None,
embed_is_channel_first = False,
output_pos_add_pos_emb = 0 # defaults to first output position to add embedding output_pos_add_pos_emb = 0 # defaults to first output position to add embedding
): ):
super().__init__() super().__init__()
@@ -40,6 +41,8 @@ class AcceptVideoWrapper(Module):
self.pos_emb = Parameter(randn(time_seq_len, dim_emb) * 1e-2) self.pos_emb = Parameter(randn(time_seq_len, dim_emb) * 1e-2)
self.embed_is_channel_first = embed_is_channel_first
def forward( def forward(
self, self,
video, # (b c t h w) video, # (b c t h w)
@@ -89,9 +92,16 @@ class AcceptVideoWrapper(Module):
dims_to_unsqueeze = embed.ndim - pos_emb.ndim dims_to_unsqueeze = embed.ndim - pos_emb.ndim
pos_emb = pos_emb.reshape(*pos_emb.shape[:2], *((1,) * dims_to_unsqueeze) , pos_emb.shape[-1]) one_dims = ((1,) * dims_to_unsqueeze)
embed = embed + pos_emb[:, :embed.shape[1]] if self.embed_is_channel_first:
pos_emb = pos_emb.reshape(*pos_emb.shape, *one_dims)
else:
pos_emb = pos_emb.reshape(*pos_emb.shape[:2], *one_dims, pos_emb.shape[-1])
pos_emb = pos_emb[:, :embed.shape[1]]
embed = embed + pos_emb
outputs[self.output_pos_add_pos_emb] = embed outputs[self.output_pos_add_pos_emb] = embed