last tweak to vat

This commit is contained in:
lucidrains
2025-10-23 12:21:09 -07:00
parent 25871013f5
commit a583cb5988
2 changed files with 11 additions and 3 deletions

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "vit-pytorch"
version = "1.14.2"
version = "1.14.4"
description = "Vision Transformer (ViT) - Pytorch"
readme = { file = "README.md", content-type = "text/markdown" }
license = { file = "LICENSE" }

View File

@@ -329,6 +329,7 @@ class VAT(Module):
extra = None, # (b d) - batch, dim extra
tasks = None, # (b)
actions = None, # (b k d) - batch, action chunk length, action dimension
return_hiddens = False
):
batch = video_or_image.shape[0]
return_loss = exists(actions)
@@ -408,6 +409,8 @@ class VAT(Module):
# cross attention
hiddens = [action_tokens]
for (maybe_film, maybe_self_attn, cross_attn, ff), layer_context in zip(self.layers, context):
if exists(tasks):
@@ -420,6 +423,8 @@ class VAT(Module):
action_tokens = ff(action_tokens) + action_tokens
hiddens.append(action_tokens)
# maybe unpack extra
if has_extra:
@@ -432,7 +437,10 @@ class VAT(Module):
pred_action = self.to_pred_action(action_tokens)
if not return_loss:
return pred_action
if not return_hiddens:
return pred_action
return pred_action, stack(hiddens)
assert pred_action.shape[1] == actions.shape[1]
@@ -484,6 +492,6 @@ if __name__ == '__main__':
# after much training
pred_actions = vat(images, tasks = tasks, extra = extra)
pred_actions, hiddens = vat(images, tasks = tasks, extra = extra, return_hiddens = True)
assert pred_actions.shape == (2, 7, 20)