From e5324242be61bcbf433e129e914aa4b4fa1a79a0 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 5 Aug 2021 12:55:48 -0700 Subject: [PATCH] fix wrong norm in nest --- setup.py | 2 +- vit_pytorch/nest.py | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 8a979ef..187543f 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.20.1', + version = '0.20.2', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', author = 'Phil Wang', diff --git a/vit_pytorch/nest.py b/vit_pytorch/nest.py index 18919bd..e0d383d 100644 --- a/vit_pytorch/nest.py +++ b/vit_pytorch/nest.py @@ -10,10 +10,20 @@ from einops.layers.torch import Rearrange, Reduce def cast_tuple(val, depth): return val if isinstance(val, tuple) else ((val,) * depth) -LayerNorm = partial(nn.InstanceNorm2d, affine = True) - # classes +class LayerNorm(nn.Module): + def __init__(self, dim, eps = 1e-5): + super().__init__() + self.eps = eps + self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) + self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)) + + def forward(self, x): + std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt() + mean = torch.mean(x, dim = 1, keepdim = True) + return (x - mean) / (std + self.eps) * self.g + self.b + class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__()