From f8bec5ede2f0cf9c2cf7166d72e43b96c3684baf Mon Sep 17 00:00:00 2001 From: lucidrains Date: Wed, 13 Aug 2025 10:15:18 -0700 Subject: [PATCH] able to project the image embedding before applying time positional embedding for accept video wrapper --- setup.py | 2 +- vit_pytorch/accept_video_wrapper.py | 32 ++++++++++++++++++++++++----- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/setup.py b/setup.py index 64f9572..7baa04a 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ with open('README.md') as f: setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '1.11.5', + version = '1.11.6', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', long_description = long_description, diff --git a/vit_pytorch/accept_video_wrapper.py b/vit_pytorch/accept_video_wrapper.py index 6370229..f913955 100644 --- a/vit_pytorch/accept_video_wrapper.py +++ b/vit_pytorch/accept_video_wrapper.py @@ -2,7 +2,7 @@ from contextlib import nullcontext import torch from torch import is_tensor, randn -from torch.nn import Module, Parameter +from torch.nn import Module, Linear, Parameter from torch.utils._pytree import tree_flatten, tree_unflatten from einops import rearrange, repeat @@ -26,7 +26,8 @@ class AcceptVideoWrapper(Module): dim_emb = None, time_seq_len = None, embed_is_channel_first = False, - output_pos_add_pos_emb = 0 # defaults to first output position to add embedding + output_pos_add_pos_emb = 0, # defaults to first output position to add embedding + proj_embed_to_dim = None ): super().__init__() self.image_net = image_net @@ -35,11 +36,23 @@ class AcceptVideoWrapper(Module): self.add_time_pos_emb = add_time_pos_emb self.output_pos_add_pos_emb = output_pos_add_pos_emb + # maybe project the image embedding + + self.embed_proj = None + + if exists(proj_embed_to_dim): + assert exists(dim_emb), '`dim_emb` must be passed in' + self.embed_proj = Linear(dim_emb, proj_embed_to_dim) + + # time positional embedding + if add_time_pos_emb: assert exists(dim_emb) and exists(time_seq_len), '`dim_emb` and `time_seq_len` must be set if adding positional embeddings to the output' self.time_seq_len = time_seq_len - self.pos_emb = Parameter(randn(time_seq_len, dim_emb) * 1e-2) + dim_pos_emb = default(proj_embed_to_dim, dim_emb) + + self.pos_emb = Parameter(randn(time_seq_len, dim_pos_emb) * 1e-2) self.embed_is_channel_first = embed_is_channel_first @@ -79,6 +92,15 @@ class AcceptVideoWrapper(Module): outputs = tuple(rearrange(t, '(b t) ... -> b t ...', t = time) if is_tensor(t) and t.numel() > 1 else t for t in outputs) + # maybe project embedding + + if exists(self.embed_proj): + outputs = list(outputs) + + embed = outputs[self.output_pos_add_pos_emb] + + outputs[self.output_pos_add_pos_emb] = self.embed_proj(embed) + # maybe add time positional embedding if add_time_pos_emb: @@ -131,9 +153,9 @@ if __name__ == '__main__': from vit_pytorch.extractor import Extractor v = Extractor(v) - video_acceptor = AcceptVideoWrapper(v, add_time_pos_emb = True, output_pos_add_pos_emb = 1, time_seq_len = 12, dim_emb = 1024) + video_acceptor = AcceptVideoWrapper(v, add_time_pos_emb = True, output_pos_add_pos_emb = 1, time_seq_len = 12, dim_emb = 1024, proj_embed_to_dim = 512) logits, embeddings = video_acceptor(videos, eval_with_no_grad = True) # always (batch, channels, time, height, width) - time is always dimension 2 assert logits.shape == (1, 7, 1000) - assert embeddings.shape == (1, 7, 65, 1024) + assert embeddings.shape == (1, 7, 65, 512)