From 0ad09c4cbc57af3c8bb404fda36a47a7bee66ece Mon Sep 17 00:00:00 2001 From: lucidrains Date: Wed, 25 Oct 2023 14:47:58 -0700 Subject: [PATCH] allow channels to be customizable for cvt --- setup.py | 2 +- vit_pytorch/cvt.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index bc19db9..db94870 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '1.6.2', + version = '1.6.3', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', long_description_content_type = 'text/markdown', diff --git a/vit_pytorch/cvt.py b/vit_pytorch/cvt.py index 6f214f7..ccf84db 100644 --- a/vit_pytorch/cvt.py +++ b/vit_pytorch/cvt.py @@ -140,12 +140,13 @@ class CvT(nn.Module): s3_heads = 6, s3_depth = 10, s3_mlp_mult = 4, - dropout = 0. + dropout = 0., + channels = 3 ): super().__init__() kwargs = dict(locals()) - dim = 3 + dim = channels layers = [] for prefix in ('s1', 's2', 's3'):