From 0449865786b9d137d68c49dcf7160b07bd3e9b28 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sun, 10 Nov 2024 09:37:48 -0800 Subject: [PATCH] update minimum version for nested tensor of NaViT --- .github/workflows/python-test.yml | 2 +- vit_pytorch/na_vit_nested_tensor.py | 4 ++-- vit_pytorch/na_vit_nested_tensor_3d.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/python-test.yml b/.github/workflows/python-test.yml index fe54d29..7338714 100644 --- a/.github/workflows/python-test.yml +++ b/.github/workflows/python-test.yml @@ -28,7 +28,7 @@ jobs: python -m pip install --upgrade pip python -m pip install pytest python -m pip install wheel - python -m pip install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cpu + python -m pip install torch==2.5.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cpu if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - name: Test with pytest run: | diff --git a/vit_pytorch/na_vit_nested_tensor.py b/vit_pytorch/na_vit_nested_tensor.py index 7d52334..6084e6c 100644 --- a/vit_pytorch/na_vit_nested_tensor.py +++ b/vit_pytorch/na_vit_nested_tensor.py @@ -6,8 +6,8 @@ from functools import partial import torch import packaging.version as pkg_version -if pkg_version.parse(torch.__version__) < pkg_version.parse('2.4'): - print('nested tensor NaViT was tested on pytorch 2.4') +if pkg_version.parse(torch.__version__) < pkg_version.parse('2.5'): + print('nested tensor NaViT was tested on pytorch 2.5') from torch import nn, Tensor import torch.nn.functional as F diff --git a/vit_pytorch/na_vit_nested_tensor_3d.py b/vit_pytorch/na_vit_nested_tensor_3d.py index ff37774..7722c0f 100644 --- a/vit_pytorch/na_vit_nested_tensor_3d.py +++ b/vit_pytorch/na_vit_nested_tensor_3d.py @@ -6,8 +6,8 @@ from functools import partial import torch import packaging.version as pkg_version -if pkg_version.parse(torch.__version__) < pkg_version.parse('2.4'): - print('nested tensor NaViT was tested on pytorch 2.4') +if pkg_version.parse(torch.__version__) < pkg_version.parse('2.5'): + print('nested tensor NaViT was tested on pytorch 2.5') from torch import nn, Tensor import torch.nn.functional as F