mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
fix multiheaded qk rmsnorm in nViT
This commit is contained in:
2
setup.py
2
setup.py
@@ -6,7 +6,7 @@ with open('README.md') as f:
|
|||||||
setup(
|
setup(
|
||||||
name = 'vit-pytorch',
|
name = 'vit-pytorch',
|
||||||
packages = find_packages(exclude=['examples']),
|
packages = find_packages(exclude=['examples']),
|
||||||
version = '1.8.4',
|
version = '1.8.5',
|
||||||
license='MIT',
|
license='MIT',
|
||||||
description = 'Vision Transformer (ViT) - Pytorch',
|
description = 'Vision Transformer (ViT) - Pytorch',
|
||||||
long_description=long_description,
|
long_description=long_description,
|
||||||
|
|||||||
@@ -76,8 +76,8 @@ class Attention(Module):
|
|||||||
|
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
|
|
||||||
self.q_scale = nn.Parameter(torch.ones(dim_inner) * (dim_head ** 0.25))
|
self.q_scale = nn.Parameter(torch.ones(heads, 1, dim_head) * (dim_head ** 0.25))
|
||||||
self.k_scale = nn.Parameter(torch.ones(dim_inner) * (dim_head ** 0.25))
|
self.k_scale = nn.Parameter(torch.ones(heads, 1, dim_head) * (dim_head ** 0.25))
|
||||||
|
|
||||||
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
|
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
|
||||||
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
||||||
@@ -90,15 +90,15 @@ class Attention(Module):
|
|||||||
):
|
):
|
||||||
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
|
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
|
||||||
|
|
||||||
q = q * self.q_scale
|
|
||||||
k = k * self.k_scale
|
|
||||||
|
|
||||||
q, k, v = map(self.split_heads, (q, k, v))
|
q, k, v = map(self.split_heads, (q, k, v))
|
||||||
|
|
||||||
# query key rmsnorm
|
# query key rmsnorm
|
||||||
|
|
||||||
q, k = map(l2norm, (q, k))
|
q, k = map(l2norm, (q, k))
|
||||||
|
|
||||||
|
q = q * self.q_scale
|
||||||
|
k = k * self.k_scale
|
||||||
|
|
||||||
# scale is 1., as scaling factor is moved to s_qk (dk ^ 0.25) - eq. 16
|
# scale is 1., as scaling factor is moved to s_qk (dk ^ 0.25) - eq. 16
|
||||||
|
|
||||||
out = F.scaled_dot_product_attention(
|
out = F.scaled_dot_product_attention(
|
||||||
|
|||||||
Reference in New Issue
Block a user