mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
use Rearrange layers
This commit is contained in:
@@ -37,20 +37,6 @@ def dropout_layers(layers, dropout):
|
||||
layers = [layer for (layer, drop) in zip(layers, to_drop) if not drop]
|
||||
return layers
|
||||
|
||||
# helper classes
|
||||
|
||||
class ChanLayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps = 1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
|
||||
|
||||
# classes
|
||||
|
||||
class LayerScale(Module):
|
||||
@@ -168,17 +154,17 @@ class LocalPatchInteraction(Module):
|
||||
padding = kernel_size // 2
|
||||
|
||||
self.net = nn.Sequential(
|
||||
ChanLayerNorm(dim),
|
||||
nn.LayerNorm(dim),
|
||||
Rearrange('b h w c -> b c h w'),
|
||||
nn.Conv2d(dim, dim, kernel_size, padding = padding, groups = dim),
|
||||
nn.BatchNorm2d(dim),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(dim, dim, kernel_size, padding = padding, groups = dim)
|
||||
nn.Conv2d(dim, dim, kernel_size, padding = padding, groups = dim),
|
||||
Rearrange('b c h w -> b h w c'),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = rearrange(x, 'b h w c -> b c h w')
|
||||
x = self.net(x)
|
||||
return rearrange(x, 'b c h w -> b h w c')
|
||||
return self.net(x)
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., layer_dropout = 0.):
|
||||
|
||||
Reference in New Issue
Block a user