last tweak to vat

This commit is contained in:
lucidrains
2025-10-23 12:21:09 -07:00
parent 25871013f5
commit a583cb5988
2 changed files with 11 additions and 3 deletions

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "vit-pytorch" name = "vit-pytorch"
version = "1.14.2" version = "1.14.4"
description = "Vision Transformer (ViT) - Pytorch" description = "Vision Transformer (ViT) - Pytorch"
readme = { file = "README.md", content-type = "text/markdown" } readme = { file = "README.md", content-type = "text/markdown" }
license = { file = "LICENSE" } license = { file = "LICENSE" }

View File

@@ -329,6 +329,7 @@ class VAT(Module):
extra = None, # (b d) - batch, dim extra extra = None, # (b d) - batch, dim extra
tasks = None, # (b) tasks = None, # (b)
actions = None, # (b k d) - batch, action chunk length, action dimension actions = None, # (b k d) - batch, action chunk length, action dimension
return_hiddens = False
): ):
batch = video_or_image.shape[0] batch = video_or_image.shape[0]
return_loss = exists(actions) return_loss = exists(actions)
@@ -408,6 +409,8 @@ class VAT(Module):
# cross attention # cross attention
hiddens = [action_tokens]
for (maybe_film, maybe_self_attn, cross_attn, ff), layer_context in zip(self.layers, context): for (maybe_film, maybe_self_attn, cross_attn, ff), layer_context in zip(self.layers, context):
if exists(tasks): if exists(tasks):
@@ -420,6 +423,8 @@ class VAT(Module):
action_tokens = ff(action_tokens) + action_tokens action_tokens = ff(action_tokens) + action_tokens
hiddens.append(action_tokens)
# maybe unpack extra # maybe unpack extra
if has_extra: if has_extra:
@@ -432,8 +437,11 @@ class VAT(Module):
pred_action = self.to_pred_action(action_tokens) pred_action = self.to_pred_action(action_tokens)
if not return_loss: if not return_loss:
if not return_hiddens:
return pred_action return pred_action
return pred_action, stack(hiddens)
assert pred_action.shape[1] == actions.shape[1] assert pred_action.shape[1] == actions.shape[1]
# they found l1 loss suffices # they found l1 loss suffices
@@ -484,6 +492,6 @@ if __name__ == '__main__':
# after much training # after much training
pred_actions = vat(images, tasks = tasks, extra = extra) pred_actions, hiddens = vat(images, tasks = tasks, extra = extra, return_hiddens = True)
assert pred_actions.shape == (2, 7, 20) assert pred_actions.shape == (2, 7, 20)