mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f8bec5ede2 | ||
|
|
297e7d00a2 | ||
|
|
29ac8e143c | ||
|
|
e05cd6d8b8 | ||
|
|
b46233c3d6 | ||
|
|
68e13a3c7d |
2
setup.py
2
setup.py
@@ -6,7 +6,7 @@ with open('README.md') as f:
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '1.11.0',
|
||||
version = '1.11.6',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
long_description = long_description,
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from torch import is_tensor, randn
|
||||
from torch.nn import Module, Parameter
|
||||
from torch.nn import Module, Linear, Parameter
|
||||
from torch.utils._pytree import tree_flatten, tree_unflatten
|
||||
|
||||
from einops import rearrange, repeat
|
||||
@@ -23,7 +25,9 @@ class AcceptVideoWrapper(Module):
|
||||
add_time_pos_emb = False,
|
||||
dim_emb = None,
|
||||
time_seq_len = None,
|
||||
output_pos_add_pos_emb = 0 # defaults to first output position to add embedding
|
||||
embed_is_channel_first = False,
|
||||
output_pos_add_pos_emb = 0, # defaults to first output position to add embedding
|
||||
proj_embed_to_dim = None
|
||||
):
|
||||
super().__init__()
|
||||
self.image_net = image_net
|
||||
@@ -32,18 +36,34 @@ class AcceptVideoWrapper(Module):
|
||||
self.add_time_pos_emb = add_time_pos_emb
|
||||
self.output_pos_add_pos_emb = output_pos_add_pos_emb
|
||||
|
||||
# maybe project the image embedding
|
||||
|
||||
self.embed_proj = None
|
||||
|
||||
if exists(proj_embed_to_dim):
|
||||
assert exists(dim_emb), '`dim_emb` must be passed in'
|
||||
self.embed_proj = Linear(dim_emb, proj_embed_to_dim)
|
||||
|
||||
# time positional embedding
|
||||
|
||||
if add_time_pos_emb:
|
||||
assert exists(dim_emb) and exists(time_seq_len), '`dim_emb` and `time_seq_len` must be set if adding positional embeddings to the output'
|
||||
self.time_seq_len = time_seq_len
|
||||
|
||||
self.pos_emb = Parameter(randn(time_seq_len, dim_emb) * 1e-2)
|
||||
dim_pos_emb = default(proj_embed_to_dim, dim_emb)
|
||||
|
||||
self.pos_emb = Parameter(randn(time_seq_len, dim_pos_emb) * 1e-2)
|
||||
|
||||
self.embed_is_channel_first = embed_is_channel_first
|
||||
|
||||
def forward(
|
||||
self,
|
||||
video # (b c t h w)
|
||||
video, # (b c t h w)
|
||||
eval_with_no_grad = False,
|
||||
forward_kwargs = dict()
|
||||
):
|
||||
add_time_pos_emb = self.add_time_pos_emb
|
||||
batch, time = video.shape[0], video.shape[2]
|
||||
time = video.shape[2]
|
||||
|
||||
# maybe validate time positional embedding
|
||||
|
||||
@@ -54,9 +74,17 @@ class AcceptVideoWrapper(Module):
|
||||
|
||||
video = rearrange(video, 'b t ... -> (b t) ...')
|
||||
|
||||
# forward through image net for outputs
|
||||
|
||||
func = getattr(self.image_net, self.forward_function)
|
||||
|
||||
outputs = func(video)
|
||||
if eval_with_no_grad:
|
||||
self.image_net.eval()
|
||||
|
||||
context = torch.no_grad if eval_with_no_grad else nullcontext
|
||||
|
||||
with context():
|
||||
outputs = func(video, **forward_kwargs)
|
||||
|
||||
# handle multiple outputs, say logits and embeddings returned from extractor - also handle some reduce aux loss being returned
|
||||
|
||||
@@ -64,14 +92,37 @@ class AcceptVideoWrapper(Module):
|
||||
|
||||
outputs = tuple(rearrange(t, '(b t) ... -> b t ...', t = time) if is_tensor(t) and t.numel() > 1 else t for t in outputs)
|
||||
|
||||
# maybe project embedding
|
||||
|
||||
if exists(self.embed_proj):
|
||||
outputs = list(outputs)
|
||||
|
||||
embed = outputs[self.output_pos_add_pos_emb]
|
||||
|
||||
outputs[self.output_pos_add_pos_emb] = self.embed_proj(embed)
|
||||
|
||||
# maybe add time positional embedding
|
||||
|
||||
if add_time_pos_emb:
|
||||
pos_emb = repeat(self.pos_emb, 't d -> b t 1 d', b = batch)
|
||||
|
||||
outputs = list(outputs)
|
||||
embed = outputs[self.output_pos_add_pos_emb]
|
||||
|
||||
pos_emb = rearrange(self.pos_emb, 't d -> 1 t d')
|
||||
|
||||
# handle the network outputting embeddings with spatial dimensions intact - assume embedded dimension is last
|
||||
|
||||
dims_to_unsqueeze = embed.ndim - pos_emb.ndim
|
||||
|
||||
one_dims = ((1,) * dims_to_unsqueeze)
|
||||
|
||||
if self.embed_is_channel_first:
|
||||
pos_emb = pos_emb.reshape(*pos_emb.shape, *one_dims)
|
||||
else:
|
||||
pos_emb = pos_emb.reshape(*pos_emb.shape[:2], *one_dims, pos_emb.shape[-1])
|
||||
|
||||
pos_emb = pos_emb[:, :embed.shape[1]]
|
||||
|
||||
embed = embed + pos_emb
|
||||
|
||||
outputs[self.output_pos_add_pos_emb] = embed
|
||||
@@ -95,16 +146,16 @@ if __name__ == '__main__':
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
videos = torch.randn(1, 3, 10, 256, 256)
|
||||
videos = torch.randn(1, 3, 7, 256, 256)
|
||||
|
||||
# step up the difficulty and return embeddings for robotics
|
||||
|
||||
from vit_pytorch.extractor import Extractor
|
||||
v = Extractor(v)
|
||||
|
||||
video_acceptor = AcceptVideoWrapper(v, add_time_pos_emb = True, output_pos_add_pos_emb = 1, time_seq_len = 10, dim_emb = 1024)
|
||||
video_acceptor = AcceptVideoWrapper(v, add_time_pos_emb = True, output_pos_add_pos_emb = 1, time_seq_len = 12, dim_emb = 1024, proj_embed_to_dim = 512)
|
||||
|
||||
logits, embeddings = video_acceptor(videos) # always (batch, channels, time, height, width) - time is always dimension 2
|
||||
logits, embeddings = video_acceptor(videos, eval_with_no_grad = True) # always (batch, channels, time, height, width) - time is always dimension 2
|
||||
|
||||
assert logits.shape == (1, 10, 1000)
|
||||
assert embeddings.shape == (1, 10, 65, 1024)
|
||||
assert logits.shape == (1, 7, 1000)
|
||||
assert embeddings.shape == (1, 7, 65, 512)
|
||||
|
||||
Reference in New Issue
Block a user