From 976f48923088e5240d798bc6e7e60f55456cf666 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 22 Dec 2021 09:13:31 -0800 Subject: [PATCH] add some tests --- .github/workflows/python-test.yml | 33 +++++++++++++++++++++++++++++++ setup.py | 8 +++++++- tests/test.py | 20 +++++++++++++++++++ 3 files changed, 60 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/python-test.yml create mode 100644 tests/test.py diff --git a/.github/workflows/python-test.yml b/.github/workflows/python-test.yml new file mode 100644 index 0000000..f5079cb --- /dev/null +++ b/.github/workflows/python-test.yml @@ -0,0 +1,33 @@ +# This workflow will install Python dependencies, run tests and lint with a variety of Python versions +# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + +name: Test + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + build: + + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.7, 3.8, 3.9] + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install pytest + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + - name: Test with pytest + run: | + python setup.py test diff --git a/setup.py b/setup.py index 082e6b6..c57b9c3 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.25.1', + version = '0.25.3', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', author = 'Phil Wang', @@ -19,6 +19,12 @@ setup( 'torch>=1.6', 'torchvision' ], + setup_requires=[ + 'pytest-runner', + ], + tests_require=[ + 'pytest' + ], classifiers=[ 'Development Status :: 4 - Beta', 'Intended Audience :: Developers', diff --git a/tests/test.py b/tests/test.py new file mode 100644 index 0000000..b0e9d77 --- /dev/null +++ b/tests/test.py @@ -0,0 +1,20 @@ +import torch +from vit_pytorch import ViT + +def test(): + v = ViT( + image_size = 256, + patch_size = 32, + num_classes = 1000, + dim = 1024, + depth = 6, + heads = 16, + mlp_dim = 2048, + dropout = 0.1, + emb_dropout = 0.1 + ) + + img = torch.randn(1, 3, 256, 256) + + preds = v(img) + assert preds.shape == (1, 1000), 'correct logits outputted'