From 6693d47d0bcd6789a51fda149736fa1e2bbbb760 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 7 Nov 2024 20:02:02 -0800 Subject: [PATCH] update comment for navit 3d --- vit_pytorch/na_vit_nested_tensor.py | 2 ++ vit_pytorch/na_vit_nested_tensor_3d.py | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/vit_pytorch/na_vit_nested_tensor.py b/vit_pytorch/na_vit_nested_tensor.py index 81bc519..7d52334 100644 --- a/vit_pytorch/na_vit_nested_tensor.py +++ b/vit_pytorch/na_vit_nested_tensor.py @@ -323,3 +323,5 @@ if __name__ == '__main__': ] assert v(images).shape == (5, 1000) + + v(images).sum().backward() diff --git a/vit_pytorch/na_vit_nested_tensor_3d.py b/vit_pytorch/na_vit_nested_tensor_3d.py index e40758c..ff37774 100644 --- a/vit_pytorch/na_vit_nested_tensor_3d.py +++ b/vit_pytorch/na_vit_nested_tensor_3d.py @@ -336,7 +336,7 @@ class NaViT(Module): if __name__ == '__main__': - # works for torch 2.4 + # works for torch 2.5 v = NaViT( image_size = 256, @@ -362,3 +362,5 @@ if __name__ == '__main__': ] assert v(volumes).shape == (5, 1000) + + v(volumes).sum().backward()