mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
some models only return embeddings with some kwarg on forward
This commit is contained in:
2
setup.py
2
setup.py
@@ -6,7 +6,7 @@ with open('README.md') as f:
|
|||||||
setup(
|
setup(
|
||||||
name = 'vit-pytorch',
|
name = 'vit-pytorch',
|
||||||
packages = find_packages(exclude=['examples']),
|
packages = find_packages(exclude=['examples']),
|
||||||
version = '1.11.2',
|
version = '1.11.3',
|
||||||
license='MIT',
|
license='MIT',
|
||||||
description = 'Vision Transformer (ViT) - Pytorch',
|
description = 'Vision Transformer (ViT) - Pytorch',
|
||||||
long_description = long_description,
|
long_description = long_description,
|
||||||
|
|||||||
@@ -43,7 +43,8 @@ class AcceptVideoWrapper(Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
video, # (b c t h w)
|
video, # (b c t h w)
|
||||||
eval_with_no_grad = False
|
eval_with_no_grad = False,
|
||||||
|
forward_kwargs = dict()
|
||||||
):
|
):
|
||||||
add_time_pos_emb = self.add_time_pos_emb
|
add_time_pos_emb = self.add_time_pos_emb
|
||||||
time = video.shape[2]
|
time = video.shape[2]
|
||||||
@@ -67,7 +68,7 @@ class AcceptVideoWrapper(Module):
|
|||||||
context = torch.no_grad if eval_with_no_grad else nullcontext
|
context = torch.no_grad if eval_with_no_grad else nullcontext
|
||||||
|
|
||||||
with context():
|
with context():
|
||||||
outputs = func(video)
|
outputs = func(video, **forward_kwargs)
|
||||||
|
|
||||||
# handle multiple outputs, say logits and embeddings returned from extractor - also handle some reduce aux loss being returned
|
# handle multiple outputs, say logits and embeddings returned from extractor - also handle some reduce aux loss being returned
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user