some models only return embeddings with some kwarg on forward

This commit is contained in:
lucidrains
2025-07-27 08:46:43 -07:00
parent b46233c3d6
commit e05cd6d8b8
2 changed files with 4 additions and 3 deletions

View File

@@ -6,7 +6,7 @@ with open('README.md') as f:
setup( setup(
name = 'vit-pytorch', name = 'vit-pytorch',
packages = find_packages(exclude=['examples']), packages = find_packages(exclude=['examples']),
version = '1.11.2', version = '1.11.3',
license='MIT', license='MIT',
description = 'Vision Transformer (ViT) - Pytorch', description = 'Vision Transformer (ViT) - Pytorch',
long_description = long_description, long_description = long_description,

View File

@@ -43,7 +43,8 @@ class AcceptVideoWrapper(Module):
def forward( def forward(
self, self,
video, # (b c t h w) video, # (b c t h w)
eval_with_no_grad = False eval_with_no_grad = False,
forward_kwargs = dict()
): ):
add_time_pos_emb = self.add_time_pos_emb add_time_pos_emb = self.add_time_pos_emb
time = video.shape[2] time = video.shape[2]
@@ -67,7 +68,7 @@ class AcceptVideoWrapper(Module):
context = torch.no_grad if eval_with_no_grad else nullcontext context = torch.no_grad if eval_with_no_grad else nullcontext
with context(): with context():
outputs = func(video) outputs = func(video, **forward_kwargs)
# handle multiple outputs, say logits and embeddings returned from extractor - also handle some reduce aux loss being returned # handle multiple outputs, say logits and embeddings returned from extractor - also handle some reduce aux loss being returned