mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
forgot task conditioning for vat
This commit is contained in:
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "vit-pytorch"
|
name = "vit-pytorch"
|
||||||
version = "1.14.1"
|
version = "1.14.2"
|
||||||
description = "Vision Transformer (ViT) - Pytorch"
|
description = "Vision Transformer (ViT) - Pytorch"
|
||||||
readme = { file = "README.md", content-type = "text/markdown" }
|
readme = { file = "README.md", content-type = "text/markdown" }
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
|
|||||||
@@ -21,6 +21,27 @@ def pair(t):
|
|||||||
|
|
||||||
# classes
|
# classes
|
||||||
|
|
||||||
|
class FiLM(Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
proj = nn.Linear(dim, dim * 2)
|
||||||
|
|
||||||
|
self.to_gamma_beta = nn.Sequential(
|
||||||
|
proj,
|
||||||
|
Rearrange('b (two d) -> two b 1 d', two = 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
nn.init.zeros_(proj.weight)
|
||||||
|
nn.init.zeros_(proj.bias)
|
||||||
|
|
||||||
|
def forward(self, tokens, cond):
|
||||||
|
gamma, beta = self.to_gamma_beta(cond)
|
||||||
|
|
||||||
|
return tokens * gamma + beta
|
||||||
|
|
||||||
class FeedForward(Module):
|
class FeedForward(Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -228,6 +249,7 @@ class VAT(Module):
|
|||||||
dim_action,
|
dim_action,
|
||||||
mlp_dim,
|
mlp_dim,
|
||||||
num_views = None,
|
num_views = None,
|
||||||
|
num_tasks = None,
|
||||||
dim_extra_token = None,
|
dim_extra_token = None,
|
||||||
action_chunk_len = 7,
|
action_chunk_len = 7,
|
||||||
time_seq_len = 1,
|
time_seq_len = 1,
|
||||||
@@ -266,6 +288,13 @@ class VAT(Module):
|
|||||||
|
|
||||||
self.view_emb = nn.Parameter(torch.randn(num_views, vit_dim) * 1e-2) if exists(num_views) and num_views > 1 else None
|
self.view_emb = nn.Parameter(torch.randn(num_views, vit_dim) * 1e-2) if exists(num_views) and num_views > 1 else None
|
||||||
|
|
||||||
|
# handle maybe task conditioning
|
||||||
|
|
||||||
|
self.has_tasks = exists(num_tasks)
|
||||||
|
|
||||||
|
if self.has_tasks:
|
||||||
|
self.task_emb = nn.Parameter(torch.randn(num_tasks, dim) * 1e-2)
|
||||||
|
|
||||||
# to action tokens
|
# to action tokens
|
||||||
|
|
||||||
self.action_pos_emb = nn.Parameter(torch.randn(action_chunk_len, dim) * 1e-2)
|
self.action_pos_emb = nn.Parameter(torch.randn(action_chunk_len, dim) * 1e-2)
|
||||||
@@ -273,9 +302,11 @@ class VAT(Module):
|
|||||||
self.layers = ModuleList([])
|
self.layers = ModuleList([])
|
||||||
|
|
||||||
for _ in range(depth):
|
for _ in range(depth):
|
||||||
|
maybe_film = FiLM(dim = dim) if self.has_tasks else None
|
||||||
maybe_self_attn = Attention(dim = dim, heads = self_attn_heads, dim_head = self_attn_dim_head, dropout = dropout) if add_self_attn else None
|
maybe_self_attn = Attention(dim = dim, heads = self_attn_heads, dim_head = self_attn_dim_head, dropout = dropout) if add_self_attn else None
|
||||||
|
|
||||||
self.layers.append(ModuleList([
|
self.layers.append(ModuleList([
|
||||||
|
maybe_film,
|
||||||
maybe_self_attn,
|
maybe_self_attn,
|
||||||
Attention(dim = dim, heads = heads, dim_head = dim_head, dropout = dropout, cross_attend = True),
|
Attention(dim = dim, heads = heads, dim_head = dim_head, dropout = dropout, cross_attend = True),
|
||||||
FeedForward(dim = dim, hidden_dim = mlp_dim, dropout = dropout)
|
FeedForward(dim = dim, hidden_dim = mlp_dim, dropout = dropout)
|
||||||
@@ -294,7 +325,9 @@ class VAT(Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
video_or_image, # (b v? c t? h w) - batch, views [wrist + third person or more], channels, maybe time, height, width
|
video_or_image, # (b v? c t? h w) - batch, views [wrist + third person or more], channels, maybe time, height, width
|
||||||
|
*,
|
||||||
extra = None, # (b d) - batch, dim extra
|
extra = None, # (b d) - batch, dim extra
|
||||||
|
tasks = None, # (b)
|
||||||
actions = None, # (b k d) - batch, action chunk length, action dimension
|
actions = None, # (b k d) - batch, action chunk length, action dimension
|
||||||
):
|
):
|
||||||
batch = video_or_image.shape[0]
|
batch = video_or_image.shape[0]
|
||||||
@@ -349,6 +382,13 @@ class VAT(Module):
|
|||||||
view_emb = rearrange(self.view_emb, 'v d -> v 1 1 d')
|
view_emb = rearrange(self.view_emb, 'v d -> v 1 1 d')
|
||||||
hiddens = hiddens + view_emb
|
hiddens = hiddens + view_emb
|
||||||
|
|
||||||
|
# maybe tasks
|
||||||
|
|
||||||
|
if exists(tasks):
|
||||||
|
assert self.has_tasks, f'`num_tasks` must be set on `VAT` for task conditioning'
|
||||||
|
|
||||||
|
task_emb = self.task_emb[tasks]
|
||||||
|
|
||||||
# cross from actions to representation trajectory
|
# cross from actions to representation trajectory
|
||||||
|
|
||||||
context = rearrange(hiddens, 'l b v t n d -> l b (v t n) d')
|
context = rearrange(hiddens, 'l b v t n d -> l b (v t n) d')
|
||||||
@@ -368,7 +408,10 @@ class VAT(Module):
|
|||||||
|
|
||||||
# cross attention
|
# cross attention
|
||||||
|
|
||||||
for (maybe_self_attn, cross_attn, ff), layer_context in zip(self.layers, context):
|
for (maybe_film, maybe_self_attn, cross_attn, ff), layer_context in zip(self.layers, context):
|
||||||
|
|
||||||
|
if exists(tasks):
|
||||||
|
action_tokens = maybe_film(action_tokens, task_emb)
|
||||||
|
|
||||||
action_tokens = cross_attn(action_tokens, layer_context) + action_tokens
|
action_tokens = cross_attn(action_tokens, layer_context) + action_tokens
|
||||||
|
|
||||||
@@ -422,6 +465,7 @@ if __name__ == '__main__':
|
|||||||
action_chunk_len = 7,
|
action_chunk_len = 7,
|
||||||
time_seq_len = 4,
|
time_seq_len = 4,
|
||||||
num_views = 2,
|
num_views = 2,
|
||||||
|
num_tasks = 4,
|
||||||
add_self_attn = True,
|
add_self_attn = True,
|
||||||
dim_extra_token = 33, # extra token with some variable dimension
|
dim_extra_token = 33, # extra token with some variable dimension
|
||||||
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
|
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
|
||||||
@@ -430,15 +474,16 @@ if __name__ == '__main__':
|
|||||||
)
|
)
|
||||||
|
|
||||||
images = torch.randn(2, 2, 3, 4, 256, 256) # (2 views with 4 frames)
|
images = torch.randn(2, 2, 3, 4, 256, 256) # (2 views with 4 frames)
|
||||||
|
tasks = torch.randint(0, 4, (2,))
|
||||||
extra = torch.randn(2, 33) # extra internal state
|
extra = torch.randn(2, 33) # extra internal state
|
||||||
|
|
||||||
actions = torch.randn(2, 7, 20) # actions for learning
|
actions = torch.randn(2, 7, 20) # actions for learning
|
||||||
|
|
||||||
loss = vat(images, actions = actions, extra = extra)
|
loss = vat(images, actions = actions, tasks = tasks, extra = extra)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
# after much training
|
# after much training
|
||||||
|
|
||||||
pred_actions = vat(images)
|
pred_actions = vat(images, tasks = tasks, extra = extra)
|
||||||
|
|
||||||
assert pred_actions.shape == (2, 7, 20)
|
assert pred_actions.shape == (2, 7, 20)
|
||||||
|
|||||||
Reference in New Issue
Block a user