From bcfb0f054ae98584a3ddc3e5e7836390fbfa1ab9 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 12 Oct 2023 19:48:31 -0700 Subject: [PATCH] use Rearrange layers --- vit_pytorch/xcit.py | 24 +++++------------------- 1 file changed, 5 insertions(+), 19 deletions(-) diff --git a/vit_pytorch/xcit.py b/vit_pytorch/xcit.py index afc02d3..bc5a85c 100644 --- a/vit_pytorch/xcit.py +++ b/vit_pytorch/xcit.py @@ -37,20 +37,6 @@ def dropout_layers(layers, dropout): layers = [layer for (layer, drop) in zip(layers, to_drop) if not drop] return layers -# helper classes - -class ChanLayerNorm(nn.Module): - def __init__(self, dim, eps = 1e-5): - super().__init__() - self.eps = eps - self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) - self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)) - - def forward(self, x): - var = torch.var(x, dim = 1, unbiased = False, keepdim = True) - mean = torch.mean(x, dim = 1, keepdim = True) - return (x - mean) / (var + self.eps).sqrt() * self.g + self.b - # classes class LayerScale(Module): @@ -168,17 +154,17 @@ class LocalPatchInteraction(Module): padding = kernel_size // 2 self.net = nn.Sequential( - ChanLayerNorm(dim), + nn.LayerNorm(dim), + Rearrange('b h w c -> b c h w'), nn.Conv2d(dim, dim, kernel_size, padding = padding, groups = dim), nn.BatchNorm2d(dim), nn.GELU(), - nn.Conv2d(dim, dim, kernel_size, padding = padding, groups = dim) + nn.Conv2d(dim, dim, kernel_size, padding = padding, groups = dim), + Rearrange('b c h w -> b h w c'), ) def forward(self, x): - x = rearrange(x, 'b h w c -> b c h w') - x = self.net(x) - return rearrange(x, 'b c h w -> b h w c') + return self.net(x) class Transformer(Module): def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., layer_dropout = 0.):