mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
forgot task conditioning for vat
This commit is contained in:
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "vit-pytorch"
|
||||
version = "1.14.1"
|
||||
version = "1.14.2"
|
||||
description = "Vision Transformer (ViT) - Pytorch"
|
||||
readme = { file = "README.md", content-type = "text/markdown" }
|
||||
license = { file = "LICENSE" }
|
||||
|
||||
@@ -21,6 +21,27 @@ def pair(t):
|
||||
|
||||
# classes
|
||||
|
||||
class FiLM(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
):
|
||||
super().__init__()
|
||||
proj = nn.Linear(dim, dim * 2)
|
||||
|
||||
self.to_gamma_beta = nn.Sequential(
|
||||
proj,
|
||||
Rearrange('b (two d) -> two b 1 d', two = 2)
|
||||
)
|
||||
|
||||
nn.init.zeros_(proj.weight)
|
||||
nn.init.zeros_(proj.bias)
|
||||
|
||||
def forward(self, tokens, cond):
|
||||
gamma, beta = self.to_gamma_beta(cond)
|
||||
|
||||
return tokens * gamma + beta
|
||||
|
||||
class FeedForward(Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -228,6 +249,7 @@ class VAT(Module):
|
||||
dim_action,
|
||||
mlp_dim,
|
||||
num_views = None,
|
||||
num_tasks = None,
|
||||
dim_extra_token = None,
|
||||
action_chunk_len = 7,
|
||||
time_seq_len = 1,
|
||||
@@ -266,6 +288,13 @@ class VAT(Module):
|
||||
|
||||
self.view_emb = nn.Parameter(torch.randn(num_views, vit_dim) * 1e-2) if exists(num_views) and num_views > 1 else None
|
||||
|
||||
# handle maybe task conditioning
|
||||
|
||||
self.has_tasks = exists(num_tasks)
|
||||
|
||||
if self.has_tasks:
|
||||
self.task_emb = nn.Parameter(torch.randn(num_tasks, dim) * 1e-2)
|
||||
|
||||
# to action tokens
|
||||
|
||||
self.action_pos_emb = nn.Parameter(torch.randn(action_chunk_len, dim) * 1e-2)
|
||||
@@ -273,9 +302,11 @@ class VAT(Module):
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
maybe_film = FiLM(dim = dim) if self.has_tasks else None
|
||||
maybe_self_attn = Attention(dim = dim, heads = self_attn_heads, dim_head = self_attn_dim_head, dropout = dropout) if add_self_attn else None
|
||||
|
||||
self.layers.append(ModuleList([
|
||||
maybe_film,
|
||||
maybe_self_attn,
|
||||
Attention(dim = dim, heads = heads, dim_head = dim_head, dropout = dropout, cross_attend = True),
|
||||
FeedForward(dim = dim, hidden_dim = mlp_dim, dropout = dropout)
|
||||
@@ -294,7 +325,9 @@ class VAT(Module):
|
||||
def forward(
|
||||
self,
|
||||
video_or_image, # (b v? c t? h w) - batch, views [wrist + third person or more], channels, maybe time, height, width
|
||||
*,
|
||||
extra = None, # (b d) - batch, dim extra
|
||||
tasks = None, # (b)
|
||||
actions = None, # (b k d) - batch, action chunk length, action dimension
|
||||
):
|
||||
batch = video_or_image.shape[0]
|
||||
@@ -349,6 +382,13 @@ class VAT(Module):
|
||||
view_emb = rearrange(self.view_emb, 'v d -> v 1 1 d')
|
||||
hiddens = hiddens + view_emb
|
||||
|
||||
# maybe tasks
|
||||
|
||||
if exists(tasks):
|
||||
assert self.has_tasks, f'`num_tasks` must be set on `VAT` for task conditioning'
|
||||
|
||||
task_emb = self.task_emb[tasks]
|
||||
|
||||
# cross from actions to representation trajectory
|
||||
|
||||
context = rearrange(hiddens, 'l b v t n d -> l b (v t n) d')
|
||||
@@ -368,7 +408,10 @@ class VAT(Module):
|
||||
|
||||
# cross attention
|
||||
|
||||
for (maybe_self_attn, cross_attn, ff), layer_context in zip(self.layers, context):
|
||||
for (maybe_film, maybe_self_attn, cross_attn, ff), layer_context in zip(self.layers, context):
|
||||
|
||||
if exists(tasks):
|
||||
action_tokens = maybe_film(action_tokens, task_emb)
|
||||
|
||||
action_tokens = cross_attn(action_tokens, layer_context) + action_tokens
|
||||
|
||||
@@ -422,6 +465,7 @@ if __name__ == '__main__':
|
||||
action_chunk_len = 7,
|
||||
time_seq_len = 4,
|
||||
num_views = 2,
|
||||
num_tasks = 4,
|
||||
add_self_attn = True,
|
||||
dim_extra_token = 33, # extra token with some variable dimension
|
||||
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
|
||||
@@ -430,15 +474,16 @@ if __name__ == '__main__':
|
||||
)
|
||||
|
||||
images = torch.randn(2, 2, 3, 4, 256, 256) # (2 views with 4 frames)
|
||||
tasks = torch.randint(0, 4, (2,))
|
||||
extra = torch.randn(2, 33) # extra internal state
|
||||
|
||||
actions = torch.randn(2, 7, 20) # actions for learning
|
||||
|
||||
loss = vat(images, actions = actions, extra = extra)
|
||||
loss = vat(images, actions = actions, tasks = tasks, extra = extra)
|
||||
loss.backward()
|
||||
|
||||
# after much training
|
||||
|
||||
pred_actions = vat(images)
|
||||
pred_actions = vat(images, tasks = tasks, extra = extra)
|
||||
|
||||
assert pred_actions.shape == (2, 7, 20)
|
||||
|
||||
Reference in New Issue
Block a user