able to project the image embedding before applying time positional embedding for accept video wrapper

This commit is contained in:
lucidrains
2025-08-13 10:15:18 -07:00
parent 297e7d00a2
commit f8bec5ede2
2 changed files with 28 additions and 6 deletions

View File

@@ -2,7 +2,7 @@ from contextlib import nullcontext
import torch
from torch import is_tensor, randn
from torch.nn import Module, Parameter
from torch.nn import Module, Linear, Parameter
from torch.utils._pytree import tree_flatten, tree_unflatten
from einops import rearrange, repeat
@@ -26,7 +26,8 @@ class AcceptVideoWrapper(Module):
dim_emb = None,
time_seq_len = None,
embed_is_channel_first = False,
output_pos_add_pos_emb = 0 # defaults to first output position to add embedding
output_pos_add_pos_emb = 0, # defaults to first output position to add embedding
proj_embed_to_dim = None
):
super().__init__()
self.image_net = image_net
@@ -35,11 +36,23 @@ class AcceptVideoWrapper(Module):
self.add_time_pos_emb = add_time_pos_emb
self.output_pos_add_pos_emb = output_pos_add_pos_emb
# maybe project the image embedding
self.embed_proj = None
if exists(proj_embed_to_dim):
assert exists(dim_emb), '`dim_emb` must be passed in'
self.embed_proj = Linear(dim_emb, proj_embed_to_dim)
# time positional embedding
if add_time_pos_emb:
assert exists(dim_emb) and exists(time_seq_len), '`dim_emb` and `time_seq_len` must be set if adding positional embeddings to the output'
self.time_seq_len = time_seq_len
self.pos_emb = Parameter(randn(time_seq_len, dim_emb) * 1e-2)
dim_pos_emb = default(proj_embed_to_dim, dim_emb)
self.pos_emb = Parameter(randn(time_seq_len, dim_pos_emb) * 1e-2)
self.embed_is_channel_first = embed_is_channel_first
@@ -79,6 +92,15 @@ class AcceptVideoWrapper(Module):
outputs = tuple(rearrange(t, '(b t) ... -> b t ...', t = time) if is_tensor(t) and t.numel() > 1 else t for t in outputs)
# maybe project embedding
if exists(self.embed_proj):
outputs = list(outputs)
embed = outputs[self.output_pos_add_pos_emb]
outputs[self.output_pos_add_pos_emb] = self.embed_proj(embed)
# maybe add time positional embedding
if add_time_pos_emb:
@@ -131,9 +153,9 @@ if __name__ == '__main__':
from vit_pytorch.extractor import Extractor
v = Extractor(v)
video_acceptor = AcceptVideoWrapper(v, add_time_pos_emb = True, output_pos_add_pos_emb = 1, time_seq_len = 12, dim_emb = 1024)
video_acceptor = AcceptVideoWrapper(v, add_time_pos_emb = True, output_pos_add_pos_emb = 1, time_seq_len = 12, dim_emb = 1024, proj_embed_to_dim = 512)
logits, embeddings = video_acceptor(videos, eval_with_no_grad = True) # always (batch, channels, time, height, width) - time is always dimension 2
assert logits.shape == (1, 7, 1000)
assert embeddings.shape == (1, 7, 65, 1024)
assert embeddings.shape == (1, 7, 65, 512)