mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
when cross attending in look vit, make sure context tokens are normalized
This commit is contained in:
2
setup.py
2
setup.py
@@ -6,7 +6,7 @@ with open('README.md') as f:
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '1.7.2',
|
||||
version = '1.7.3',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
long_description=long_description,
|
||||
|
||||
@@ -66,6 +66,7 @@ class Attention(Module):
|
||||
heads = 8,
|
||||
dim_head = 64,
|
||||
dropout = 0.,
|
||||
cross_attend = False,
|
||||
reuse_attention = False
|
||||
):
|
||||
super().__init__()
|
||||
@@ -74,10 +75,13 @@ class Attention(Module):
|
||||
self.scale = dim_head ** -0.5
|
||||
self.heads = heads
|
||||
self.reuse_attention = reuse_attention
|
||||
self.cross_attend = cross_attend
|
||||
|
||||
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
|
||||
|
||||
self.norm = LayerNorm(dim) if not reuse_attention else nn.Identity()
|
||||
self.norm_context = LayerNorm(dim) if cross_attend else nn.Identity()
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
@@ -99,7 +103,13 @@ class Attention(Module):
|
||||
attn = None
|
||||
):
|
||||
x = self.norm(x)
|
||||
context = default(context, x)
|
||||
|
||||
assert not (exists(context) ^ self.cross_attend)
|
||||
|
||||
if self.cross_attend:
|
||||
context = self.norm_context(context)
|
||||
else:
|
||||
context = x
|
||||
|
||||
v = self.to_v(context)
|
||||
v = self.split_heads(v)
|
||||
@@ -179,8 +189,8 @@ class LookViT(Module):
|
||||
layers.append(ModuleList([
|
||||
Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = dropout),
|
||||
MLP(dim = dim, factor = mlp_factor, dropout = dropout),
|
||||
Attention(dim = dim, dim_head = cross_attn_dim_head, heads = cross_attn_heads, dropout = dropout),
|
||||
Attention(dim = dim, dim_head = cross_attn_dim_head, heads = cross_attn_heads, dropout = dropout, reuse_attention = True),
|
||||
Attention(dim = dim, dim_head = cross_attn_dim_head, heads = cross_attn_heads, dropout = dropout, cross_attend = True),
|
||||
Attention(dim = dim, dim_head = cross_attn_dim_head, heads = cross_attn_heads, dropout = dropout, cross_attend = True, reuse_attention = True),
|
||||
LayerNorm(dim),
|
||||
MLP(dim = dim, factor = highres_mlp_factor, dropout = dropout)
|
||||
]))
|
||||
|
||||
Reference in New Issue
Block a user