mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
need to be able to invoke with eval no grad
This commit is contained in:
2
setup.py
2
setup.py
@@ -6,7 +6,7 @@ with open('README.md') as f:
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '1.11.1',
|
||||
version = '1.11.2',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
long_description = long_description,
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from torch import is_tensor, randn
|
||||
from torch.nn import Module, Parameter
|
||||
@@ -40,7 +42,8 @@ class AcceptVideoWrapper(Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
video # (b c t h w)
|
||||
video, # (b c t h w)
|
||||
eval_with_no_grad = False
|
||||
):
|
||||
add_time_pos_emb = self.add_time_pos_emb
|
||||
time = video.shape[2]
|
||||
@@ -54,9 +57,17 @@ class AcceptVideoWrapper(Module):
|
||||
|
||||
video = rearrange(video, 'b t ... -> (b t) ...')
|
||||
|
||||
# forward through image net for outputs
|
||||
|
||||
func = getattr(self.image_net, self.forward_function)
|
||||
|
||||
outputs = func(video)
|
||||
if eval_with_no_grad:
|
||||
self.image_net.eval()
|
||||
|
||||
context = torch.no_grad if eval_with_no_grad else nullcontext
|
||||
|
||||
with context():
|
||||
outputs = func(video)
|
||||
|
||||
# handle multiple outputs, say logits and embeddings returned from extractor - also handle some reduce aux loss being returned
|
||||
|
||||
@@ -111,7 +122,7 @@ if __name__ == '__main__':
|
||||
|
||||
video_acceptor = AcceptVideoWrapper(v, add_time_pos_emb = True, output_pos_add_pos_emb = 1, time_seq_len = 10, dim_emb = 1024)
|
||||
|
||||
logits, embeddings = video_acceptor(videos) # always (batch, channels, time, height, width) - time is always dimension 2
|
||||
logits, embeddings = video_acceptor(videos, eval_with_no_grad = True) # always (batch, channels, time, height, width) - time is always dimension 2
|
||||
|
||||
assert logits.shape == (1, 10, 1000)
|
||||
assert embeddings.shape == (1, 10, 65, 1024)
|
||||
|
||||
Reference in New Issue
Block a user