Compare commits

...

3 Commits
1.2.0 ... 1.2.2

Author SHA1 Message Date
Phil Wang
ce4bcd08fb address https://github.com/lucidrains/vit-pytorch/issues/266 2023-05-20 08:24:49 -07:00
Phil Wang
ad4ca19775 enforce latest einops 2023-05-08 09:34:14 -07:00
Phil Wang
e1b08c15b9 fix tests 2023-03-19 10:52:47 -07:00
4 changed files with 11 additions and 3 deletions

View File

@@ -27,6 +27,7 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install pytest
python -m pip install wheel
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Test with pytest
run: |

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.2.0',
version = '1.2.2',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description_content_type = 'text/markdown',
@@ -16,7 +16,7 @@ setup(
'image recognition'
],
install_requires=[
'einops>=0.6.0',
'einops>=0.6.1',
'torch>=1.10',
'torchvision'
],

View File

@@ -1,3 +1,10 @@
import torch
from packaging import version
if version.parse(torch.__version__) >= version.parse('2.0.0'):
from einops._torch_specific import allow_ops_in_compiled_graph
allow_ops_in_compiled_graph()
from vit_pytorch.vit import ViT
from vit_pytorch.simple_vit import SimpleViT

View File

@@ -146,7 +146,7 @@ class ViT(nn.Module):
x = self.to_patch_embedding(video)
b, f, n, _ = x.shape
x = x + self.pos_embedding
x = x + self.pos_embedding[:, :f, :n]
if exists(self.spatial_cls_token):
spatial_cls_tokens = repeat(self.spatial_cls_token, '1 1 d -> b f 1 d', b = b, f = f)