diff --git a/setup.py b/setup.py index f6da4c3..4a80dbc 100644 --- a/setup.py +++ b/setup.py @@ -6,10 +6,10 @@ with open('README.md') as f: setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '1.8.8', + version = '1.8.9', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', - long_description=long_description, + long_description = long_description, long_description_content_type = 'text/markdown', author = 'Phil Wang', author_email = 'lucidrains@gmail.com', diff --git a/vit_pytorch/simple_vit_with_value_residual.py b/vit_pytorch/simple_vit_with_value_residual.py index b87d1f5..2713471 100644 --- a/vit_pytorch/simple_vit_with_value_residual.py +++ b/vit_pytorch/simple_vit_with_value_residual.py @@ -38,7 +38,7 @@ def FeedForward(dim, hidden_dim): ) class Attention(Module): - def __init__(self, dim, heads = 8, dim_head = 64): + def __init__(self, dim, heads = 8, dim_head = 64, learned_value_residual_mix = False): super().__init__() inner_dim = dim_head * heads self.heads = heads @@ -50,6 +50,12 @@ class Attention(Module): self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) self.to_out = nn.Linear(inner_dim, dim, bias = False) + self.to_residual_mix = nn.Sequential( + nn.Linear(dim, heads), + nn.Sigmoid(), + Rearrange('b n h -> b h n 1') + ) if learned_value_residual_mix else (lambda _: 0.5) + def forward(self, x, value_residual = None): x = self.norm(x) @@ -57,7 +63,8 @@ class Attention(Module): q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) if exists(value_residual): - v = 0.5 * (v + value_residual) + mix = self.to_residual_mix(x) + v = v * mix + value_residual * (1. - mix) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale @@ -73,9 +80,10 @@ class Transformer(Module): super().__init__() self.norm = nn.LayerNorm(dim) self.layers = ModuleList([]) - for _ in range(depth): + for i in range(depth): + is_first = i == 0 self.layers.append(ModuleList([ - Attention(dim, heads = heads, dim_head = dim_head), + Attention(dim, heads = heads, dim_head = dim_head, learned_value_residual_mix = not is_first), FeedForward(dim, mlp_dim) ])) def forward(self, x):