diff --git a/setup.py b/setup.py index e93a580..54afec4 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ with open('README.md') as f: setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '1.7.12', + version = '1.7.14', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', long_description=long_description, diff --git a/vit_pytorch/regionvit.py b/vit_pytorch/regionvit.py index 2e155a1..57e4c64 100644 --- a/vit_pytorch/regionvit.py +++ b/vit_pytorch/regionvit.py @@ -20,6 +20,18 @@ def divisible_by(val, d): # helper classes +class ChanLayerNorm(nn.Module): + def __init__(self, dim, eps = 1e-5): + super().__init__() + self.eps = eps + self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) + self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)) + + def forward(self, x): + var = torch.var(x, dim = 1, unbiased = False, keepdim = True) + mean = torch.mean(x, dim = 1, keepdim = True) + return (x - mean) / (var + self.eps).sqrt() * self.g + self.b + class Downsample(nn.Module): def __init__(self, dim_in, dim_out): super().__init__() @@ -212,10 +224,10 @@ class RegionViT(nn.Module): if tokenize_local_3_conv: self.local_encoder = nn.Sequential( nn.Conv2d(3, init_dim, 3, 2, 1), - nn.LayerNorm(init_dim), + ChanLayerNorm(init_dim), nn.GELU(), nn.Conv2d(init_dim, init_dim, 3, 2, 1), - nn.LayerNorm(init_dim), + ChanLayerNorm(init_dim), nn.GELU(), nn.Conv2d(init_dim, init_dim, 3, 1, 1) )