From a583cb59882f32b5f949394cea1bf76e6bf0592e Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 23 Oct 2025 12:21:09 -0700 Subject: [PATCH] last tweak to vat --- pyproject.toml | 2 +- vit_pytorch/vat.py | 12 ++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f0d190e..7ae8399 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "vit-pytorch" -version = "1.14.2" +version = "1.14.4" description = "Vision Transformer (ViT) - Pytorch" readme = { file = "README.md", content-type = "text/markdown" } license = { file = "LICENSE" } diff --git a/vit_pytorch/vat.py b/vit_pytorch/vat.py index be61c2f..aa8f529 100644 --- a/vit_pytorch/vat.py +++ b/vit_pytorch/vat.py @@ -329,6 +329,7 @@ class VAT(Module): extra = None, # (b d) - batch, dim extra tasks = None, # (b) actions = None, # (b k d) - batch, action chunk length, action dimension + return_hiddens = False ): batch = video_or_image.shape[0] return_loss = exists(actions) @@ -408,6 +409,8 @@ class VAT(Module): # cross attention + hiddens = [action_tokens] + for (maybe_film, maybe_self_attn, cross_attn, ff), layer_context in zip(self.layers, context): if exists(tasks): @@ -420,6 +423,8 @@ class VAT(Module): action_tokens = ff(action_tokens) + action_tokens + hiddens.append(action_tokens) + # maybe unpack extra if has_extra: @@ -432,7 +437,10 @@ class VAT(Module): pred_action = self.to_pred_action(action_tokens) if not return_loss: - return pred_action + if not return_hiddens: + return pred_action + + return pred_action, stack(hiddens) assert pred_action.shape[1] == actions.shape[1] @@ -484,6 +492,6 @@ if __name__ == '__main__': # after much training - pred_actions = vat(images, tasks = tasks, extra = extra) + pred_actions, hiddens = vat(images, tasks = tasks, extra = extra, return_hiddens = True) assert pred_actions.shape == (2, 7, 20)