Compare commits

...

1 Commits
1.6.8 ... 1.4.3

Author SHA1 Message Date
Phil Wang
ad7d0df22c just remove PreNorm wrapper from all ViTs, as it is unlikely to change at this point 2023-08-14 09:41:40 -07:00
21 changed files with 137 additions and 232 deletions

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.4.2',
version = '1.4.3',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description_content_type = 'text/markdown',

View File

@@ -110,18 +110,11 @@ class AdaptiveTokenSampling(nn.Module):
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
@@ -138,6 +131,7 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
@@ -154,6 +148,7 @@ class Attention(nn.Module):
def forward(self, x, *, mask):
num_tokens = x.shape[1]
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
@@ -189,8 +184,8 @@ class Transformer(nn.Module):
self.layers = nn.ModuleList([])
for _, output_num_tokens in zip(range(depth), max_tokens_per_depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, output_num_tokens = output_num_tokens, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
Attention(dim, output_num_tokens = output_num_tokens, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):

View File

@@ -44,18 +44,11 @@ class LayerScale(nn.Module):
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) * self.scale
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
@@ -72,6 +65,7 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
@@ -89,6 +83,7 @@ class Attention(nn.Module):
def forward(self, x, context = None):
b, n, _, h = *x.shape, self.heads
x = self.norm(x)
context = x if not exists(context) else torch.cat((x, context), dim = 1)
qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
@@ -115,8 +110,8 @@ class Transformer(nn.Module):
for ind in range(depth):
self.layers.append(nn.ModuleList([
LayerScale(dim, PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), depth = ind + 1),
LayerScale(dim, PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)), depth = ind + 1)
LayerScale(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), depth = ind + 1),
LayerScale(dim, FeedForward(dim, mlp_dim, dropout = dropout), depth = ind + 1)
]))
def forward(self, x, context = None):
layers = dropout_layers(self.layers, dropout = self.layer_dropout)

View File

@@ -13,22 +13,13 @@ def exists(val):
def default(val, d):
return val if exists(val) else d
# pre-layernorm
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
# feedforward
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
@@ -47,6 +38,7 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
@@ -60,6 +52,7 @@ class Attention(nn.Module):
def forward(self, x, context = None, kv_include_self = False):
b, n, _, h = *x.shape, self.heads
x = self.norm(x)
context = default(context, x)
if kv_include_self:
@@ -86,8 +79,8 @@ class Transformer(nn.Module):
self.norm = nn.LayerNorm(dim)
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
@@ -121,8 +114,8 @@ class CrossTransformer(nn.Module):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
ProjectInOut(sm_dim, lg_dim, PreNorm(lg_dim, Attention(lg_dim, heads = heads, dim_head = dim_head, dropout = dropout))),
ProjectInOut(lg_dim, sm_dim, PreNorm(sm_dim, Attention(sm_dim, heads = heads, dim_head = dim_head, dropout = dropout)))
ProjectInOut(sm_dim, lg_dim, Attention(lg_dim, heads = heads, dim_head = dim_head, dropout = dropout)),
ProjectInOut(lg_dim, sm_dim, ttention(sm_dim, heads = heads, dim_head = dim_head, dropout = dropout))
]))
def forward(self, sm_tokens, lg_tokens):

View File

@@ -34,19 +34,11 @@ class LayerNorm(nn.Module): # layernorm, but done in the channel dimension #1
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
LayerNorm(dim),
nn.Conv2d(dim, dim * mult, 1),
nn.GELU(),
nn.Dropout(dropout),
@@ -75,6 +67,7 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
@@ -89,6 +82,8 @@ class Attention(nn.Module):
def forward(self, x):
shape = x.shape
b, n, _, y, h = *shape, self.heads
x = self.norm(x)
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = 1))
q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> (b h) (x y) d', h = h), (q, k, v))
@@ -107,8 +102,8 @@ class Transformer(nn.Module):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, proj_kernel = proj_kernel, kv_proj_stride = kv_proj_stride, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout))
Attention(dim, proj_kernel = proj_kernel, kv_proj_stride = kv_proj_stride, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_mult, dropout = dropout)
]))
def forward(self, x):
for attn, ff in self.layers:

View File

@@ -5,25 +5,11 @@ import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
@@ -40,6 +26,7 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.dropout = nn.Dropout(dropout)
@@ -59,6 +46,8 @@ class Attention(nn.Module):
def forward(self, x):
b, n, _, h = *x.shape, self.heads
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
@@ -86,13 +75,13 @@ class Transformer(nn.Module):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x)
x = ff(x)
x = attn(x) + x
x = ff(x) + x
return x
class DeepViT(nn.Module):

View File

@@ -26,16 +26,6 @@ class ExcludeCLS(nn.Module):
x = self.fn(x, **kwargs)
return torch.cat((cls_token, x), dim = 1)
# prenorm
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
# feed forward related classes
class DepthWiseConv2d(nn.Module):
@@ -52,6 +42,7 @@ class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Conv2d(dim, hidden_dim, 1),
nn.Hardswish(),
DepthWiseConv2d(hidden_dim, hidden_dim, 3, padding = 1),
@@ -77,6 +68,7 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
@@ -88,6 +80,8 @@ class Attention(nn.Module):
def forward(self, x):
b, n, _, h = *x.shape, self.heads
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
@@ -106,8 +100,8 @@ class Transformer(nn.Module):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
ExcludeCLS(Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))))
Residual(Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
ExcludeCLS(Residual(FeedForward(dim, mlp_dim, dropout = dropout)))
]))
def forward(self, x):
for attn, ff in self.layers:

View File

@@ -19,20 +19,20 @@ def cast_tuple(val, length = 1):
# helper classes
class PreNormResidual(nn.Module):
class Residual(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x):
return self.fn(self.norm(x)) + x
return self.fn(x) + x
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
inner_dim = int(dim * mult)
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim),
nn.GELU(),
nn.Dropout(dropout),
@@ -132,6 +132,7 @@ class Attention(nn.Module):
self.heads = dim // dim_head
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
self.attend = nn.Sequential(
@@ -160,6 +161,8 @@ class Attention(nn.Module):
def forward(self, x):
batch, height, width, window_height, window_width, _, device, h = *x.shape, x.device, self.heads
x = self.norm(x)
# flatten
x = rearrange(x, 'b x y w1 w2 d -> (b x y) (w1 w2) d')
@@ -259,13 +262,13 @@ class MaxViT(nn.Module):
shrinkage_rate = mbconv_shrinkage_rate
),
Rearrange('b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w), # block-like attention
PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
PreNormResidual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
Residual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
Residual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
Rearrange('b x y w1 w2 d -> b d (x w1) (y w2)'),
Rearrange('b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w), # grid-like attention
PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
PreNormResidual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
Residual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
Residual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
Rearrange('b x y w1 w2 d -> b d (w1 x) (w2 y)'),
)

View File

@@ -22,20 +22,11 @@ def conv_nxn_bn(inp, oup, kernel_size=3, stride=1):
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.SiLU(),
nn.Dropout(dropout),
@@ -53,6 +44,7 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(dropout)
@@ -64,9 +56,10 @@ class Attention(nn.Module):
)
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(
t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
@@ -88,8 +81,8 @@ class Transformer(nn.Module):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
Attention(dim, heads, dim_head, dropout),
FeedForward(dim, mlp_dim, dropout)
]))
def forward(self, x):
@@ -167,11 +160,9 @@ class MobileViTBlock(nn.Module):
# Global representations
_, _, h, w = x.shape
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d',
ph=self.ph, pw=self.pw)
x = self.transformer(x)
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)',
h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
x = self.transformer(x)
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)
# Fusion
x = self.conv3(x)

View File

@@ -24,19 +24,11 @@ class LayerNorm(nn.Module):
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, mlp_mult = 4, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
LayerNorm(dim),
nn.Conv2d(dim, dim * mlp_mult, 1),
nn.GELU(),
nn.Dropout(dropout),
@@ -54,6 +46,7 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
@@ -66,6 +59,8 @@ class Attention(nn.Module):
def forward(self, x):
b, c, h, w, heads = *x.shape, self.heads
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> b h (x y) d', h = heads), qkv)
@@ -93,8 +88,8 @@ class Transformer(nn.Module):
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout))
Attention(dim, heads = heads, dropout = dropout),
FeedForward(dim, mlp_mult, dropout = dropout)
]))
def forward(self, x):
*_, h, w = x.shape

View File

@@ -19,18 +19,11 @@ class Parallel(nn.Module):
def forward(self, x):
return sum([fn(x) for fn in self.fns])
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
@@ -49,6 +42,7 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
@@ -60,6 +54,7 @@ class Attention(nn.Module):
) if project_out else nn.Identity()
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
@@ -77,8 +72,8 @@ class Transformer(nn.Module):
super().__init__()
self.layers = nn.ModuleList([])
attn_block = lambda: PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))
ff_block = lambda: PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
attn_block = lambda: Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)
ff_block = lambda: FeedForward(dim, mlp_dim, dropout = dropout)
for _ in range(depth):
self.layers.append(nn.ModuleList([

View File

@@ -17,18 +17,11 @@ def conv_output_size(image_size, kernel_size, stride, padding = 0):
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
@@ -47,6 +40,7 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
@@ -58,6 +52,8 @@ class Attention(nn.Module):
def forward(self, x):
b, n, _, h = *x.shape, self.heads
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
@@ -76,8 +72,8 @@ class Transformer(nn.Module):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
for attn, ff in self.layers:

View File

@@ -55,14 +55,6 @@ class DepthWiseConv2d(nn.Module):
# helper classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class SpatialConv(nn.Module):
def __init__(self, dim_in, dim_out, kernel, bias = False):
super().__init__()
@@ -86,6 +78,7 @@ class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0., use_glu = True):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim * 2 if use_glu else hidden_dim),
GEGLU() if use_glu else nn.GELU(),
nn.Dropout(dropout),
@@ -103,6 +96,7 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
@@ -121,6 +115,9 @@ class Attention(nn.Module):
b, n, _, h = *x.shape, self.heads
to_q_kwargs = {'fmap_dims': fmap_dims} if self.use_ds_conv else {}
x = self.norm(x)
q = self.to_q(x, **to_q_kwargs)
qkv = (q, *self.to_kv(x).chunk(2, dim = -1))
@@ -162,8 +159,8 @@ class Transformer(nn.Module):
self.pos_emb = AxialRotaryEmbedding(dim_head, max_freq = image_size)
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_rotary = use_rotary, use_ds_conv = use_ds_conv)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout, use_glu = use_glu))
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_rotary = use_rotary, use_ds_conv = use_ds_conv),
FeedForward(dim, mlp_dim, dropout = dropout, use_glu = use_glu)
]))
def forward(self, x, fmap_dims):
pos_emb = self.pos_emb(x[:, 1:])

View File

@@ -33,15 +33,6 @@ class ChanLayerNorm(nn.Module):
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = ChanLayerNorm(dim)
self.fn = fn
def forward(self, x):
return self.fn(self.norm(x))
class Downsample(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
@@ -65,6 +56,7 @@ class FeedForward(nn.Module):
super().__init__()
inner_dim = dim * expansion_factor
self.net = nn.Sequential(
ChanLayerNorm(dim),
nn.Conv2d(dim, inner_dim, 1),
nn.GELU(),
nn.Dropout(dropout),
@@ -92,6 +84,7 @@ class ScalableSelfAttention(nn.Module):
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.norm = ChanLayerNorm(dim)
self.to_q = nn.Conv2d(dim, dim_key * heads, 1, bias = False)
self.to_k = nn.Conv2d(dim, dim_key * heads, reduction_factor, stride = reduction_factor, bias = False)
self.to_v = nn.Conv2d(dim, dim_value * heads, reduction_factor, stride = reduction_factor, bias = False)
@@ -104,6 +97,8 @@ class ScalableSelfAttention(nn.Module):
def forward(self, x):
height, width, heads = *x.shape[-2:], self.heads
x = self.norm(x)
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
# split out heads
@@ -145,6 +140,7 @@ class InteractiveWindowedSelfAttention(nn.Module):
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.norm = ChanLayerNorm(dim)
self.local_interactive_module = nn.Conv2d(dim_value * heads, dim_value * heads, 3, padding = 1)
self.to_q = nn.Conv2d(dim, dim_key * heads, 1, bias = False)
@@ -159,6 +155,8 @@ class InteractiveWindowedSelfAttention(nn.Module):
def forward(self, x):
height, width, heads, wsz = *x.shape[-2:], self.heads, self.window_size
x = self.norm(x)
wsz_h, wsz_w = default(wsz, height), default(wsz, width)
assert (height % wsz_h) == 0 and (width % wsz_w) == 0, f'height ({height}) or width ({width}) of feature map is not divisible by the window size ({wsz_h}, {wsz_w})'
@@ -217,11 +215,11 @@ class Transformer(nn.Module):
is_first = ind == 0
self.layers.append(nn.ModuleList([
PreNorm(dim, ScalableSelfAttention(dim, heads = heads, dim_key = ssa_dim_key, dim_value = ssa_dim_value, reduction_factor = ssa_reduction_factor, dropout = dropout)),
PreNorm(dim, FeedForward(dim, expansion_factor = ff_expansion_factor, dropout = dropout)),
ScalableSelfAttention(dim, heads = heads, dim_key = ssa_dim_key, dim_value = ssa_dim_value, reduction_factor = ssa_reduction_factor, dropout = dropout),
FeedForward(dim, expansion_factor = ff_expansion_factor, dropout = dropout),
PEG(dim) if is_first else None,
PreNorm(dim, FeedForward(dim, expansion_factor = ff_expansion_factor, dropout = dropout)),
PreNorm(dim, InteractiveWindowedSelfAttention(dim, heads = heads, dim_key = iwsa_dim_key, dim_value = iwsa_dim_value, window_size = iwsa_window_size, dropout = dropout))
FeedForward(dim, expansion_factor = ff_expansion_factor, dropout = dropout),
InteractiveWindowedSelfAttention(dim, heads = heads, dim_key = iwsa_dim_key, dim_value = iwsa_dim_value, window_size = iwsa_window_size, dropout = dropout)
]))
self.norm = ChanLayerNorm(dim) if norm_output else nn.Identity()

View File

@@ -25,15 +25,6 @@ class ChanLayerNorm(nn.Module):
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = ChanLayerNorm(dim)
self.fn = fn
def forward(self, x):
return self.fn(self.norm(x))
class OverlappingPatchEmbed(nn.Module):
def __init__(self, dim_in, dim_out, stride = 2):
super().__init__()
@@ -59,6 +50,7 @@ class FeedForward(nn.Module):
super().__init__()
inner_dim = int(dim * mult)
self.net = nn.Sequential(
ChanLayerNorm(dim),
nn.Conv2d(dim, inner_dim, 1),
nn.GELU(),
nn.Dropout(dropout),
@@ -85,6 +77,8 @@ class DSSA(nn.Module):
self.window_size = window_size
inner_dim = dim_head * heads
self.norm = ChanLayerNorm(dim)
self.attend = nn.Sequential(
nn.Softmax(dim = -1),
nn.Dropout(dropout)
@@ -138,6 +132,8 @@ class DSSA(nn.Module):
assert (height % wsz) == 0 and (width % wsz) == 0, f'height {height} and width {width} must be divisible by window size {wsz}'
num_windows = (height // wsz) * (width // wsz)
x = self.norm(x)
# fold in windows for "depthwise" attention - not sure why it is named depthwise when it is just "windowed" attention
x = rearrange(x, 'b c (h w1) (w w2) -> (b h w) c (w1 w2)', w1 = wsz, w2 = wsz)
@@ -225,8 +221,8 @@ class Transformer(nn.Module):
for ind in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, DSSA(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mult = ff_mult, dropout = dropout)),
DSSA(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mult = ff_mult, dropout = dropout),
]))
self.norm = ChanLayerNorm(dim) if norm_output else nn.Identity()

View File

@@ -42,20 +42,11 @@ class LayerNorm(nn.Module):
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
LayerNorm(dim),
nn.Conv2d(dim, dim * mult, 1),
nn.GELU(),
nn.Dropout(dropout),
@@ -99,6 +90,7 @@ class LocalAttention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = LayerNorm(dim)
self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False)
self.to_kv = nn.Conv2d(dim, inner_dim * 2, 1, bias = False)
@@ -108,6 +100,8 @@ class LocalAttention(nn.Module):
)
def forward(self, fmap):
fmap = self.norm(fmap)
shape, p = fmap.shape, self.patch_size
b, n, x, y, h = *shape, self.heads
x, y = map(lambda t: t // p, (x, y))
@@ -132,6 +126,8 @@ class GlobalAttention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = LayerNorm(dim)
self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False)
self.to_kv = nn.Conv2d(dim, inner_dim * 2, k, stride = k, bias = False)
@@ -143,6 +139,8 @@ class GlobalAttention(nn.Module):
)
def forward(self, x):
x = self.norm(x)
shape = x.shape
b, n, _, y, h = *shape, self.heads
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = 1))
@@ -164,10 +162,10 @@ class Transformer(nn.Module):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Residual(PreNorm(dim, LocalAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout, patch_size = local_patch_size))) if has_local else nn.Identity(),
Residual(PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout))) if has_local else nn.Identity(),
Residual(PreNorm(dim, GlobalAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout, k = global_k))),
Residual(PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout)))
Residual(LocalAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout, patch_size = local_patch_size)) if has_local else nn.Identity(),
Residual(FeedForward(dim, mlp_mult, dropout = dropout)) if has_local else nn.Identity(),
Residual(GlobalAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout, k = global_k)),
Residual(FeedForward(dim, mlp_mult, dropout = dropout))
]))
def forward(self, x):
for local_attn, ff1, global_attn, ff2 in self.layers:

View File

@@ -6,18 +6,11 @@ from einops.layers.torch import Rearrange
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Layernorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
@@ -36,6 +29,7 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
@@ -47,6 +41,7 @@ class Attention(nn.Module):
) if project_out else nn.Identity()
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
@@ -65,8 +60,8 @@ class Transformer(nn.Module):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
for attn, ff in self.layers:

View File

@@ -11,18 +11,11 @@ def pair(t):
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
@@ -41,6 +34,7 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
@@ -52,6 +46,7 @@ class Attention(nn.Module):
) if project_out else nn.Identity()
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
@@ -70,8 +65,8 @@ class Transformer(nn.Module):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
for attn, ff in self.layers:

View File

@@ -13,18 +13,11 @@ def pair(t):
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim)
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
@@ -41,6 +34,7 @@ class LSA(nn.Module):
self.heads = heads
self.temperature = nn.Parameter(torch.log(torch.tensor(dim_head ** -0.5)))
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
@@ -52,6 +46,7 @@ class LSA(nn.Module):
)
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
@@ -74,8 +69,8 @@ class Transformer(nn.Module):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, LSA(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
LSA(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
for attn, ff in self.layers:

View File

@@ -30,18 +30,11 @@ class PatchDropout(nn.Module):
return x[batch_indices, patch_indices_keep]
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
@@ -60,6 +53,7 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
@@ -71,6 +65,7 @@ class Attention(nn.Module):
) if project_out else nn.Identity()
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
@@ -89,8 +84,8 @@ class Transformer(nn.Module):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
for attn, ff in self.layers:

View File

@@ -14,18 +14,11 @@ def pair(t):
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
@@ -44,6 +37,7 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
@@ -55,6 +49,7 @@ class Attention(nn.Module):
) if project_out else nn.Identity()
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
@@ -74,8 +69,8 @@ class Transformer(nn.Module):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
for attn, ff in self.layers: