mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
last tweak to vat
This commit is contained in:
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "vit-pytorch"
|
name = "vit-pytorch"
|
||||||
version = "1.14.2"
|
version = "1.14.4"
|
||||||
description = "Vision Transformer (ViT) - Pytorch"
|
description = "Vision Transformer (ViT) - Pytorch"
|
||||||
readme = { file = "README.md", content-type = "text/markdown" }
|
readme = { file = "README.md", content-type = "text/markdown" }
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
|
|||||||
@@ -329,6 +329,7 @@ class VAT(Module):
|
|||||||
extra = None, # (b d) - batch, dim extra
|
extra = None, # (b d) - batch, dim extra
|
||||||
tasks = None, # (b)
|
tasks = None, # (b)
|
||||||
actions = None, # (b k d) - batch, action chunk length, action dimension
|
actions = None, # (b k d) - batch, action chunk length, action dimension
|
||||||
|
return_hiddens = False
|
||||||
):
|
):
|
||||||
batch = video_or_image.shape[0]
|
batch = video_or_image.shape[0]
|
||||||
return_loss = exists(actions)
|
return_loss = exists(actions)
|
||||||
@@ -408,6 +409,8 @@ class VAT(Module):
|
|||||||
|
|
||||||
# cross attention
|
# cross attention
|
||||||
|
|
||||||
|
hiddens = [action_tokens]
|
||||||
|
|
||||||
for (maybe_film, maybe_self_attn, cross_attn, ff), layer_context in zip(self.layers, context):
|
for (maybe_film, maybe_self_attn, cross_attn, ff), layer_context in zip(self.layers, context):
|
||||||
|
|
||||||
if exists(tasks):
|
if exists(tasks):
|
||||||
@@ -420,6 +423,8 @@ class VAT(Module):
|
|||||||
|
|
||||||
action_tokens = ff(action_tokens) + action_tokens
|
action_tokens = ff(action_tokens) + action_tokens
|
||||||
|
|
||||||
|
hiddens.append(action_tokens)
|
||||||
|
|
||||||
# maybe unpack extra
|
# maybe unpack extra
|
||||||
|
|
||||||
if has_extra:
|
if has_extra:
|
||||||
@@ -432,8 +437,11 @@ class VAT(Module):
|
|||||||
pred_action = self.to_pred_action(action_tokens)
|
pred_action = self.to_pred_action(action_tokens)
|
||||||
|
|
||||||
if not return_loss:
|
if not return_loss:
|
||||||
|
if not return_hiddens:
|
||||||
return pred_action
|
return pred_action
|
||||||
|
|
||||||
|
return pred_action, stack(hiddens)
|
||||||
|
|
||||||
assert pred_action.shape[1] == actions.shape[1]
|
assert pred_action.shape[1] == actions.shape[1]
|
||||||
|
|
||||||
# they found l1 loss suffices
|
# they found l1 loss suffices
|
||||||
@@ -484,6 +492,6 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
# after much training
|
# after much training
|
||||||
|
|
||||||
pred_actions = vat(images, tasks = tasks, extra = extra)
|
pred_actions, hiddens = vat(images, tasks = tasks, extra = extra, return_hiddens = True)
|
||||||
|
|
||||||
assert pred_actions.shape == (2, 7, 20)
|
assert pred_actions.shape == (2, 7, 20)
|
||||||
|
|||||||
Reference in New Issue
Block a user