mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
seeing a signal with dual patchnorm in another repository, fully incorporate
This commit is contained in:
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '1.0.0',
|
||||
version = '1.0.1',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
long_description_content_type = 'text/markdown',
|
||||
|
||||
@@ -118,7 +118,9 @@ class ViT(nn.Module):
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim)
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
|
||||
@@ -71,7 +71,12 @@ class PatchEmbedding(nn.Module):
|
||||
self.dim = dim
|
||||
self.dim_out = dim_out
|
||||
self.patch_size = patch_size
|
||||
self.proj = nn.Conv2d(patch_size ** 2 * dim, dim_out, 1)
|
||||
|
||||
self.proj = nn.Sequential(
|
||||
LayerNorm(patch_size ** 2 * dim),
|
||||
nn.Conv2d(patch_size ** 2 * dim, dim_out, 1),
|
||||
LayerNorm(dim_out)
|
||||
)
|
||||
|
||||
def forward(self, fmap):
|
||||
p = self.patch_size
|
||||
|
||||
@@ -121,7 +121,9 @@ class ViT(nn.Module):
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim)
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
|
||||
Reference in New Issue
Block a user