mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
correct need for post-attention dropout
This commit is contained in:
@@ -42,6 +42,8 @@ class Attention(nn.Module):
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.reattn_weights = nn.Parameter(torch.randn(heads, heads))
|
||||
|
||||
self.reattn_norm = nn.Sequential(
|
||||
@@ -64,6 +66,7 @@ class Attention(nn.Module):
|
||||
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
attn = dots.softmax(dim=-1)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
# re-attention
|
||||
|
||||
|
||||
Reference in New Issue
Block a user