Compare commits

...

6 Commits

2 changed files with 64 additions and 13 deletions

View File

@@ -6,7 +6,7 @@ with open('README.md') as f:
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.11.0',
version = '1.11.6',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description = long_description,

View File

@@ -1,6 +1,8 @@
from contextlib import nullcontext
import torch
from torch import is_tensor, randn
from torch.nn import Module, Parameter
from torch.nn import Module, Linear, Parameter
from torch.utils._pytree import tree_flatten, tree_unflatten
from einops import rearrange, repeat
@@ -23,7 +25,9 @@ class AcceptVideoWrapper(Module):
add_time_pos_emb = False,
dim_emb = None,
time_seq_len = None,
output_pos_add_pos_emb = 0 # defaults to first output position to add embedding
embed_is_channel_first = False,
output_pos_add_pos_emb = 0, # defaults to first output position to add embedding
proj_embed_to_dim = None
):
super().__init__()
self.image_net = image_net
@@ -32,18 +36,34 @@ class AcceptVideoWrapper(Module):
self.add_time_pos_emb = add_time_pos_emb
self.output_pos_add_pos_emb = output_pos_add_pos_emb
# maybe project the image embedding
self.embed_proj = None
if exists(proj_embed_to_dim):
assert exists(dim_emb), '`dim_emb` must be passed in'
self.embed_proj = Linear(dim_emb, proj_embed_to_dim)
# time positional embedding
if add_time_pos_emb:
assert exists(dim_emb) and exists(time_seq_len), '`dim_emb` and `time_seq_len` must be set if adding positional embeddings to the output'
self.time_seq_len = time_seq_len
self.pos_emb = Parameter(randn(time_seq_len, dim_emb) * 1e-2)
dim_pos_emb = default(proj_embed_to_dim, dim_emb)
self.pos_emb = Parameter(randn(time_seq_len, dim_pos_emb) * 1e-2)
self.embed_is_channel_first = embed_is_channel_first
def forward(
self,
video # (b c t h w)
video, # (b c t h w)
eval_with_no_grad = False,
forward_kwargs = dict()
):
add_time_pos_emb = self.add_time_pos_emb
batch, time = video.shape[0], video.shape[2]
time = video.shape[2]
# maybe validate time positional embedding
@@ -54,9 +74,17 @@ class AcceptVideoWrapper(Module):
video = rearrange(video, 'b t ... -> (b t) ...')
# forward through image net for outputs
func = getattr(self.image_net, self.forward_function)
outputs = func(video)
if eval_with_no_grad:
self.image_net.eval()
context = torch.no_grad if eval_with_no_grad else nullcontext
with context():
outputs = func(video, **forward_kwargs)
# handle multiple outputs, say logits and embeddings returned from extractor - also handle some reduce aux loss being returned
@@ -64,14 +92,37 @@ class AcceptVideoWrapper(Module):
outputs = tuple(rearrange(t, '(b t) ... -> b t ...', t = time) if is_tensor(t) and t.numel() > 1 else t for t in outputs)
# maybe project embedding
if exists(self.embed_proj):
outputs = list(outputs)
embed = outputs[self.output_pos_add_pos_emb]
outputs[self.output_pos_add_pos_emb] = self.embed_proj(embed)
# maybe add time positional embedding
if add_time_pos_emb:
pos_emb = repeat(self.pos_emb, 't d -> b t 1 d', b = batch)
outputs = list(outputs)
embed = outputs[self.output_pos_add_pos_emb]
pos_emb = rearrange(self.pos_emb, 't d -> 1 t d')
# handle the network outputting embeddings with spatial dimensions intact - assume embedded dimension is last
dims_to_unsqueeze = embed.ndim - pos_emb.ndim
one_dims = ((1,) * dims_to_unsqueeze)
if self.embed_is_channel_first:
pos_emb = pos_emb.reshape(*pos_emb.shape, *one_dims)
else:
pos_emb = pos_emb.reshape(*pos_emb.shape[:2], *one_dims, pos_emb.shape[-1])
pos_emb = pos_emb[:, :embed.shape[1]]
embed = embed + pos_emb
outputs[self.output_pos_add_pos_emb] = embed
@@ -95,16 +146,16 @@ if __name__ == '__main__':
emb_dropout = 0.1
)
videos = torch.randn(1, 3, 10, 256, 256)
videos = torch.randn(1, 3, 7, 256, 256)
# step up the difficulty and return embeddings for robotics
from vit_pytorch.extractor import Extractor
v = Extractor(v)
video_acceptor = AcceptVideoWrapper(v, add_time_pos_emb = True, output_pos_add_pos_emb = 1, time_seq_len = 10, dim_emb = 1024)
video_acceptor = AcceptVideoWrapper(v, add_time_pos_emb = True, output_pos_add_pos_emb = 1, time_seq_len = 12, dim_emb = 1024, proj_embed_to_dim = 512)
logits, embeddings = video_acceptor(videos) # always (batch, channels, time, height, width) - time is always dimension 2
logits, embeddings = video_acceptor(videos, eval_with_no_grad = True) # always (batch, channels, time, height, width) - time is always dimension 2
assert logits.shape == (1, 10, 1000)
assert embeddings.shape == (1, 10, 65, 1024)
assert logits.shape == (1, 7, 1000)
assert embeddings.shape == (1, 7, 65, 512)