diff --git a/vit_pytorch/na_vit_nested_tensor.py b/vit_pytorch/na_vit_nested_tensor.py index 81bc519..7d52334 100644 --- a/vit_pytorch/na_vit_nested_tensor.py +++ b/vit_pytorch/na_vit_nested_tensor.py @@ -323,3 +323,5 @@ if __name__ == '__main__': ] assert v(images).shape == (5, 1000) + + v(images).sum().backward() diff --git a/vit_pytorch/na_vit_nested_tensor_3d.py b/vit_pytorch/na_vit_nested_tensor_3d.py index e40758c..ff37774 100644 --- a/vit_pytorch/na_vit_nested_tensor_3d.py +++ b/vit_pytorch/na_vit_nested_tensor_3d.py @@ -336,7 +336,7 @@ class NaViT(Module): if __name__ == '__main__': - # works for torch 2.4 + # works for torch 2.5 v = NaViT( image_size = 256, @@ -362,3 +362,5 @@ if __name__ == '__main__': ] assert v(volumes).shape == (5, 1000) + + v(volumes).sum().backward()