diff --git a/setup.py b/setup.py index be84ada..04e07c5 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '1.0.0', + version = '1.0.1', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', long_description_content_type = 'text/markdown', diff --git a/vit_pytorch/learnable_memory_vit.py b/vit_pytorch/learnable_memory_vit.py index 7764052..ecda757 100644 --- a/vit_pytorch/learnable_memory_vit.py +++ b/vit_pytorch/learnable_memory_vit.py @@ -118,7 +118,9 @@ class ViT(nn.Module): self.to_patch_embedding = nn.Sequential( Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), + nn.LayerNorm(patch_dim), nn.Linear(patch_dim, dim), + nn.LayerNorm(dim) ) self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) diff --git a/vit_pytorch/twins_svt.py b/vit_pytorch/twins_svt.py index 8a548da..ea888b8 100644 --- a/vit_pytorch/twins_svt.py +++ b/vit_pytorch/twins_svt.py @@ -71,7 +71,12 @@ class PatchEmbedding(nn.Module): self.dim = dim self.dim_out = dim_out self.patch_size = patch_size - self.proj = nn.Conv2d(patch_size ** 2 * dim, dim_out, 1) + + self.proj = nn.Sequential( + LayerNorm(patch_size ** 2 * dim), + nn.Conv2d(patch_size ** 2 * dim, dim_out, 1), + LayerNorm(dim_out) + ) def forward(self, fmap): p = self.patch_size diff --git a/vit_pytorch/vit_with_patch_merger.py b/vit_pytorch/vit_with_patch_merger.py index 5690ea8..7f1360b 100644 --- a/vit_pytorch/vit_with_patch_merger.py +++ b/vit_pytorch/vit_with_patch_merger.py @@ -121,7 +121,9 @@ class ViT(nn.Module): self.to_patch_embedding = nn.Sequential( Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), + nn.LayerNorm(patch_dim), nn.Linear(patch_dim, dim), + nn.LayerNorm(dim) ) self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))