just remove PreNorm wrapper from all ViTs, as it is unlikely to change at this point

This commit is contained in:
Phil Wang
2023-08-14 09:48:55 -07:00
parent 4264efd906
commit 8208c859a5
21 changed files with 137 additions and 232 deletions

View File

@@ -25,15 +25,6 @@ class ChanLayerNorm(nn.Module):
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = ChanLayerNorm(dim)
self.fn = fn
def forward(self, x):
return self.fn(self.norm(x))
class OverlappingPatchEmbed(nn.Module):
def __init__(self, dim_in, dim_out, stride = 2):
super().__init__()
@@ -59,6 +50,7 @@ class FeedForward(nn.Module):
super().__init__()
inner_dim = int(dim * mult)
self.net = nn.Sequential(
ChanLayerNorm(dim),
nn.Conv2d(dim, inner_dim, 1),
nn.GELU(),
nn.Dropout(dropout),
@@ -85,6 +77,8 @@ class DSSA(nn.Module):
self.window_size = window_size
inner_dim = dim_head * heads
self.norm = ChanLayerNorm(dim)
self.attend = nn.Sequential(
nn.Softmax(dim = -1),
nn.Dropout(dropout)
@@ -138,6 +132,8 @@ class DSSA(nn.Module):
assert (height % wsz) == 0 and (width % wsz) == 0, f'height {height} and width {width} must be divisible by window size {wsz}'
num_windows = (height // wsz) * (width // wsz)
x = self.norm(x)
# fold in windows for "depthwise" attention - not sure why it is named depthwise when it is just "windowed" attention
x = rearrange(x, 'b c (h w1) (w w2) -> (b h w) c (w1 w2)', w1 = wsz, w2 = wsz)
@@ -225,8 +221,8 @@ class Transformer(nn.Module):
for ind in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, DSSA(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mult = ff_mult, dropout = dropout)),
DSSA(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mult = ff_mult, dropout = dropout),
]))
self.norm = ChanLayerNorm(dim) if norm_output else nn.Identity()