From 297e7d00a20628ad075470362c215095dbf5c7bd Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sun, 3 Aug 2025 08:29:40 -0700 Subject: [PATCH] handle channel first for accept video wrapper --- setup.py | 2 +- vit_pytorch/accept_video_wrapper.py | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index a3881ab..64f9572 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ with open('README.md') as f: setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '1.11.4', + version = '1.11.5', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', long_description = long_description, diff --git a/vit_pytorch/accept_video_wrapper.py b/vit_pytorch/accept_video_wrapper.py index ee10f0e..6370229 100644 --- a/vit_pytorch/accept_video_wrapper.py +++ b/vit_pytorch/accept_video_wrapper.py @@ -25,6 +25,7 @@ class AcceptVideoWrapper(Module): add_time_pos_emb = False, dim_emb = None, time_seq_len = None, + embed_is_channel_first = False, output_pos_add_pos_emb = 0 # defaults to first output position to add embedding ): super().__init__() @@ -40,6 +41,8 @@ class AcceptVideoWrapper(Module): self.pos_emb = Parameter(randn(time_seq_len, dim_emb) * 1e-2) + self.embed_is_channel_first = embed_is_channel_first + def forward( self, video, # (b c t h w) @@ -89,9 +92,16 @@ class AcceptVideoWrapper(Module): dims_to_unsqueeze = embed.ndim - pos_emb.ndim - pos_emb = pos_emb.reshape(*pos_emb.shape[:2], *((1,) * dims_to_unsqueeze) , pos_emb.shape[-1]) + one_dims = ((1,) * dims_to_unsqueeze) - embed = embed + pos_emb[:, :embed.shape[1]] + if self.embed_is_channel_first: + pos_emb = pos_emb.reshape(*pos_emb.shape, *one_dims) + else: + pos_emb = pos_emb.reshape(*pos_emb.shape[:2], *one_dims, pos_emb.shape[-1]) + + pos_emb = pos_emb[:, :embed.shape[1]] + + embed = embed + pos_emb outputs[self.output_pos_add_pos_emb] = embed