diff --git a/setup.py b/setup.py index c74f398..9f75b95 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ with open('README.md') as f: setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '1.11.1', + version = '1.11.2', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', long_description = long_description, diff --git a/vit_pytorch/accept_video_wrapper.py b/vit_pytorch/accept_video_wrapper.py index 297a670..108ffe9 100644 --- a/vit_pytorch/accept_video_wrapper.py +++ b/vit_pytorch/accept_video_wrapper.py @@ -1,3 +1,5 @@ +from contextlib import nullcontext + import torch from torch import is_tensor, randn from torch.nn import Module, Parameter @@ -40,7 +42,8 @@ class AcceptVideoWrapper(Module): def forward( self, - video # (b c t h w) + video, # (b c t h w) + eval_with_no_grad = False ): add_time_pos_emb = self.add_time_pos_emb time = video.shape[2] @@ -54,9 +57,17 @@ class AcceptVideoWrapper(Module): video = rearrange(video, 'b t ... -> (b t) ...') + # forward through image net for outputs + func = getattr(self.image_net, self.forward_function) - outputs = func(video) + if eval_with_no_grad: + self.image_net.eval() + + context = torch.no_grad if eval_with_no_grad else nullcontext + + with context(): + outputs = func(video) # handle multiple outputs, say logits and embeddings returned from extractor - also handle some reduce aux loss being returned @@ -111,7 +122,7 @@ if __name__ == '__main__': video_acceptor = AcceptVideoWrapper(v, add_time_pos_emb = True, output_pos_add_pos_emb = 1, time_seq_len = 10, dim_emb = 1024) - logits, embeddings = video_acceptor(videos) # always (batch, channels, time, height, width) - time is always dimension 2 + logits, embeddings = video_acceptor(videos, eval_with_no_grad = True) # always (batch, channels, time, height, width) - time is always dimension 2 assert logits.shape == (1, 10, 1000) assert embeddings.shape == (1, 10, 65, 1024)