handle channel first for accept video wrapper

This commit is contained in:
lucidrains
2025-08-03 08:29:40 -07:00
parent 29ac8e143c
commit 297e7d00a2
2 changed files with 13 additions and 3 deletions

View File

@@ -25,6 +25,7 @@ class AcceptVideoWrapper(Module):
add_time_pos_emb = False,
dim_emb = None,
time_seq_len = None,
embed_is_channel_first = False,
output_pos_add_pos_emb = 0 # defaults to first output position to add embedding
):
super().__init__()
@@ -40,6 +41,8 @@ class AcceptVideoWrapper(Module):
self.pos_emb = Parameter(randn(time_seq_len, dim_emb) * 1e-2)
self.embed_is_channel_first = embed_is_channel_first
def forward(
self,
video, # (b c t h w)
@@ -89,9 +92,16 @@ class AcceptVideoWrapper(Module):
dims_to_unsqueeze = embed.ndim - pos_emb.ndim
pos_emb = pos_emb.reshape(*pos_emb.shape[:2], *((1,) * dims_to_unsqueeze) , pos_emb.shape[-1])
one_dims = ((1,) * dims_to_unsqueeze)
embed = embed + pos_emb[:, :embed.shape[1]]
if self.embed_is_channel_first:
pos_emb = pos_emb.reshape(*pos_emb.shape, *one_dims)
else:
pos_emb = pos_emb.reshape(*pos_emb.shape[:2], *one_dims, pos_emb.shape[-1])
pos_emb = pos_emb[:, :embed.shape[1]]
embed = embed + pos_emb
outputs[self.output_pos_add_pos_emb] = embed