correct need for post-attention dropout

This commit is contained in:
Phil Wang
2022-03-30 10:50:57 -07:00
parent 6d7298d8ad
commit 4e6a42a0ca
20 changed files with 61 additions and 2 deletions

View File

@@ -104,6 +104,7 @@ class Attention(nn.Module):
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.use_ds_conv = use_ds_conv
@@ -148,6 +149,7 @@ class Attention(nn.Module):
dots = einsum('b i d, b j d -> b i j', q, k) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)