allow for qk norm to be turned off for na vit nested tensor

This commit is contained in:
lucidrains
2024-11-20 10:59:22 -08:00
parent f6d7287b6b
commit 24196a3e8a
3 changed files with 15 additions and 13 deletions

View File

@@ -6,7 +6,7 @@ with open('README.md') as f:
setup( setup(
name = 'vit-pytorch', name = 'vit-pytorch',
packages = find_packages(exclude=['examples']), packages = find_packages(exclude=['examples']),
version = '1.8.7', version = '1.8.8',
license='MIT', license='MIT',
description = 'Vision Transformer (ViT) - Pytorch', description = 'Vision Transformer (ViT) - Pytorch',
long_description=long_description, long_description=long_description,

View File

@@ -41,7 +41,7 @@ def FeedForward(dim, hidden_dim, dropout = 0.):
) )
class Attention(Module): class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., qk_norm = True):
super().__init__() super().__init__()
self.norm = nn.LayerNorm(dim, bias = False) self.norm = nn.LayerNorm(dim, bias = False)
@@ -56,8 +56,8 @@ class Attention(Module):
# in the paper, they employ qk rmsnorm, a way to stabilize attention # in the paper, they employ qk rmsnorm, a way to stabilize attention
# will use layernorm in place of rmsnorm, which has been shown to work in certain papers. requires l2norm on non-ragged dimension to be supported in nested tensors # will use layernorm in place of rmsnorm, which has been shown to work in certain papers. requires l2norm on non-ragged dimension to be supported in nested tensors
self.query_norm = nn.LayerNorm(dim_head, bias = False) self.query_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity()
self.key_norm = nn.LayerNorm(dim_head, bias = False) self.key_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity()
self.dropout = dropout self.dropout = dropout
@@ -111,13 +111,13 @@ class Attention(Module):
return self.to_out(out) return self.to_out(out)
class Transformer(Module): class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., qk_norm = True):
super().__init__() super().__init__()
self.layers = ModuleList([]) self.layers = ModuleList([])
for _ in range(depth): for _ in range(depth):
self.layers.append(ModuleList([ self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, qk_norm = qk_norm),
FeedForward(dim, mlp_dim, dropout = dropout) FeedForward(dim, mlp_dim, dropout = dropout)
])) ]))
@@ -146,6 +146,7 @@ class NaViT(Module):
dim_head = 64, dim_head = 64,
dropout = 0., dropout = 0.,
emb_dropout = 0., emb_dropout = 0.,
qk_rmsnorm = True,
token_dropout_prob: float | None = None token_dropout_prob: float | None = None
): ):
super().__init__() super().__init__()
@@ -184,7 +185,7 @@ class NaViT(Module):
self.dropout = nn.Dropout(emb_dropout) self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, qk_rmsnorm)
# final attention pooling queries # final attention pooling queries

View File

@@ -41,7 +41,7 @@ def FeedForward(dim, hidden_dim, dropout = 0.):
) )
class Attention(Module): class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., qk_norm = True):
super().__init__() super().__init__()
self.norm = nn.LayerNorm(dim, bias = False) self.norm = nn.LayerNorm(dim, bias = False)
@@ -56,8 +56,8 @@ class Attention(Module):
# in the paper, they employ qk rmsnorm, a way to stabilize attention # in the paper, they employ qk rmsnorm, a way to stabilize attention
# will use layernorm in place of rmsnorm, which has been shown to work in certain papers. requires l2norm on non-ragged dimension to be supported in nested tensors # will use layernorm in place of rmsnorm, which has been shown to work in certain papers. requires l2norm on non-ragged dimension to be supported in nested tensors
self.query_norm = nn.LayerNorm(dim_head, bias = False) self.query_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity()
self.key_norm = nn.LayerNorm(dim_head, bias = False) self.key_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity()
self.dropout = dropout self.dropout = dropout
@@ -123,13 +123,13 @@ class Attention(Module):
return self.to_out(out) return self.to_out(out)
class Transformer(Module): class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., qk_norm = True):
super().__init__() super().__init__()
self.layers = ModuleList([]) self.layers = ModuleList([])
for _ in range(depth): for _ in range(depth):
self.layers.append(ModuleList([ self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, qk_norm = qk_norm),
FeedForward(dim, mlp_dim, dropout = dropout) FeedForward(dim, mlp_dim, dropout = dropout)
])) ]))
@@ -161,6 +161,7 @@ class NaViT(Module):
dropout = 0., dropout = 0.,
emb_dropout = 0., emb_dropout = 0.,
num_registers = 4, num_registers = 4,
qk_rmsnorm = True,
token_dropout_prob: float | None = None token_dropout_prob: float | None = None
): ):
super().__init__() super().__init__()
@@ -209,7 +210,7 @@ class NaViT(Module):
self.dropout = nn.Dropout(emb_dropout) self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, qk_rmsnorm)
# final attention pooling queries # final attention pooling queries