use Rearrange layers

This commit is contained in:
lucidrains
2023-10-12 19:48:31 -07:00
parent 3dfb1579f7
commit bcfb0f054a

View File

@@ -37,20 +37,6 @@ def dropout_layers(layers, dropout):
layers = [layer for (layer, drop) in zip(layers, to_drop) if not drop]
return layers
# helper classes
class ChanLayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
# classes
class LayerScale(Module):
@@ -168,17 +154,17 @@ class LocalPatchInteraction(Module):
padding = kernel_size // 2
self.net = nn.Sequential(
ChanLayerNorm(dim),
nn.LayerNorm(dim),
Rearrange('b h w c -> b c h w'),
nn.Conv2d(dim, dim, kernel_size, padding = padding, groups = dim),
nn.BatchNorm2d(dim),
nn.GELU(),
nn.Conv2d(dim, dim, kernel_size, padding = padding, groups = dim)
nn.Conv2d(dim, dim, kernel_size, padding = padding, groups = dim),
Rearrange('b c h w -> b h w c'),
)
def forward(self, x):
x = rearrange(x, 'b h w c -> b c h w')
x = self.net(x)
return rearrange(x, 'b c h w -> b h w c')
return self.net(x)
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., layer_dropout = 0.):