VAAT should have two ears

This commit is contained in:
lucidrains
2025-11-22 08:32:23 -08:00
parent 6aa0374313
commit aa49c2783a
2 changed files with 50 additions and 17 deletions

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "vit-pytorch" name = "vit-pytorch"
version = "1.15.6" version = "1.15.7"
description = "Vision Transformer (ViT) - Pytorch" description = "Vision Transformer (ViT) - Pytorch"
readme = { file = "README.md", content-type = "text/markdown" } readme = { file = "README.md", content-type = "text/markdown" }
license = { file = "LICENSE" } license = { file = "LICENSE" }

View File

@@ -421,7 +421,8 @@ class VAAT(Module):
dim_head, dim_head,
dim_action, dim_action,
mlp_dim, mlp_dim,
num_views = None, num_image_views = None,
num_audio_views = None,
num_tasks = None, num_tasks = None,
dim_extra_token = None, dim_extra_token = None,
num_register_tokens = 4, num_register_tokens = 4,
@@ -462,6 +463,8 @@ class VAAT(Module):
ast_dim = ast.dim ast_dim = ast.dim
self.ast_accept_spec = ast.accept_spec
assert ast.depth == depth or exists(ast_layer_indices), f'if the VAAT depth is not equal to the AST depth, you must pass in the indices from the AST to be layered to the VAAT in order from bottom to top' assert ast.depth == depth or exists(ast_layer_indices), f'if the VAAT depth is not equal to the AST depth, you must pass in the indices from the AST to be layered to the VAAT in order from bottom to top'
ast_layer_indices = default(ast_layer_indices, tuple(range(depth))) ast_layer_indices = default(ast_layer_indices, tuple(range(depth)))
@@ -480,7 +483,9 @@ class VAAT(Module):
# maybe view embeddings # maybe view embeddings
self.view_emb = nn.Parameter(torch.randn(num_views, vit_dim) * 1e-2) if exists(num_views) and num_views > 1 else None self.image_view_emb = nn.Parameter(torch.randn(num_image_views, vit_dim) * 1e-2) if exists(num_image_views) and num_image_views > 1 else None
self.audio_view_emb = nn.Parameter(torch.randn(num_audio_views, ast_dim) * 1e-2) if exists(num_audio_views) and num_audio_views > 1 else None
# handle maybe task conditioning # handle maybe task conditioning
@@ -523,12 +528,12 @@ class VAAT(Module):
def forward( def forward(
self, self,
video_or_image, # (b v? c t? h w) - batch, views [wrist + third person or more], channels, maybe time, height, width video_or_image, # (b v? c t? h w) - batch, views [wrist + third person or more], channels, maybe time, height, width
audio_or_spec, # (b t) | (b f t) - batch, audio len | batch, spec freq, time audio_or_spec, # (b v? t) | (b v?f t) - batch, audio len | batch, spec freq, time
*, *,
extra = None, # (b d) - batch, dim extra extra = None, # (b d) - batch, dim extra
tasks = None, # (b) tasks = None, # (b)
actions = None, # (b k d) - batch, action chunk length, action dimension actions = None, # (b k d) - batch, action chunk length, action dimension
return_hiddens = False, return_hiddens = False,
freeze_vit = False, freeze_vit = False,
freeze_ast = False freeze_ast = False
@@ -551,11 +556,26 @@ class VAAT(Module):
assert video_or_image.shape[3] == self.time_seq_len assert video_or_image.shape[3] == self.time_seq_len
# audio shapes - adding view if impliciy to be 1
if audio_or_spec.ndim == 2 and not self.ast_accept_spec:
audio_or_spec = rearrange(audio_or_spec, 'b t -> b 1 t')
elif audio_or_spec.ndim == 3 and self.ast_accept_spec:
audio_or_spec = rearrange(audio_or_spec, 'b f t -> b 1 f t')
# to images # to images
images = rearrange(video_or_image, 'b v c t h w -> b v t c h w') images = rearrange(video_or_image, 'b v c t h w -> b v t c h w')
images, packed_shape = pack([images], '* c h w') images, image_packed_shape = pack([images], '* c h w')
# to audio
if self.ast_accept_spec:
audio_or_spec, audio_packed_shape = pack([audio_or_spec], '* f t')
else:
audio_or_spec, audio_packed_shape = pack([audio_or_spec], '* t')
# get representation trajectory from vit # get representation trajectory from vit
@@ -570,9 +590,9 @@ class VAAT(Module):
hiddens = hiddens[self.vit_layer_indices] hiddens = hiddens[self.vit_layer_indices]
# pack temporarily for embedding # unpack temporarily for embedding
hiddens, = unpack(hiddens, packed_shape, 'l * n d') # l for layers hiddens, = unpack(hiddens, image_packed_shape, 'l * n d') # l for layers
# maybe add time embeddings # maybe add time embeddings
@@ -582,11 +602,11 @@ class VAAT(Module):
# maybe view embeddings # maybe view embeddings
if exists(self.view_emb): if exists(self.image_view_emb):
assert self.view_emb.shape[0] == hiddens.shape[2] assert self.image_view_emb.shape[0] == hiddens.shape[2]
view_emb = rearrange(self.view_emb, 'v d -> v 1 1 d') image_view_emb = rearrange(self.image_view_emb, 'v d -> v 1 1 d')
hiddens = hiddens + view_emb hiddens = hiddens + image_view_emb
# get representation trajectory from ast # get representation trajectory from ast
@@ -601,6 +621,18 @@ class VAAT(Module):
audio_hiddens = audio_hiddens[self.ast_layer_indices] audio_hiddens = audio_hiddens[self.ast_layer_indices]
# unpack audio temporarily for embedding
audio_hiddens, = unpack(audio_hiddens, audio_packed_shape, 'l * n d') # l for layers
# maybe audio view embeddings
if exists(self.audio_view_emb):
assert self.audio_view_emb.shape[0] == audio_hiddens.shape[2]
audio_view_emb = rearrange(self.audio_view_emb, 'v d -> v 1 1 d')
audio_hiddens = audio_hiddens + audio_view_emb
# maybe tasks # maybe tasks
if exists(tasks): if exists(tasks):
@@ -612,7 +644,7 @@ class VAAT(Module):
image_context = rearrange(hiddens, 'l b v t n d -> l b (v t n) d') image_context = rearrange(hiddens, 'l b v t n d -> l b (v t n) d')
audio_context = audio_hiddens # eventually handle views (stereo and beyond) audio_context = rearrange(audio_hiddens, 'l b v n d -> l b (v n) d')
# get main action tokens and maybe append extra # get main action tokens and maybe append extra
@@ -714,7 +746,8 @@ if __name__ == '__main__':
dim_action = 20, dim_action = 20,
action_chunk_len = 7, action_chunk_len = 7,
time_seq_len = 4, time_seq_len = 4,
num_views = 2, num_image_views = 2,
num_audio_views = 2,
num_tasks = 4, num_tasks = 4,
add_self_attn = True, add_self_attn = True,
dim_extra_token = 33, # extra token with some variable dimension dim_extra_token = 33, # extra token with some variable dimension
@@ -727,7 +760,7 @@ if __name__ == '__main__':
) )
images = torch.randn(2, 2, 3, 4, 256, 256) # (2 views with 4 frames) images = torch.randn(2, 2, 3, 4, 256, 256) # (2 views with 4 frames)
audio = torch.randn(2, 14_100 * 5) audio = torch.randn(2, 2, 14_100 * 5)
tasks = torch.randint(0, 4, (2,)) tasks = torch.randint(0, 4, (2,))
extra = torch.randn(2, 33) # extra internal state extra = torch.randn(2, 33) # extra internal state