Compare commits

...

3 Commits

Author SHA1 Message Date
lucidrains
e05cd6d8b8 some models only return embeddings with some kwarg on forward 2025-07-27 08:46:43 -07:00
lucidrains
b46233c3d6 need to be able to invoke with eval no grad 2025-07-27 08:25:58 -07:00
lucidrains
68e13a3c7d bit more flexible 2025-07-27 08:14:48 -07:00
2 changed files with 25 additions and 6 deletions

View File

@@ -6,7 +6,7 @@ with open('README.md') as f:
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.11.0',
version = '1.11.3',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description = long_description,

View File

@@ -1,3 +1,5 @@
from contextlib import nullcontext
import torch
from torch import is_tensor, randn
from torch.nn import Module, Parameter
@@ -40,10 +42,12 @@ class AcceptVideoWrapper(Module):
def forward(
self,
video # (b c t h w)
video, # (b c t h w)
eval_with_no_grad = False,
forward_kwargs = dict()
):
add_time_pos_emb = self.add_time_pos_emb
batch, time = video.shape[0], video.shape[2]
time = video.shape[2]
# maybe validate time positional embedding
@@ -54,9 +58,17 @@ class AcceptVideoWrapper(Module):
video = rearrange(video, 'b t ... -> (b t) ...')
# forward through image net for outputs
func = getattr(self.image_net, self.forward_function)
outputs = func(video)
if eval_with_no_grad:
self.image_net.eval()
context = torch.no_grad if eval_with_no_grad else nullcontext
with context():
outputs = func(video, **forward_kwargs)
# handle multiple outputs, say logits and embeddings returned from extractor - also handle some reduce aux loss being returned
@@ -67,11 +79,18 @@ class AcceptVideoWrapper(Module):
# maybe add time positional embedding
if add_time_pos_emb:
pos_emb = repeat(self.pos_emb, 't d -> b t 1 d', b = batch)
outputs = list(outputs)
embed = outputs[self.output_pos_add_pos_emb]
pos_emb = rearrange(self.pos_emb, 't d -> 1 t d')
# handle the network outputting embeddings with spatial dimensions intact - assume embedded dimension is last
dims_to_unsqueeze = embed.ndim - pos_emb.ndim
pos_emb = pos_emb.reshape(*pos_emb.shape[:2], *((1,) * dims_to_unsqueeze) , pos_emb.shape[-1])
embed = embed + pos_emb
outputs[self.output_pos_add_pos_emb] = embed
@@ -104,7 +123,7 @@ if __name__ == '__main__':
video_acceptor = AcceptVideoWrapper(v, add_time_pos_emb = True, output_pos_add_pos_emb = 1, time_seq_len = 10, dim_emb = 1024)
logits, embeddings = video_acceptor(videos) # always (batch, channels, time, height, width) - time is always dimension 2
logits, embeddings = video_acceptor(videos, eval_with_no_grad = True) # always (batch, channels, time, height, width) - time is always dimension 2
assert logits.shape == (1, 10, 1000)
assert embeddings.shape == (1, 10, 65, 1024)