mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
update comment for navit 3d
This commit is contained in:
@@ -323,3 +323,5 @@ if __name__ == '__main__':
|
||||
]
|
||||
|
||||
assert v(images).shape == (5, 1000)
|
||||
|
||||
v(images).sum().backward()
|
||||
|
||||
@@ -336,7 +336,7 @@ class NaViT(Module):
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# works for torch 2.4
|
||||
# works for torch 2.5
|
||||
|
||||
v = NaViT(
|
||||
image_size = 256,
|
||||
@@ -362,3 +362,5 @@ if __name__ == '__main__':
|
||||
]
|
||||
|
||||
assert v(volumes).shape == (5, 1000)
|
||||
|
||||
v(volumes).sum().backward()
|
||||
|
||||
Reference in New Issue
Block a user