From 5699ed7d139062020d1394f0e85a07f706c87c09 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 10 Feb 2023 10:39:50 -0800 Subject: [PATCH] double down on dual patch norm, fix MAE and Simmim to be compatible with dual patchnorm --- setup.py | 2 +- vit_pytorch/ats_vit.py | 2 ++ vit_pytorch/cait.py | 2 ++ vit_pytorch/cross_vit.py | 2 ++ vit_pytorch/efficient.py | 2 ++ vit_pytorch/mae.py | 7 +++++-- vit_pytorch/simmim.py | 7 +++++-- vit_pytorch/vivit.py | 2 ++ 8 files changed, 21 insertions(+), 5 deletions(-) diff --git a/setup.py b/setup.py index 04e07c5..c3d56b3 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '1.0.1', + version = '1.0.2', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', long_description_content_type = 'text/markdown', diff --git a/vit_pytorch/ats_vit.py b/vit_pytorch/ats_vit.py index 69951be..779c400 100644 --- a/vit_pytorch/ats_vit.py +++ b/vit_pytorch/ats_vit.py @@ -230,7 +230,9 @@ class ViT(nn.Module): self.to_patch_embedding = nn.Sequential( Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), + nn.LayerNorm(patch_dim), nn.Linear(patch_dim, dim), + nn.LayerNorm(dim) ) self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) diff --git a/vit_pytorch/cait.py b/vit_pytorch/cait.py index 5968c6c..eac9185 100644 --- a/vit_pytorch/cait.py +++ b/vit_pytorch/cait.py @@ -150,7 +150,9 @@ class CaiT(nn.Module): self.to_patch_embedding = nn.Sequential( Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size), + nn.LayerNorm(patch_dim), nn.Linear(patch_dim, dim), + nn.LayerNorm(dim) ) self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim)) diff --git a/vit_pytorch/cross_vit.py b/vit_pytorch/cross_vit.py index 4bb637f..b894a2f 100644 --- a/vit_pytorch/cross_vit.py +++ b/vit_pytorch/cross_vit.py @@ -186,7 +186,9 @@ class ImageEmbedder(nn.Module): self.to_patch_embedding = nn.Sequential( Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size), + nn.LayerNorm(patch_dim), nn.Linear(patch_dim, dim), + nn.LayerNorm(dim) ) self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) diff --git a/vit_pytorch/efficient.py b/vit_pytorch/efficient.py index 8e9033d..c499331 100644 --- a/vit_pytorch/efficient.py +++ b/vit_pytorch/efficient.py @@ -17,7 +17,9 @@ class ViT(nn.Module): self.to_patch_embedding = nn.Sequential( Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size), + nn.LayerNorm(patch_dim), nn.Linear(patch_dim, dim), + nn.LayerNorm(dim) ) self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) diff --git a/vit_pytorch/mae.py b/vit_pytorch/mae.py index b2d750a..e2b993c 100644 --- a/vit_pytorch/mae.py +++ b/vit_pytorch/mae.py @@ -24,8 +24,11 @@ class MAE(nn.Module): self.encoder = encoder num_patches, encoder_dim = encoder.pos_embedding.shape[-2:] - self.to_patch, self.patch_to_emb = encoder.to_patch_embedding[:2] - pixel_values_per_patch = self.patch_to_emb.weight.shape[-1] + + self.to_patch = encoder.to_patch_embedding[0] + self.patch_to_emb = nn.Sequential(*encoder.to_patch_embedding[1:]) + + pixel_values_per_patch = encoder.to_patch_embedding[2].weight.shape[-1] # decoder parameters self.decoder_dim = decoder_dim diff --git a/vit_pytorch/simmim.py b/vit_pytorch/simmim.py index 710b4a1..5804ce3 100644 --- a/vit_pytorch/simmim.py +++ b/vit_pytorch/simmim.py @@ -18,8 +18,11 @@ class SimMIM(nn.Module): self.encoder = encoder num_patches, encoder_dim = encoder.pos_embedding.shape[-2:] - self.to_patch, self.patch_to_emb = encoder.to_patch_embedding[:2] - pixel_values_per_patch = self.patch_to_emb.weight.shape[-1] + + self.to_patch = encoder.to_patch_embedding[0] + self.patch_to_emb = nn.Sequential(*encoder.to_patch_embedding[1:]) + + pixel_values_per_patch = encoder.to_patch_embedding[2].weight.shape[-1] # simple linear head diff --git a/vit_pytorch/vivit.py b/vit_pytorch/vivit.py index 1acb6f7..082b699 100644 --- a/vit_pytorch/vivit.py +++ b/vit_pytorch/vivit.py @@ -120,7 +120,9 @@ class ViT(nn.Module): self.to_patch_embedding = nn.Sequential( Rearrange('b c (f pf) (h p1) (w p2) -> b f (h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size), + nn.LayerNorm(patch_dim), nn.Linear(patch_dim, dim), + nn.LayerNorm(dim) ) self.pos_embedding = nn.Parameter(torch.randn(1, num_frame_patches, num_image_patches, dim))