From 2c6dd7010af59a4e09c97104ccc54fca46eae42d Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 24 Jun 2022 23:28:35 -0700 Subject: [PATCH] fix hidden dimension in MaxViT thanks to @arquolo --- setup.py | 2 +- vit_pytorch/max_vit.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/setup.py b/setup.py index 3d4324f..59f676b 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.35.4', + version = '0.35.5', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', long_description_content_type = 'text/markdown', diff --git a/vit_pytorch/max_vit.py b/vit_pytorch/max_vit.py index c48b162..246adea 100644 --- a/vit_pytorch/max_vit.py +++ b/vit_pytorch/max_vit.py @@ -100,12 +100,12 @@ def MBConv( stride = 2 if downsample else 1 net = nn.Sequential( - nn.Conv2d(dim_in, dim_out, 1), - nn.BatchNorm2d(dim_out), - nn.SiLU(), - nn.Conv2d(dim_out, dim_out, 3, stride = stride, padding = 1, groups = dim_out), - SqueezeExcitation(dim_out, shrinkage_rate = shrinkage_rate), - nn.Conv2d(dim_out, dim_out, 1), + nn.Conv2d(dim_in, hidden_dim, 1), + nn.BatchNorm2d(hidden_dim), + nn.GELU(), + nn.Conv2d(hidden_dim, hidden_dim, 3, stride = stride, padding = 1, groups = dim_out), + SqueezeExcitation(hidden_dim, shrinkage_rate = shrinkage_rate), + nn.Conv2d(hidden_dim, dim_out, 1), nn.BatchNorm2d(dim_out) )