diff --git a/setup.py b/setup.py index a0a3a5b..990dad8 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.28.0', + version = '0.28.1', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', author = 'Phil Wang', diff --git a/vit_pytorch/scalable_vit.py b/vit_pytorch/scalable_vit.py index 3cf4650..11a21e0 100644 --- a/vit_pytorch/scalable_vit.py +++ b/vit_pytorch/scalable_vit.py @@ -81,8 +81,8 @@ class ScalableSelfAttention(nn.Module): self, dim, heads = 8, - dim_key = 64, - dim_value = 64, + dim_key = 32, + dim_value = 32, dropout = 0., reduction_factor = 1 ): @@ -132,8 +132,8 @@ class InteractiveWindowedSelfAttention(nn.Module): dim, window_size, heads = 8, - dim_key = 64, - dim_value = 64, + dim_key = 32, + dim_value = 32, dropout = 0. ): super().__init__() @@ -199,12 +199,12 @@ class Transformer(nn.Module): heads = 8, ff_expansion_factor = 4, dropout = 0., - ssa_dim_key = 64, - ssa_dim_value = 64, + ssa_dim_key = 32, + ssa_dim_value = 32, ssa_reduction_factor = 1, - iwsa_dim_key = 64, - iwsa_dim_value = 64, - iwsa_window_size = 64, + iwsa_dim_key = 32, + iwsa_dim_value = 32, + iwsa_window_size = None, norm_output = True ): super().__init__() @@ -244,12 +244,12 @@ class ScalableViT(nn.Module): depth, heads, reduction_factor, + window_size = None, + iwsa_dim_key = 32, + iwsa_dim_value = 32, + ssa_dim_key = 32, + ssa_dim_value = 32, ff_expansion_factor = 4, - iwsa_dim_key = 64, - iwsa_dim_value = 64, - window_size = 64, - ssa_dim_key = 64, - ssa_dim_value = 64, channels = 3, dropout = 0. ):