mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
correct need for post-attention dropout
This commit is contained in:
@@ -42,6 +42,8 @@ class Attention(nn.Module):
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
@@ -56,6 +58,7 @@ class Attention(nn.Module):
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
|
||||
Reference in New Issue
Block a user