diff --git a/setup.py b/setup.py index bc19db9..db94870 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '1.6.2', + version = '1.6.3', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', long_description_content_type = 'text/markdown', diff --git a/vit_pytorch/cvt.py b/vit_pytorch/cvt.py index 6f214f7..ccf84db 100644 --- a/vit_pytorch/cvt.py +++ b/vit_pytorch/cvt.py @@ -140,12 +140,13 @@ class CvT(nn.Module): s3_heads = 6, s3_depth = 10, s3_mlp_mult = 4, - dropout = 0. + dropout = 0., + channels = 3 ): super().__init__() kwargs = dict(locals()) - dim = 3 + dim = channels layers = [] for prefix in ('s1', 's2', 's3'):