correct need for post-attention dropout

This commit is contained in:
Phil Wang
2022-03-30 10:50:57 -07:00
parent 6d7298d8ad
commit 4e6a42a0ca
20 changed files with 61 additions and 2 deletions

View File

@@ -42,6 +42,8 @@ class Attention(nn.Module):
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
@@ -56,6 +58,7 @@ class Attention(nn.Module):
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')