From 9f87d1c43b723e59e3f767119ff6e8ed2b0cabf0 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 29 Jun 2022 08:53:09 -0700 Subject: [PATCH] follow @arquolo feedback and advice for MaxViT --- setup.py | 2 +- vit_pytorch/max_vit.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 59f676b..5a41575 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.35.5', + version = '0.35.6', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', long_description_content_type = 'text/markdown', diff --git a/vit_pytorch/max_vit.py b/vit_pytorch/max_vit.py index 246adea..8359f0c 100644 --- a/vit_pytorch/max_vit.py +++ b/vit_pytorch/max_vit.py @@ -103,7 +103,9 @@ def MBConv( nn.Conv2d(dim_in, hidden_dim, 1), nn.BatchNorm2d(hidden_dim), nn.GELU(), - nn.Conv2d(hidden_dim, hidden_dim, 3, stride = stride, padding = 1, groups = dim_out), + nn.Conv2d(hidden_dim, hidden_dim, 3, stride = stride, padding = 1, groups = hidden_dim), + nn.BatchNorm2d(hidden_dim), + nn.GELU(), SqueezeExcitation(hidden_dim, shrinkage_rate = shrinkage_rate), nn.Conv2d(hidden_dim, dim_out, 1), nn.BatchNorm2d(dim_out)