mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
allow channels to be customizable for cvt
This commit is contained in:
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
|||||||
setup(
|
setup(
|
||||||
name = 'vit-pytorch',
|
name = 'vit-pytorch',
|
||||||
packages = find_packages(exclude=['examples']),
|
packages = find_packages(exclude=['examples']),
|
||||||
version = '1.6.2',
|
version = '1.6.3',
|
||||||
license='MIT',
|
license='MIT',
|
||||||
description = 'Vision Transformer (ViT) - Pytorch',
|
description = 'Vision Transformer (ViT) - Pytorch',
|
||||||
long_description_content_type = 'text/markdown',
|
long_description_content_type = 'text/markdown',
|
||||||
|
|||||||
@@ -140,12 +140,13 @@ class CvT(nn.Module):
|
|||||||
s3_heads = 6,
|
s3_heads = 6,
|
||||||
s3_depth = 10,
|
s3_depth = 10,
|
||||||
s3_mlp_mult = 4,
|
s3_mlp_mult = 4,
|
||||||
dropout = 0.
|
dropout = 0.,
|
||||||
|
channels = 3
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
kwargs = dict(locals())
|
kwargs = dict(locals())
|
||||||
|
|
||||||
dim = 3
|
dim = channels
|
||||||
layers = []
|
layers = []
|
||||||
|
|
||||||
for prefix in ('s1', 's2', 's3'):
|
for prefix in ('s1', 's2', 's3'):
|
||||||
|
|||||||
Reference in New Issue
Block a user