mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
VAAT should have two ears
This commit is contained in:
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "vit-pytorch"
|
name = "vit-pytorch"
|
||||||
version = "1.15.6"
|
version = "1.15.7"
|
||||||
description = "Vision Transformer (ViT) - Pytorch"
|
description = "Vision Transformer (ViT) - Pytorch"
|
||||||
readme = { file = "README.md", content-type = "text/markdown" }
|
readme = { file = "README.md", content-type = "text/markdown" }
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
|
|||||||
@@ -421,7 +421,8 @@ class VAAT(Module):
|
|||||||
dim_head,
|
dim_head,
|
||||||
dim_action,
|
dim_action,
|
||||||
mlp_dim,
|
mlp_dim,
|
||||||
num_views = None,
|
num_image_views = None,
|
||||||
|
num_audio_views = None,
|
||||||
num_tasks = None,
|
num_tasks = None,
|
||||||
dim_extra_token = None,
|
dim_extra_token = None,
|
||||||
num_register_tokens = 4,
|
num_register_tokens = 4,
|
||||||
@@ -462,6 +463,8 @@ class VAAT(Module):
|
|||||||
|
|
||||||
ast_dim = ast.dim
|
ast_dim = ast.dim
|
||||||
|
|
||||||
|
self.ast_accept_spec = ast.accept_spec
|
||||||
|
|
||||||
assert ast.depth == depth or exists(ast_layer_indices), f'if the VAAT depth is not equal to the AST depth, you must pass in the indices from the AST to be layered to the VAAT in order from bottom to top'
|
assert ast.depth == depth or exists(ast_layer_indices), f'if the VAAT depth is not equal to the AST depth, you must pass in the indices from the AST to be layered to the VAAT in order from bottom to top'
|
||||||
|
|
||||||
ast_layer_indices = default(ast_layer_indices, tuple(range(depth)))
|
ast_layer_indices = default(ast_layer_indices, tuple(range(depth)))
|
||||||
@@ -480,7 +483,9 @@ class VAAT(Module):
|
|||||||
|
|
||||||
# maybe view embeddings
|
# maybe view embeddings
|
||||||
|
|
||||||
self.view_emb = nn.Parameter(torch.randn(num_views, vit_dim) * 1e-2) if exists(num_views) and num_views > 1 else None
|
self.image_view_emb = nn.Parameter(torch.randn(num_image_views, vit_dim) * 1e-2) if exists(num_image_views) and num_image_views > 1 else None
|
||||||
|
|
||||||
|
self.audio_view_emb = nn.Parameter(torch.randn(num_audio_views, ast_dim) * 1e-2) if exists(num_audio_views) and num_audio_views > 1 else None
|
||||||
|
|
||||||
# handle maybe task conditioning
|
# handle maybe task conditioning
|
||||||
|
|
||||||
@@ -523,12 +528,12 @@ class VAAT(Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
video_or_image, # (b v? c t? h w) - batch, views [wrist + third person or more], channels, maybe time, height, width
|
video_or_image, # (b v? c t? h w) - batch, views [wrist + third person or more], channels, maybe time, height, width
|
||||||
audio_or_spec, # (b t) | (b f t) - batch, audio len | batch, spec freq, time
|
audio_or_spec, # (b v? t) | (b v?f t) - batch, audio len | batch, spec freq, time
|
||||||
*,
|
*,
|
||||||
extra = None, # (b d) - batch, dim extra
|
extra = None, # (b d) - batch, dim extra
|
||||||
tasks = None, # (b)
|
tasks = None, # (b)
|
||||||
actions = None, # (b k d) - batch, action chunk length, action dimension
|
actions = None, # (b k d) - batch, action chunk length, action dimension
|
||||||
return_hiddens = False,
|
return_hiddens = False,
|
||||||
freeze_vit = False,
|
freeze_vit = False,
|
||||||
freeze_ast = False
|
freeze_ast = False
|
||||||
@@ -551,11 +556,26 @@ class VAAT(Module):
|
|||||||
|
|
||||||
assert video_or_image.shape[3] == self.time_seq_len
|
assert video_or_image.shape[3] == self.time_seq_len
|
||||||
|
|
||||||
|
# audio shapes - adding view if impliciy to be 1
|
||||||
|
|
||||||
|
if audio_or_spec.ndim == 2 and not self.ast_accept_spec:
|
||||||
|
audio_or_spec = rearrange(audio_or_spec, 'b t -> b 1 t')
|
||||||
|
|
||||||
|
elif audio_or_spec.ndim == 3 and self.ast_accept_spec:
|
||||||
|
audio_or_spec = rearrange(audio_or_spec, 'b f t -> b 1 f t')
|
||||||
|
|
||||||
# to images
|
# to images
|
||||||
|
|
||||||
images = rearrange(video_or_image, 'b v c t h w -> b v t c h w')
|
images = rearrange(video_or_image, 'b v c t h w -> b v t c h w')
|
||||||
|
|
||||||
images, packed_shape = pack([images], '* c h w')
|
images, image_packed_shape = pack([images], '* c h w')
|
||||||
|
|
||||||
|
# to audio
|
||||||
|
|
||||||
|
if self.ast_accept_spec:
|
||||||
|
audio_or_spec, audio_packed_shape = pack([audio_or_spec], '* f t')
|
||||||
|
else:
|
||||||
|
audio_or_spec, audio_packed_shape = pack([audio_or_spec], '* t')
|
||||||
|
|
||||||
# get representation trajectory from vit
|
# get representation trajectory from vit
|
||||||
|
|
||||||
@@ -570,9 +590,9 @@ class VAAT(Module):
|
|||||||
|
|
||||||
hiddens = hiddens[self.vit_layer_indices]
|
hiddens = hiddens[self.vit_layer_indices]
|
||||||
|
|
||||||
# pack temporarily for embedding
|
# unpack temporarily for embedding
|
||||||
|
|
||||||
hiddens, = unpack(hiddens, packed_shape, 'l * n d') # l for layers
|
hiddens, = unpack(hiddens, image_packed_shape, 'l * n d') # l for layers
|
||||||
|
|
||||||
# maybe add time embeddings
|
# maybe add time embeddings
|
||||||
|
|
||||||
@@ -582,11 +602,11 @@ class VAAT(Module):
|
|||||||
|
|
||||||
# maybe view embeddings
|
# maybe view embeddings
|
||||||
|
|
||||||
if exists(self.view_emb):
|
if exists(self.image_view_emb):
|
||||||
assert self.view_emb.shape[0] == hiddens.shape[2]
|
assert self.image_view_emb.shape[0] == hiddens.shape[2]
|
||||||
|
|
||||||
view_emb = rearrange(self.view_emb, 'v d -> v 1 1 d')
|
image_view_emb = rearrange(self.image_view_emb, 'v d -> v 1 1 d')
|
||||||
hiddens = hiddens + view_emb
|
hiddens = hiddens + image_view_emb
|
||||||
|
|
||||||
# get representation trajectory from ast
|
# get representation trajectory from ast
|
||||||
|
|
||||||
@@ -601,6 +621,18 @@ class VAAT(Module):
|
|||||||
|
|
||||||
audio_hiddens = audio_hiddens[self.ast_layer_indices]
|
audio_hiddens = audio_hiddens[self.ast_layer_indices]
|
||||||
|
|
||||||
|
# unpack audio temporarily for embedding
|
||||||
|
|
||||||
|
audio_hiddens, = unpack(audio_hiddens, audio_packed_shape, 'l * n d') # l for layers
|
||||||
|
|
||||||
|
# maybe audio view embeddings
|
||||||
|
|
||||||
|
if exists(self.audio_view_emb):
|
||||||
|
assert self.audio_view_emb.shape[0] == audio_hiddens.shape[2]
|
||||||
|
|
||||||
|
audio_view_emb = rearrange(self.audio_view_emb, 'v d -> v 1 1 d')
|
||||||
|
audio_hiddens = audio_hiddens + audio_view_emb
|
||||||
|
|
||||||
# maybe tasks
|
# maybe tasks
|
||||||
|
|
||||||
if exists(tasks):
|
if exists(tasks):
|
||||||
@@ -612,7 +644,7 @@ class VAAT(Module):
|
|||||||
|
|
||||||
image_context = rearrange(hiddens, 'l b v t n d -> l b (v t n) d')
|
image_context = rearrange(hiddens, 'l b v t n d -> l b (v t n) d')
|
||||||
|
|
||||||
audio_context = audio_hiddens # eventually handle views (stereo and beyond)
|
audio_context = rearrange(audio_hiddens, 'l b v n d -> l b (v n) d')
|
||||||
|
|
||||||
# get main action tokens and maybe append extra
|
# get main action tokens and maybe append extra
|
||||||
|
|
||||||
@@ -714,7 +746,8 @@ if __name__ == '__main__':
|
|||||||
dim_action = 20,
|
dim_action = 20,
|
||||||
action_chunk_len = 7,
|
action_chunk_len = 7,
|
||||||
time_seq_len = 4,
|
time_seq_len = 4,
|
||||||
num_views = 2,
|
num_image_views = 2,
|
||||||
|
num_audio_views = 2,
|
||||||
num_tasks = 4,
|
num_tasks = 4,
|
||||||
add_self_attn = True,
|
add_self_attn = True,
|
||||||
dim_extra_token = 33, # extra token with some variable dimension
|
dim_extra_token = 33, # extra token with some variable dimension
|
||||||
@@ -727,7 +760,7 @@ if __name__ == '__main__':
|
|||||||
)
|
)
|
||||||
|
|
||||||
images = torch.randn(2, 2, 3, 4, 256, 256) # (2 views with 4 frames)
|
images = torch.randn(2, 2, 3, 4, 256, 256) # (2 views with 4 frames)
|
||||||
audio = torch.randn(2, 14_100 * 5)
|
audio = torch.randn(2, 2, 14_100 * 5)
|
||||||
|
|
||||||
tasks = torch.randint(0, 4, (2,))
|
tasks = torch.randint(0, 4, (2,))
|
||||||
extra = torch.randn(2, 33) # extra internal state
|
extra = torch.randn(2, 33) # extra internal state
|
||||||
|
|||||||
Reference in New Issue
Block a user