From 4e6a42a0ca9c2981bc6299d5ff954d9032ab773c Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 30 Mar 2022 10:50:57 -0700 Subject: [PATCH] correct need for post-attention dropout --- setup.py | 2 +- vit_pytorch/ats_vit.py | 3 +++ vit_pytorch/cait.py | 4 ++++ vit_pytorch/cross_vit.py | 3 +++ vit_pytorch/crossformer.py | 4 ++++ vit_pytorch/cvt.py | 2 ++ vit_pytorch/deepvit.py | 3 +++ vit_pytorch/levit.py | 2 ++ vit_pytorch/local_vit.py | 2 ++ vit_pytorch/mobile_vit.py | 5 +++++ vit_pytorch/nest.py | 2 ++ vit_pytorch/parallel_vit.py | 3 +++ vit_pytorch/pit.py | 2 ++ vit_pytorch/regionvit.py | 8 +++++++- vit_pytorch/rvt.py | 2 ++ vit_pytorch/scalable_vit.py | 4 ++++ vit_pytorch/twins_svt.py | 3 +++ vit_pytorch/vit.py | 3 +++ vit_pytorch/vit_for_small_dataset.py | 3 +++ vit_pytorch/vit_with_patch_merger.py | 3 +++ 20 files changed, 61 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index dfd9a0e..7b05251 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.29.1', + version = '0.30.0', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', author = 'Phil Wang', diff --git a/vit_pytorch/ats_vit.py b/vit_pytorch/ats_vit.py index 0af1017..69951be 100644 --- a/vit_pytorch/ats_vit.py +++ b/vit_pytorch/ats_vit.py @@ -139,6 +139,8 @@ class Attention(nn.Module): self.scale = dim_head ** -0.5 self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) self.output_num_tokens = output_num_tokens @@ -163,6 +165,7 @@ class Attention(nn.Module): dots = dots.masked_fill(~dots_mask, mask_value) attn = self.attend(dots) + attn = self.dropout(attn) sampled_token_ids = None diff --git a/vit_pytorch/cait.py b/vit_pytorch/cait.py index 0572945..5968c6c 100644 --- a/vit_pytorch/cait.py +++ b/vit_pytorch/cait.py @@ -76,6 +76,7 @@ class Attention(nn.Module): self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) self.mix_heads_pre_attn = nn.Parameter(torch.randn(heads, heads)) self.mix_heads_post_attn = nn.Parameter(torch.randn(heads, heads)) @@ -96,7 +97,10 @@ class Attention(nn.Module): dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale dots = einsum('b h i j, h g -> b g i j', dots, self.mix_heads_pre_attn) # talking heads, pre-softmax + attn = self.attend(dots) + attn = self.dropout(attn) + attn = einsum('b h i j, h g -> b g i j', attn, self.mix_heads_post_attn) # talking heads, post-softmax out = einsum('b h i j, b h j d -> b h i d', attn, v) diff --git a/vit_pytorch/cross_vit.py b/vit_pytorch/cross_vit.py index 881f2ac..4bb637f 100644 --- a/vit_pytorch/cross_vit.py +++ b/vit_pytorch/cross_vit.py @@ -48,6 +48,8 @@ class Attention(nn.Module): self.scale = dim_head ** -0.5 self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + self.to_q = nn.Linear(dim, inner_dim, bias = False) self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) @@ -69,6 +71,7 @@ class Attention(nn.Module): dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale attn = self.attend(dots) + attn = self.dropout(attn) out = einsum('b h i j, b h j d -> b h i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') diff --git a/vit_pytorch/crossformer.py b/vit_pytorch/crossformer.py index bc7c78a..65ec3cb 100644 --- a/vit_pytorch/crossformer.py +++ b/vit_pytorch/crossformer.py @@ -95,6 +95,9 @@ class Attention(nn.Module): self.window_size = window_size self.norm = LayerNorm(dim) + + self.dropout = nn.Dropout(dropout) + self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False) self.to_out = nn.Conv2d(inner_dim, dim, 1) @@ -151,6 +154,7 @@ class Attention(nn.Module): # attend attn = sim.softmax(dim = -1) + attn = self.dropout(attn) # merge heads diff --git a/vit_pytorch/cvt.py b/vit_pytorch/cvt.py index 62406ec..add9ce9 100644 --- a/vit_pytorch/cvt.py +++ b/vit_pytorch/cvt.py @@ -76,6 +76,7 @@ class Attention(nn.Module): self.scale = dim_head ** -0.5 self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) self.to_q = DepthWiseConv2d(dim, inner_dim, proj_kernel, padding = padding, stride = 1, bias = False) self.to_kv = DepthWiseConv2d(dim, inner_dim * 2, proj_kernel, padding = padding, stride = kv_proj_stride, bias = False) @@ -94,6 +95,7 @@ class Attention(nn.Module): dots = einsum('b i d, b j d -> b i j', q, k) * self.scale attn = self.attend(dots) + attn = self.dropout(attn) out = einsum('b i j, b j d -> b i d', attn, v) out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, y = y) diff --git a/vit_pytorch/deepvit.py b/vit_pytorch/deepvit.py index bf9d228..787034a 100644 --- a/vit_pytorch/deepvit.py +++ b/vit_pytorch/deepvit.py @@ -42,6 +42,8 @@ class Attention(nn.Module): self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + self.dropout = nn.Dropout(dropout) + self.reattn_weights = nn.Parameter(torch.randn(heads, heads)) self.reattn_norm = nn.Sequential( @@ -64,6 +66,7 @@ class Attention(nn.Module): dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale attn = dots.softmax(dim=-1) + attn = self.dropout(attn) # re-attention diff --git a/vit_pytorch/levit.py b/vit_pytorch/levit.py index a823231..ffb3efa 100644 --- a/vit_pytorch/levit.py +++ b/vit_pytorch/levit.py @@ -52,6 +52,7 @@ class Attention(nn.Module): self.to_v = nn.Sequential(nn.Conv2d(dim, inner_dim_value, 1, bias = False), nn.BatchNorm2d(inner_dim_value)) self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) out_batch_norm = nn.BatchNorm2d(dim_out) nn.init.zeros_(out_batch_norm.weight) @@ -100,6 +101,7 @@ class Attention(nn.Module): dots = self.apply_pos_bias(dots) attn = self.attend(dots) + attn = self.dropout(attn) out = einsum('b h i j, b h j d -> b h i d', attn, v) out = rearrange(out, 'b h (x y) d -> b (h d) x y', h = h, y = y) diff --git a/vit_pytorch/local_vit.py b/vit_pytorch/local_vit.py index 8a08cff..bf9716b 100644 --- a/vit_pytorch/local_vit.py +++ b/vit_pytorch/local_vit.py @@ -78,6 +78,7 @@ class Attention(nn.Module): self.scale = dim_head ** -0.5 self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) self.to_out = nn.Sequential( @@ -93,6 +94,7 @@ class Attention(nn.Module): dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale attn = self.attend(dots) + attn = self.dropout(attn) out = einsum('b h i j, b h j d -> b h i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') diff --git a/vit_pytorch/mobile_vit.py b/vit_pytorch/mobile_vit.py index 34b933e..c1a951f 100644 --- a/vit_pytorch/mobile_vit.py +++ b/vit_pytorch/mobile_vit.py @@ -54,6 +54,8 @@ class Attention(nn.Module): self.scale = dim_head ** -0.5 self.attend = nn.Softmax(dim=-1) + self.dropout = nn.Dropout(dropout) + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) self.to_out = nn.Sequential( @@ -67,7 +69,10 @@ class Attention(nn.Module): t, 'b p n (h d) -> b p h n d', h=self.heads), qkv) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + attn = self.attend(dots) + attn = self.dropout(attn) + out = torch.matmul(attn, v) out = rearrange(out, 'b p h n d -> b p n (h d)') return self.to_out(out) diff --git a/vit_pytorch/nest.py b/vit_pytorch/nest.py index 77edbec..b36da48 100644 --- a/vit_pytorch/nest.py +++ b/vit_pytorch/nest.py @@ -55,6 +55,7 @@ class Attention(nn.Module): self.scale = dim_head ** -0.5 self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False) self.to_out = nn.Sequential( @@ -71,6 +72,7 @@ class Attention(nn.Module): dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale attn = self.attend(dots) + attn = self.dropout(attn) out = einsum('b h i j, b h j d -> b h i d', attn, v) out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w) diff --git a/vit_pytorch/parallel_vit.py b/vit_pytorch/parallel_vit.py index 62c574b..bd736d2 100644 --- a/vit_pytorch/parallel_vit.py +++ b/vit_pytorch/parallel_vit.py @@ -50,6 +50,8 @@ class Attention(nn.Module): self.scale = dim_head ** -0.5 self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) self.to_out = nn.Sequential( @@ -64,6 +66,7 @@ class Attention(nn.Module): dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale attn = self.attend(dots) + attn = self.dropout(attn) out = torch.matmul(attn, v) out = rearrange(out, 'b h n d -> b n (h d)') diff --git a/vit_pytorch/pit.py b/vit_pytorch/pit.py index 35d123d..7ed257a 100644 --- a/vit_pytorch/pit.py +++ b/vit_pytorch/pit.py @@ -48,6 +48,7 @@ class Attention(nn.Module): self.scale = dim_head ** -0.5 self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) self.to_out = nn.Sequential( @@ -63,6 +64,7 @@ class Attention(nn.Module): dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale attn = self.attend(dots) + attn = self.dropout(attn) out = einsum('b h i j, b h j d -> b h i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') diff --git a/vit_pytorch/regionvit.py b/vit_pytorch/regionvit.py index 045468b..bdb6095 100644 --- a/vit_pytorch/regionvit.py +++ b/vit_pytorch/regionvit.py @@ -61,8 +61,13 @@ class Attention(nn.Module): inner_dim = dim_head * heads self.norm = nn.LayerNorm(dim) + self.dropout = nn.Dropout(dropout) self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) - self.to_out = nn.Linear(inner_dim, dim) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) def forward(self, x, rel_pos_bias = None): h = self.heads @@ -86,6 +91,7 @@ class Attention(nn.Module): sim = sim + rel_pos_bias attn = sim.softmax(dim = -1) + attn = self.dropout(attn) # merge heads diff --git a/vit_pytorch/rvt.py b/vit_pytorch/rvt.py index 18b7a9b..5e95442 100644 --- a/vit_pytorch/rvt.py +++ b/vit_pytorch/rvt.py @@ -104,6 +104,7 @@ class Attention(nn.Module): self.scale = dim_head ** -0.5 self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) self.use_ds_conv = use_ds_conv @@ -148,6 +149,7 @@ class Attention(nn.Module): dots = einsum('b i d, b j d -> b i j', q, k) * self.scale attn = self.attend(dots) + attn = self.dropout(attn) out = einsum('b i j, b j d -> b i d', attn, v) out = rearrange(out, '(b h) n d -> b n (h d)', h = h) diff --git a/vit_pytorch/scalable_vit.py b/vit_pytorch/scalable_vit.py index 3dbb1be..0df4e06 100644 --- a/vit_pytorch/scalable_vit.py +++ b/vit_pytorch/scalable_vit.py @@ -90,6 +90,7 @@ class ScalableSelfAttention(nn.Module): self.heads = heads self.scale = dim_key ** -0.5 self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) self.to_q = nn.Conv2d(dim, dim_key * heads, 1, bias = False) self.to_k = nn.Conv2d(dim, dim_key * heads, reduction_factor, stride = reduction_factor, bias = False) @@ -116,6 +117,7 @@ class ScalableSelfAttention(nn.Module): # attention attn = self.attend(dots) + attn = self.dropout(attn) # aggregate values @@ -141,6 +143,7 @@ class InteractiveWindowedSelfAttention(nn.Module): self.scale = dim_key ** -0.5 self.window_size = window_size self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) self.local_interactive_module = nn.Conv2d(dim_value * heads, dim_value * heads, 3, padding = 1) @@ -176,6 +179,7 @@ class InteractiveWindowedSelfAttention(nn.Module): # attention attn = self.attend(dots) + attn = self.dropout(attn) # aggregate values diff --git a/vit_pytorch/twins_svt.py b/vit_pytorch/twins_svt.py index ec27cc2..8a548da 100644 --- a/vit_pytorch/twins_svt.py +++ b/vit_pytorch/twins_svt.py @@ -130,6 +130,8 @@ class GlobalAttention(nn.Module): self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False) self.to_kv = nn.Conv2d(dim, inner_dim * 2, k, stride = k, bias = False) + self.dropout = nn.Dropout(dropout) + self.to_out = nn.Sequential( nn.Conv2d(inner_dim, dim, 1), nn.Dropout(dropout) @@ -145,6 +147,7 @@ class GlobalAttention(nn.Module): dots = einsum('b i d, b j d -> b i j', q, k) * self.scale attn = dots.softmax(dim = -1) + attn = self.dropout(attn) out = einsum('b i j, b j d -> b i d', attn, v) out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, y = y) diff --git a/vit_pytorch/vit.py b/vit_pytorch/vit.py index 92e2972..8dc01a2 100644 --- a/vit_pytorch/vit.py +++ b/vit_pytorch/vit.py @@ -42,6 +42,8 @@ class Attention(nn.Module): self.scale = dim_head ** -0.5 self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) self.to_out = nn.Sequential( @@ -56,6 +58,7 @@ class Attention(nn.Module): dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale attn = self.attend(dots) + attn = self.dropout(attn) out = torch.matmul(attn, v) out = rearrange(out, 'b h n d -> b n (h d)') diff --git a/vit_pytorch/vit_for_small_dataset.py b/vit_pytorch/vit_for_small_dataset.py index 0e223ce..4884f22 100644 --- a/vit_pytorch/vit_for_small_dataset.py +++ b/vit_pytorch/vit_for_small_dataset.py @@ -42,6 +42,8 @@ class LSA(nn.Module): self.temperature = nn.Parameter(torch.log(torch.tensor(dim_head ** -0.5))) self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) self.to_out = nn.Sequential( @@ -60,6 +62,7 @@ class LSA(nn.Module): dots = dots.masked_fill(mask, mask_value) attn = self.attend(dots) + attn = self.dropout(attn) out = torch.matmul(attn, v) out = rearrange(out, 'b h n d -> b n (h d)') diff --git a/vit_pytorch/vit_with_patch_merger.py b/vit_pytorch/vit_with_patch_merger.py index 3106bb3..5690ea8 100644 --- a/vit_pytorch/vit_with_patch_merger.py +++ b/vit_pytorch/vit_with_patch_merger.py @@ -63,6 +63,8 @@ class Attention(nn.Module): self.scale = dim_head ** -0.5 self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) self.to_out = nn.Sequential( @@ -77,6 +79,7 @@ class Attention(nn.Module): dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale attn = self.attend(dots) + attn = self.dropout(attn) out = torch.matmul(attn, v) out = rearrange(out, 'b h n d -> b n (h d)')