mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
double down on dual patch norm, fix MAE and Simmim to be compatible with dual patchnorm
This commit is contained in:
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '1.0.1',
|
||||
version = '1.0.2',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
long_description_content_type = 'text/markdown',
|
||||
|
||||
@@ -230,7 +230,9 @@ class ViT(nn.Module):
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim)
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
|
||||
@@ -150,7 +150,9 @@ class CaiT(nn.Module):
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim)
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
|
||||
|
||||
@@ -186,7 +186,9 @@ class ImageEmbedder(nn.Module):
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim)
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
|
||||
@@ -17,7 +17,9 @@ class ViT(nn.Module):
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim)
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
|
||||
@@ -24,8 +24,11 @@ class MAE(nn.Module):
|
||||
|
||||
self.encoder = encoder
|
||||
num_patches, encoder_dim = encoder.pos_embedding.shape[-2:]
|
||||
self.to_patch, self.patch_to_emb = encoder.to_patch_embedding[:2]
|
||||
pixel_values_per_patch = self.patch_to_emb.weight.shape[-1]
|
||||
|
||||
self.to_patch = encoder.to_patch_embedding[0]
|
||||
self.patch_to_emb = nn.Sequential(*encoder.to_patch_embedding[1:])
|
||||
|
||||
pixel_values_per_patch = encoder.to_patch_embedding[2].weight.shape[-1]
|
||||
|
||||
# decoder parameters
|
||||
self.decoder_dim = decoder_dim
|
||||
|
||||
@@ -18,8 +18,11 @@ class SimMIM(nn.Module):
|
||||
|
||||
self.encoder = encoder
|
||||
num_patches, encoder_dim = encoder.pos_embedding.shape[-2:]
|
||||
self.to_patch, self.patch_to_emb = encoder.to_patch_embedding[:2]
|
||||
pixel_values_per_patch = self.patch_to_emb.weight.shape[-1]
|
||||
|
||||
self.to_patch = encoder.to_patch_embedding[0]
|
||||
self.patch_to_emb = nn.Sequential(*encoder.to_patch_embedding[1:])
|
||||
|
||||
pixel_values_per_patch = encoder.to_patch_embedding[2].weight.shape[-1]
|
||||
|
||||
# simple linear head
|
||||
|
||||
|
||||
@@ -120,7 +120,9 @@ class ViT(nn.Module):
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b f (h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim)
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_frame_patches, num_image_patches, dim))
|
||||
|
||||
Reference in New Issue
Block a user