From e300cdd7dc602793c4551f87a26c81693c939318 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 10 Oct 2024 19:15:17 -0700 Subject: [PATCH] fix multiheaded qk rmsnorm in nViT --- setup.py | 2 +- vit_pytorch/normalized_vit.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/setup.py b/setup.py index 01a0e3c..575faff 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ with open('README.md') as f: setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '1.8.4', + version = '1.8.5', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', long_description=long_description, diff --git a/vit_pytorch/normalized_vit.py b/vit_pytorch/normalized_vit.py index 5a21126..54f7dd2 100644 --- a/vit_pytorch/normalized_vit.py +++ b/vit_pytorch/normalized_vit.py @@ -76,8 +76,8 @@ class Attention(Module): self.dropout = dropout - self.q_scale = nn.Parameter(torch.ones(dim_inner) * (dim_head ** 0.25)) - self.k_scale = nn.Parameter(torch.ones(dim_inner) * (dim_head ** 0.25)) + self.q_scale = nn.Parameter(torch.ones(heads, 1, dim_head) * (dim_head ** 0.25)) + self.k_scale = nn.Parameter(torch.ones(heads, 1, dim_head) * (dim_head ** 0.25)) self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads) self.merge_heads = Rearrange('b h n d -> b n (h d)') @@ -90,15 +90,15 @@ class Attention(Module): ): q, k, v = self.to_q(x), self.to_k(x), self.to_v(x) - q = q * self.q_scale - k = k * self.k_scale - q, k, v = map(self.split_heads, (q, k, v)) # query key rmsnorm q, k = map(l2norm, (q, k)) + q = q * self.q_scale + k = k * self.k_scale + # scale is 1., as scaling factor is moved to s_qk (dk ^ 0.25) - eq. 16 out = F.scaled_dot_product_attention(