diff --git a/setup.py b/setup.py index 3d4324f..59f676b 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.35.4', + version = '0.35.5', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', long_description_content_type = 'text/markdown', diff --git a/vit_pytorch/max_vit.py b/vit_pytorch/max_vit.py index c48b162..246adea 100644 --- a/vit_pytorch/max_vit.py +++ b/vit_pytorch/max_vit.py @@ -100,12 +100,12 @@ def MBConv( stride = 2 if downsample else 1 net = nn.Sequential( - nn.Conv2d(dim_in, dim_out, 1), - nn.BatchNorm2d(dim_out), - nn.SiLU(), - nn.Conv2d(dim_out, dim_out, 3, stride = stride, padding = 1, groups = dim_out), - SqueezeExcitation(dim_out, shrinkage_rate = shrinkage_rate), - nn.Conv2d(dim_out, dim_out, 1), + nn.Conv2d(dim_in, hidden_dim, 1), + nn.BatchNorm2d(hidden_dim), + nn.GELU(), + nn.Conv2d(hidden_dim, hidden_dim, 3, stride = stride, padding = 1, groups = dim_out), + SqueezeExcitation(hidden_dim, shrinkage_rate = shrinkage_rate), + nn.Conv2d(hidden_dim, dim_out, 1), nn.BatchNorm2d(dim_out) )