From 25871013f59ccfeaa86369961ca63324c1e75173 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 23 Oct 2025 10:55:16 -0700 Subject: [PATCH] forgot task conditioning for vat --- pyproject.toml | 2 +- vit_pytorch/vat.py | 51 +++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 84419b7..f0d190e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "vit-pytorch" -version = "1.14.1" +version = "1.14.2" description = "Vision Transformer (ViT) - Pytorch" readme = { file = "README.md", content-type = "text/markdown" } license = { file = "LICENSE" } diff --git a/vit_pytorch/vat.py b/vit_pytorch/vat.py index 2284f71..be61c2f 100644 --- a/vit_pytorch/vat.py +++ b/vit_pytorch/vat.py @@ -21,6 +21,27 @@ def pair(t): # classes +class FiLM(Module): + def __init__( + self, + dim, + ): + super().__init__() + proj = nn.Linear(dim, dim * 2) + + self.to_gamma_beta = nn.Sequential( + proj, + Rearrange('b (two d) -> two b 1 d', two = 2) + ) + + nn.init.zeros_(proj.weight) + nn.init.zeros_(proj.bias) + + def forward(self, tokens, cond): + gamma, beta = self.to_gamma_beta(cond) + + return tokens * gamma + beta + class FeedForward(Module): def __init__( self, @@ -228,6 +249,7 @@ class VAT(Module): dim_action, mlp_dim, num_views = None, + num_tasks = None, dim_extra_token = None, action_chunk_len = 7, time_seq_len = 1, @@ -266,6 +288,13 @@ class VAT(Module): self.view_emb = nn.Parameter(torch.randn(num_views, vit_dim) * 1e-2) if exists(num_views) and num_views > 1 else None + # handle maybe task conditioning + + self.has_tasks = exists(num_tasks) + + if self.has_tasks: + self.task_emb = nn.Parameter(torch.randn(num_tasks, dim) * 1e-2) + # to action tokens self.action_pos_emb = nn.Parameter(torch.randn(action_chunk_len, dim) * 1e-2) @@ -273,9 +302,11 @@ class VAT(Module): self.layers = ModuleList([]) for _ in range(depth): + maybe_film = FiLM(dim = dim) if self.has_tasks else None maybe_self_attn = Attention(dim = dim, heads = self_attn_heads, dim_head = self_attn_dim_head, dropout = dropout) if add_self_attn else None self.layers.append(ModuleList([ + maybe_film, maybe_self_attn, Attention(dim = dim, heads = heads, dim_head = dim_head, dropout = dropout, cross_attend = True), FeedForward(dim = dim, hidden_dim = mlp_dim, dropout = dropout) @@ -294,7 +325,9 @@ class VAT(Module): def forward( self, video_or_image, # (b v? c t? h w) - batch, views [wrist + third person or more], channels, maybe time, height, width + *, extra = None, # (b d) - batch, dim extra + tasks = None, # (b) actions = None, # (b k d) - batch, action chunk length, action dimension ): batch = video_or_image.shape[0] @@ -349,6 +382,13 @@ class VAT(Module): view_emb = rearrange(self.view_emb, 'v d -> v 1 1 d') hiddens = hiddens + view_emb + # maybe tasks + + if exists(tasks): + assert self.has_tasks, f'`num_tasks` must be set on `VAT` for task conditioning' + + task_emb = self.task_emb[tasks] + # cross from actions to representation trajectory context = rearrange(hiddens, 'l b v t n d -> l b (v t n) d') @@ -368,7 +408,10 @@ class VAT(Module): # cross attention - for (maybe_self_attn, cross_attn, ff), layer_context in zip(self.layers, context): + for (maybe_film, maybe_self_attn, cross_attn, ff), layer_context in zip(self.layers, context): + + if exists(tasks): + action_tokens = maybe_film(action_tokens, task_emb) action_tokens = cross_attn(action_tokens, layer_context) + action_tokens @@ -422,6 +465,7 @@ if __name__ == '__main__': action_chunk_len = 7, time_seq_len = 4, num_views = 2, + num_tasks = 4, add_self_attn = True, dim_extra_token = 33, # extra token with some variable dimension vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit) @@ -430,15 +474,16 @@ if __name__ == '__main__': ) images = torch.randn(2, 2, 3, 4, 256, 256) # (2 views with 4 frames) + tasks = torch.randint(0, 4, (2,)) extra = torch.randn(2, 33) # extra internal state actions = torch.randn(2, 7, 20) # actions for learning - loss = vat(images, actions = actions, extra = extra) + loss = vat(images, actions = actions, tasks = tasks, extra = extra) loss.backward() # after much training - pred_actions = vat(images) + pred_actions = vat(images, tasks = tasks, extra = extra) assert pred_actions.shape == (2, 7, 20)