Compare commits

...

1 Commits
main ... 0.18.2

Author SHA1 Message Date
Phil Wang
a954089eb7 only apply scaling after applying 2d rel pos bias 2021-05-10 10:46:43 -07:00
2 changed files with 4 additions and 2 deletions

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.18.1',
version = '0.18.2',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',

View File

@@ -95,10 +95,12 @@ class Attention(nn.Module):
qkv = (q, self.to_k(x), self.to_v(x))
q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = h), qkv)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
dots = einsum('b h i d, b h j d -> b h i j', q, k)
dots = self.apply_pos_bias(dots)
dots *= self.scale
attn = self.attend(dots)
out = einsum('b h i j, b h j d -> b h i d', attn, v)