update comment for navit 3d

This commit is contained in:
lucidrains
2024-11-07 20:02:02 -08:00
parent 141239ca86
commit 6693d47d0b
2 changed files with 5 additions and 1 deletions

View File

@@ -323,3 +323,5 @@ if __name__ == '__main__':
]
assert v(images).shape == (5, 1000)
v(images).sum().backward()

View File

@@ -336,7 +336,7 @@ class NaViT(Module):
if __name__ == '__main__':
# works for torch 2.4
# works for torch 2.5
v = NaViT(
image_size = 256,
@@ -362,3 +362,5 @@ if __name__ == '__main__':
]
assert v(volumes).shape == (5, 1000)
v(volumes).sum().backward()