allow channels to be customizable for cvt

This commit is contained in:
lucidrains
2023-10-25 14:47:58 -07:00
parent 92b69321f4
commit 0ad09c4cbc
2 changed files with 4 additions and 3 deletions

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup( setup(
name = 'vit-pytorch', name = 'vit-pytorch',
packages = find_packages(exclude=['examples']), packages = find_packages(exclude=['examples']),
version = '1.6.2', version = '1.6.3',
license='MIT', license='MIT',
description = 'Vision Transformer (ViT) - Pytorch', description = 'Vision Transformer (ViT) - Pytorch',
long_description_content_type = 'text/markdown', long_description_content_type = 'text/markdown',

View File

@@ -140,12 +140,13 @@ class CvT(nn.Module):
s3_heads = 6, s3_heads = 6,
s3_depth = 10, s3_depth = 10,
s3_mlp_mult = 4, s3_mlp_mult = 4,
dropout = 0. dropout = 0.,
channels = 3
): ):
super().__init__() super().__init__()
kwargs = dict(locals()) kwargs = dict(locals())
dim = 3 dim = channels
layers = [] layers = []
for prefix in ('s1', 's2', 's3'): for prefix in ('s1', 's2', 's3'):