Compare commits

...

1 Commits
1.7.3 ... 1.7.4

Author SHA1 Message Date
lucidrains
9992a615d1 attention re-use in lookup vit should use pre-softmax attention matrix 2024-07-19 19:23:38 -07:00
2 changed files with 13 additions and 12 deletions

View File

@@ -6,7 +6,7 @@ with open('README.md') as f:
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.7.3',
version = '1.7.4',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description=long_description,

View File

@@ -99,8 +99,8 @@ class Attention(Module):
self,
x,
context = None,
return_attn = False,
attn = None
return_qk_sim = False,
qk_sim = None
):
x = self.norm(x)
@@ -119,20 +119,21 @@ class Attention(Module):
q, k = tuple(self.split_heads(t) for t in qk)
q = q * self.scale
sim = einsum(q, k, 'b h i d, b h j d -> b h i j')
qk_sim = einsum(q, k, 'b h i d, b h j d -> b h i j')
attn = self.attend(sim)
attn = self.dropout(attn)
else:
assert exists(attn), 'attention matrix must be passed in for reusing previous attention'
assert exists(qk_sim), 'qk sim matrix must be passed in for reusing previous attention'
attn = self.attend(qk_sim)
attn = self.dropout(attn)
out = einsum(attn, v, 'b h i j, b h j d -> b h i d')
out = self.to_out(out)
if not return_attn:
if not return_qk_sim:
return out
return out, attn
return out, qk_sim
# LookViT
@@ -228,7 +229,7 @@ class LookViT(Module):
# main tokens cross attends (lookup) on the high res tokens
lookup_out, lookup_attn = lookup_cross_attn(tokens, highres_tokens, return_attn = True) # return attention as they reuse the attention matrix
lookup_out, qk_sim = lookup_cross_attn(tokens, highres_tokens, return_qk_sim = True) # return attention as they reuse the attention matrix
tokens = lookup_out + tokens
tokens = attn(tokens) + tokens
@@ -236,9 +237,9 @@ class LookViT(Module):
# attention-reuse
lookup_attn = rearrange(lookup_attn, 'b h i j -> b h j i') # transpose for reverse cross attention
qk_sim = rearrange(qk_sim, 'b h i j -> b h j i') # transpose for reverse cross attention
highres_tokens = highres_attn(highres_tokens, tokens, attn = lookup_attn) + highres_tokens
highres_tokens = highres_attn(highres_tokens, tokens, qk_sim = qk_sim) + highres_tokens
highres_tokens = highres_norm(highres_tokens)
highres_tokens = highres_mlp(highres_tokens) + highres_tokens