fix tests

This commit is contained in:
lucidrains
2024-11-10 09:43:54 -08:00
parent 0449865786
commit d47c57e32f
3 changed files with 9 additions and 7 deletions

View File

@@ -6,9 +6,6 @@ from functools import partial
import torch
import packaging.version as pkg_version
if pkg_version.parse(torch.__version__) < pkg_version.parse('2.5'):
print('nested tensor NaViT was tested on pytorch 2.5')
from torch import nn, Tensor
import torch.nn.functional as F
from torch.nn import Module, ModuleList
@@ -152,6 +149,11 @@ class NaViT(Module):
token_dropout_prob: float | None = None
):
super().__init__()
if pkg_version.parse(torch.__version__) < pkg_version.parse('2.5'):
print('nested tensor NaViT was tested on pytorch 2.5')
image_height, image_width = pair(image_size)
# what percent of tokens to dropout