fix multiheaded qk rmsnorm in nViT

This commit is contained in:
lucidrains
2024-10-10 19:15:17 -07:00
parent 36ddc7a6ba
commit e300cdd7dc
2 changed files with 6 additions and 6 deletions

View File

@@ -6,7 +6,7 @@ with open('README.md') as f:
setup( setup(
name = 'vit-pytorch', name = 'vit-pytorch',
packages = find_packages(exclude=['examples']), packages = find_packages(exclude=['examples']),
version = '1.8.4', version = '1.8.5',
license='MIT', license='MIT',
description = 'Vision Transformer (ViT) - Pytorch', description = 'Vision Transformer (ViT) - Pytorch',
long_description=long_description, long_description=long_description,

View File

@@ -76,8 +76,8 @@ class Attention(Module):
self.dropout = dropout self.dropout = dropout
self.q_scale = nn.Parameter(torch.ones(dim_inner) * (dim_head ** 0.25)) self.q_scale = nn.Parameter(torch.ones(heads, 1, dim_head) * (dim_head ** 0.25))
self.k_scale = nn.Parameter(torch.ones(dim_inner) * (dim_head ** 0.25)) self.k_scale = nn.Parameter(torch.ones(heads, 1, dim_head) * (dim_head ** 0.25))
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads) self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
self.merge_heads = Rearrange('b h n d -> b n (h d)') self.merge_heads = Rearrange('b h n d -> b n (h d)')
@@ -90,15 +90,15 @@ class Attention(Module):
): ):
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x) q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
q = q * self.q_scale
k = k * self.k_scale
q, k, v = map(self.split_heads, (q, k, v)) q, k, v = map(self.split_heads, (q, k, v))
# query key rmsnorm # query key rmsnorm
q, k = map(l2norm, (q, k)) q, k = map(l2norm, (q, k))
q = q * self.q_scale
k = k * self.k_scale
# scale is 1., as scaling factor is moved to s_qk (dk ^ 0.25) - eq. 16 # scale is 1., as scaling factor is moved to s_qk (dk ^ 0.25) - eq. 16
out = F.scaled_dot_product_attention( out = F.scaled_dot_product_attention(