mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
correct need for post-attention dropout
This commit is contained in:
@@ -52,6 +52,7 @@ class Attention(nn.Module):
|
||||
self.to_v = nn.Sequential(nn.Conv2d(dim, inner_dim_value, 1, bias = False), nn.BatchNorm2d(inner_dim_value))
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
out_batch_norm = nn.BatchNorm2d(dim_out)
|
||||
nn.init.zeros_(out_batch_norm.weight)
|
||||
@@ -100,6 +101,7 @@ class Attention(nn.Module):
|
||||
dots = self.apply_pos_bias(dots)
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h (x y) d -> b (h d) x y', h = h, y = y)
|
||||
|
||||
Reference in New Issue
Block a user