mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ce4bcd08fb | ||
|
|
ad4ca19775 | ||
|
|
e1b08c15b9 |
1
.github/workflows/python-test.yml
vendored
1
.github/workflows/python-test.yml
vendored
@@ -27,6 +27,7 @@ jobs:
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install pytest
|
||||
python -m pip install wheel
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
|
||||
4
setup.py
4
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '1.2.0',
|
||||
version = '1.2.2',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
long_description_content_type = 'text/markdown',
|
||||
@@ -16,7 +16,7 @@ setup(
|
||||
'image recognition'
|
||||
],
|
||||
install_requires=[
|
||||
'einops>=0.6.0',
|
||||
'einops>=0.6.1',
|
||||
'torch>=1.10',
|
||||
'torchvision'
|
||||
],
|
||||
|
||||
@@ -1,3 +1,10 @@
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse('2.0.0'):
|
||||
from einops._torch_specific import allow_ops_in_compiled_graph
|
||||
allow_ops_in_compiled_graph()
|
||||
|
||||
from vit_pytorch.vit import ViT
|
||||
from vit_pytorch.simple_vit import SimpleViT
|
||||
|
||||
|
||||
@@ -146,7 +146,7 @@ class ViT(nn.Module):
|
||||
x = self.to_patch_embedding(video)
|
||||
b, f, n, _ = x.shape
|
||||
|
||||
x = x + self.pos_embedding
|
||||
x = x + self.pos_embedding[:, :f, :n]
|
||||
|
||||
if exists(self.spatial_cls_token):
|
||||
spatial_cls_tokens = repeat(self.spatial_cls_token, '1 1 d -> b f 1 d', b = b, f = f)
|
||||
|
||||
Reference in New Issue
Block a user