mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
make value residual learned
This commit is contained in:
4
setup.py
4
setup.py
@@ -6,10 +6,10 @@ with open('README.md') as f:
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '1.8.8',
|
||||
version = '1.8.9',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
long_description=long_description,
|
||||
long_description = long_description,
|
||||
long_description_content_type = 'text/markdown',
|
||||
author = 'Phil Wang',
|
||||
author_email = 'lucidrains@gmail.com',
|
||||
|
||||
@@ -38,7 +38,7 @@ def FeedForward(dim, hidden_dim):
|
||||
)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, learned_value_residual_mix = False):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
@@ -50,6 +50,12 @@ class Attention(Module):
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
self.to_residual_mix = nn.Sequential(
|
||||
nn.Linear(dim, heads),
|
||||
nn.Sigmoid(),
|
||||
Rearrange('b n h -> b h n 1')
|
||||
) if learned_value_residual_mix else (lambda _: 0.5)
|
||||
|
||||
def forward(self, x, value_residual = None):
|
||||
x = self.norm(x)
|
||||
|
||||
@@ -57,7 +63,8 @@ class Attention(Module):
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
if exists(value_residual):
|
||||
v = 0.5 * (v + value_residual)
|
||||
mix = self.to_residual_mix(x)
|
||||
v = v * mix + value_residual * (1. - mix)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
@@ -73,9 +80,10 @@ class Transformer(Module):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = ModuleList([])
|
||||
for _ in range(depth):
|
||||
for i in range(depth):
|
||||
is_first = i == 0
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head),
|
||||
Attention(dim, heads = heads, dim_head = dim_head, learned_value_residual_mix = not is_first),
|
||||
FeedForward(dim, mlp_dim)
|
||||
]))
|
||||
def forward(self, x):
|
||||
|
||||
Reference in New Issue
Block a user