make value residual learned

This commit is contained in:
lucidrains
2024-11-24 08:21:28 -08:00
parent 24196a3e8a
commit 56373c0cbd
2 changed files with 14 additions and 6 deletions

View File

@@ -6,10 +6,10 @@ with open('README.md') as f:
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.8.8',
version = '1.8.9',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description=long_description,
long_description = long_description,
long_description_content_type = 'text/markdown',
author = 'Phil Wang',
author_email = 'lucidrains@gmail.com',

View File

@@ -38,7 +38,7 @@ def FeedForward(dim, hidden_dim):
)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64):
def __init__(self, dim, heads = 8, dim_head = 64, learned_value_residual_mix = False):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
@@ -50,6 +50,12 @@ class Attention(Module):
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
self.to_residual_mix = nn.Sequential(
nn.Linear(dim, heads),
nn.Sigmoid(),
Rearrange('b n h -> b h n 1')
) if learned_value_residual_mix else (lambda _: 0.5)
def forward(self, x, value_residual = None):
x = self.norm(x)
@@ -57,7 +63,8 @@ class Attention(Module):
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
if exists(value_residual):
v = 0.5 * (v + value_residual)
mix = self.to_residual_mix(x)
v = v * mix + value_residual * (1. - mix)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
@@ -73,9 +80,10 @@ class Transformer(Module):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = ModuleList([])
for _ in range(depth):
for i in range(depth):
is_first = i == 0
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head),
Attention(dim, heads = heads, dim_head = dim_head, learned_value_residual_mix = not is_first),
FeedForward(dim, mlp_dim)
]))
def forward(self, x):