fix hidden dimension in MaxViT thanks to @arquolo

This commit is contained in:
Phil Wang
2022-06-24 23:28:35 -07:00
parent 6460119f65
commit 2c6dd7010a
2 changed files with 7 additions and 7 deletions

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.35.4',
version = '0.35.5',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description_content_type = 'text/markdown',

View File

@@ -100,12 +100,12 @@ def MBConv(
stride = 2 if downsample else 1
net = nn.Sequential(
nn.Conv2d(dim_in, dim_out, 1),
nn.BatchNorm2d(dim_out),
nn.SiLU(),
nn.Conv2d(dim_out, dim_out, 3, stride = stride, padding = 1, groups = dim_out),
SqueezeExcitation(dim_out, shrinkage_rate = shrinkage_rate),
nn.Conv2d(dim_out, dim_out, 1),
nn.Conv2d(dim_in, hidden_dim, 1),
nn.BatchNorm2d(hidden_dim),
nn.GELU(),
nn.Conv2d(hidden_dim, hidden_dim, 3, stride = stride, padding = 1, groups = dim_out),
SqueezeExcitation(hidden_dim, shrinkage_rate = shrinkage_rate),
nn.Conv2d(hidden_dim, dim_out, 1),
nn.BatchNorm2d(dim_out)
)