forgot task conditioning for vat

This commit is contained in:
lucidrains
2025-10-23 10:55:16 -07:00
parent e66862bcd5
commit 25871013f5
2 changed files with 49 additions and 4 deletions

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "vit-pytorch" name = "vit-pytorch"
version = "1.14.1" version = "1.14.2"
description = "Vision Transformer (ViT) - Pytorch" description = "Vision Transformer (ViT) - Pytorch"
readme = { file = "README.md", content-type = "text/markdown" } readme = { file = "README.md", content-type = "text/markdown" }
license = { file = "LICENSE" } license = { file = "LICENSE" }

View File

@@ -21,6 +21,27 @@ def pair(t):
# classes # classes
class FiLM(Module):
def __init__(
self,
dim,
):
super().__init__()
proj = nn.Linear(dim, dim * 2)
self.to_gamma_beta = nn.Sequential(
proj,
Rearrange('b (two d) -> two b 1 d', two = 2)
)
nn.init.zeros_(proj.weight)
nn.init.zeros_(proj.bias)
def forward(self, tokens, cond):
gamma, beta = self.to_gamma_beta(cond)
return tokens * gamma + beta
class FeedForward(Module): class FeedForward(Module):
def __init__( def __init__(
self, self,
@@ -228,6 +249,7 @@ class VAT(Module):
dim_action, dim_action,
mlp_dim, mlp_dim,
num_views = None, num_views = None,
num_tasks = None,
dim_extra_token = None, dim_extra_token = None,
action_chunk_len = 7, action_chunk_len = 7,
time_seq_len = 1, time_seq_len = 1,
@@ -266,6 +288,13 @@ class VAT(Module):
self.view_emb = nn.Parameter(torch.randn(num_views, vit_dim) * 1e-2) if exists(num_views) and num_views > 1 else None self.view_emb = nn.Parameter(torch.randn(num_views, vit_dim) * 1e-2) if exists(num_views) and num_views > 1 else None
# handle maybe task conditioning
self.has_tasks = exists(num_tasks)
if self.has_tasks:
self.task_emb = nn.Parameter(torch.randn(num_tasks, dim) * 1e-2)
# to action tokens # to action tokens
self.action_pos_emb = nn.Parameter(torch.randn(action_chunk_len, dim) * 1e-2) self.action_pos_emb = nn.Parameter(torch.randn(action_chunk_len, dim) * 1e-2)
@@ -273,9 +302,11 @@ class VAT(Module):
self.layers = ModuleList([]) self.layers = ModuleList([])
for _ in range(depth): for _ in range(depth):
maybe_film = FiLM(dim = dim) if self.has_tasks else None
maybe_self_attn = Attention(dim = dim, heads = self_attn_heads, dim_head = self_attn_dim_head, dropout = dropout) if add_self_attn else None maybe_self_attn = Attention(dim = dim, heads = self_attn_heads, dim_head = self_attn_dim_head, dropout = dropout) if add_self_attn else None
self.layers.append(ModuleList([ self.layers.append(ModuleList([
maybe_film,
maybe_self_attn, maybe_self_attn,
Attention(dim = dim, heads = heads, dim_head = dim_head, dropout = dropout, cross_attend = True), Attention(dim = dim, heads = heads, dim_head = dim_head, dropout = dropout, cross_attend = True),
FeedForward(dim = dim, hidden_dim = mlp_dim, dropout = dropout) FeedForward(dim = dim, hidden_dim = mlp_dim, dropout = dropout)
@@ -294,7 +325,9 @@ class VAT(Module):
def forward( def forward(
self, self,
video_or_image, # (b v? c t? h w) - batch, views [wrist + third person or more], channels, maybe time, height, width video_or_image, # (b v? c t? h w) - batch, views [wrist + third person or more], channels, maybe time, height, width
*,
extra = None, # (b d) - batch, dim extra extra = None, # (b d) - batch, dim extra
tasks = None, # (b)
actions = None, # (b k d) - batch, action chunk length, action dimension actions = None, # (b k d) - batch, action chunk length, action dimension
): ):
batch = video_or_image.shape[0] batch = video_or_image.shape[0]
@@ -349,6 +382,13 @@ class VAT(Module):
view_emb = rearrange(self.view_emb, 'v d -> v 1 1 d') view_emb = rearrange(self.view_emb, 'v d -> v 1 1 d')
hiddens = hiddens + view_emb hiddens = hiddens + view_emb
# maybe tasks
if exists(tasks):
assert self.has_tasks, f'`num_tasks` must be set on `VAT` for task conditioning'
task_emb = self.task_emb[tasks]
# cross from actions to representation trajectory # cross from actions to representation trajectory
context = rearrange(hiddens, 'l b v t n d -> l b (v t n) d') context = rearrange(hiddens, 'l b v t n d -> l b (v t n) d')
@@ -368,7 +408,10 @@ class VAT(Module):
# cross attention # cross attention
for (maybe_self_attn, cross_attn, ff), layer_context in zip(self.layers, context): for (maybe_film, maybe_self_attn, cross_attn, ff), layer_context in zip(self.layers, context):
if exists(tasks):
action_tokens = maybe_film(action_tokens, task_emb)
action_tokens = cross_attn(action_tokens, layer_context) + action_tokens action_tokens = cross_attn(action_tokens, layer_context) + action_tokens
@@ -422,6 +465,7 @@ if __name__ == '__main__':
action_chunk_len = 7, action_chunk_len = 7,
time_seq_len = 4, time_seq_len = 4,
num_views = 2, num_views = 2,
num_tasks = 4,
add_self_attn = True, add_self_attn = True,
dim_extra_token = 33, # extra token with some variable dimension dim_extra_token = 33, # extra token with some variable dimension
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit) vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
@@ -430,15 +474,16 @@ if __name__ == '__main__':
) )
images = torch.randn(2, 2, 3, 4, 256, 256) # (2 views with 4 frames) images = torch.randn(2, 2, 3, 4, 256, 256) # (2 views with 4 frames)
tasks = torch.randint(0, 4, (2,))
extra = torch.randn(2, 33) # extra internal state extra = torch.randn(2, 33) # extra internal state
actions = torch.randn(2, 7, 20) # actions for learning actions = torch.randn(2, 7, 20) # actions for learning
loss = vat(images, actions = actions, extra = extra) loss = vat(images, actions = actions, tasks = tasks, extra = extra)
loss.backward() loss.backward()
# after much training # after much training
pred_actions = vat(images) pred_actions = vat(images, tasks = tasks, extra = extra)
assert pred_actions.shape == (2, 7, 20) assert pred_actions.shape == (2, 7, 20)