mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-29 23:52:27 +00:00
last tweak to vat
This commit is contained in:
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "vit-pytorch"
|
||||
version = "1.14.2"
|
||||
version = "1.14.4"
|
||||
description = "Vision Transformer (ViT) - Pytorch"
|
||||
readme = { file = "README.md", content-type = "text/markdown" }
|
||||
license = { file = "LICENSE" }
|
||||
|
||||
@@ -329,6 +329,7 @@ class VAT(Module):
|
||||
extra = None, # (b d) - batch, dim extra
|
||||
tasks = None, # (b)
|
||||
actions = None, # (b k d) - batch, action chunk length, action dimension
|
||||
return_hiddens = False
|
||||
):
|
||||
batch = video_or_image.shape[0]
|
||||
return_loss = exists(actions)
|
||||
@@ -408,6 +409,8 @@ class VAT(Module):
|
||||
|
||||
# cross attention
|
||||
|
||||
hiddens = [action_tokens]
|
||||
|
||||
for (maybe_film, maybe_self_attn, cross_attn, ff), layer_context in zip(self.layers, context):
|
||||
|
||||
if exists(tasks):
|
||||
@@ -420,6 +423,8 @@ class VAT(Module):
|
||||
|
||||
action_tokens = ff(action_tokens) + action_tokens
|
||||
|
||||
hiddens.append(action_tokens)
|
||||
|
||||
# maybe unpack extra
|
||||
|
||||
if has_extra:
|
||||
@@ -432,7 +437,10 @@ class VAT(Module):
|
||||
pred_action = self.to_pred_action(action_tokens)
|
||||
|
||||
if not return_loss:
|
||||
return pred_action
|
||||
if not return_hiddens:
|
||||
return pred_action
|
||||
|
||||
return pred_action, stack(hiddens)
|
||||
|
||||
assert pred_action.shape[1] == actions.shape[1]
|
||||
|
||||
@@ -484,6 +492,6 @@ if __name__ == '__main__':
|
||||
|
||||
# after much training
|
||||
|
||||
pred_actions = vat(images, tasks = tasks, extra = extra)
|
||||
pred_actions, hiddens = vat(images, tasks = tasks, extra = extra, return_hiddens = True)
|
||||
|
||||
assert pred_actions.shape == (2, 7, 20)
|
||||
|
||||
Reference in New Issue
Block a user