diff --git a/.github/workflows/python-test.yml b/.github/workflows/python-test.yml index fe54d29..7338714 100644 --- a/.github/workflows/python-test.yml +++ b/.github/workflows/python-test.yml @@ -28,7 +28,7 @@ jobs: python -m pip install --upgrade pip python -m pip install pytest python -m pip install wheel - python -m pip install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cpu + python -m pip install torch==2.5.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cpu if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - name: Test with pytest run: | diff --git a/vit_pytorch/na_vit_nested_tensor.py b/vit_pytorch/na_vit_nested_tensor.py index 7d52334..6084e6c 100644 --- a/vit_pytorch/na_vit_nested_tensor.py +++ b/vit_pytorch/na_vit_nested_tensor.py @@ -6,8 +6,8 @@ from functools import partial import torch import packaging.version as pkg_version -if pkg_version.parse(torch.__version__) < pkg_version.parse('2.4'): - print('nested tensor NaViT was tested on pytorch 2.4') +if pkg_version.parse(torch.__version__) < pkg_version.parse('2.5'): + print('nested tensor NaViT was tested on pytorch 2.5') from torch import nn, Tensor import torch.nn.functional as F diff --git a/vit_pytorch/na_vit_nested_tensor_3d.py b/vit_pytorch/na_vit_nested_tensor_3d.py index ff37774..7722c0f 100644 --- a/vit_pytorch/na_vit_nested_tensor_3d.py +++ b/vit_pytorch/na_vit_nested_tensor_3d.py @@ -6,8 +6,8 @@ from functools import partial import torch import packaging.version as pkg_version -if pkg_version.parse(torch.__version__) < pkg_version.parse('2.4'): - print('nested tensor NaViT was tested on pytorch 2.4') +if pkg_version.parse(torch.__version__) < pkg_version.parse('2.5'): + print('nested tensor NaViT was tested on pytorch 2.5') from torch import nn, Tensor import torch.nn.functional as F