From e05cd6d8b8dd4eb0e97f9af43ba595cca1bcb229 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sun, 27 Jul 2025 08:46:43 -0700 Subject: [PATCH] some models only return embeddings with some kwarg on forward --- setup.py | 2 +- vit_pytorch/accept_video_wrapper.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 9f75b95..d438cbe 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ with open('README.md') as f: setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '1.11.2', + version = '1.11.3', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', long_description = long_description, diff --git a/vit_pytorch/accept_video_wrapper.py b/vit_pytorch/accept_video_wrapper.py index 108ffe9..c5ef059 100644 --- a/vit_pytorch/accept_video_wrapper.py +++ b/vit_pytorch/accept_video_wrapper.py @@ -43,7 +43,8 @@ class AcceptVideoWrapper(Module): def forward( self, video, # (b c t h w) - eval_with_no_grad = False + eval_with_no_grad = False, + forward_kwargs = dict() ): add_time_pos_emb = self.add_time_pos_emb time = video.shape[2] @@ -67,7 +68,7 @@ class AcceptVideoWrapper(Module): context = torch.no_grad if eval_with_no_grad else nullcontext with context(): - outputs = func(video) + outputs = func(video, **forward_kwargs) # handle multiple outputs, say logits and embeddings returned from extractor - also handle some reduce aux loss being returned