mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
279 lines
8.3 KiB
Python
279 lines
8.3 KiB
Python
import torch
|
|
from torch import nn
|
|
import torch.nn.functional as F
|
|
from torch.nn import Module, ModuleList
|
|
|
|
from einops import einsum, rearrange, repeat, reduce
|
|
from einops.layers.torch import Rearrange
|
|
|
|
# helpers
|
|
|
|
def exists(val):
|
|
return val is not None
|
|
|
|
def default(val, d):
|
|
return val if exists(val) else d
|
|
|
|
def divisible_by(num, den):
|
|
return (num % den) == 0
|
|
|
|
# simple vit sinusoidal pos emb
|
|
|
|
def posemb_sincos_2d(t, temperature = 10000):
|
|
h, w, d, device = *t.shape[1:], t.device
|
|
y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
|
|
assert (d % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
|
|
omega = torch.arange(d // 4, device = device) / (d // 4 - 1)
|
|
omega = temperature ** -omega
|
|
|
|
y = y.flatten()[:, None] * omega[None, :]
|
|
x = x.flatten()[:, None] * omega[None, :]
|
|
pos = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
|
|
|
|
return pos.float()
|
|
|
|
# bias-less layernorm with unit offset trick (discovered by Ohad Rubin)
|
|
|
|
class LayerNorm(Module):
|
|
def __init__(self, dim):
|
|
super().__init__()
|
|
self.ln = nn.LayerNorm(dim, elementwise_affine = False)
|
|
self.gamma = nn.Parameter(torch.zeros(dim))
|
|
|
|
def forward(self, x):
|
|
normed = self.ln(x)
|
|
return normed * (self.gamma + 1)
|
|
|
|
# mlp
|
|
|
|
def MLP(dim, factor = 4, dropout = 0.):
|
|
hidden_dim = int(dim * factor)
|
|
return nn.Sequential(
|
|
LayerNorm(dim),
|
|
nn.Linear(dim, hidden_dim),
|
|
nn.GELU(),
|
|
nn.Dropout(dropout),
|
|
nn.Linear(hidden_dim, dim),
|
|
nn.Dropout(dropout)
|
|
)
|
|
|
|
# attention
|
|
|
|
class Attention(Module):
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
heads = 8,
|
|
dim_head = 64,
|
|
dropout = 0.,
|
|
cross_attend = False,
|
|
reuse_attention = False
|
|
):
|
|
super().__init__()
|
|
inner_dim = dim_head * heads
|
|
|
|
self.scale = dim_head ** -0.5
|
|
self.heads = heads
|
|
self.reuse_attention = reuse_attention
|
|
self.cross_attend = cross_attend
|
|
|
|
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
|
|
|
|
self.norm = LayerNorm(dim) if not reuse_attention else nn.Identity()
|
|
self.norm_context = LayerNorm(dim) if cross_attend else nn.Identity()
|
|
|
|
self.attend = nn.Softmax(dim = -1)
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
self.to_q = nn.Linear(dim, inner_dim, bias = False) if not reuse_attention else None
|
|
self.to_k = nn.Linear(dim, inner_dim, bias = False) if not reuse_attention else None
|
|
self.to_v = nn.Linear(dim, inner_dim, bias = False)
|
|
|
|
self.to_out = nn.Sequential(
|
|
Rearrange('b h n d -> b n (h d)'),
|
|
nn.Linear(inner_dim, dim, bias = False),
|
|
nn.Dropout(dropout)
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
x,
|
|
context = None,
|
|
return_qk_sim = False,
|
|
qk_sim = None
|
|
):
|
|
x = self.norm(x)
|
|
|
|
assert not (exists(context) ^ self.cross_attend)
|
|
|
|
if self.cross_attend:
|
|
context = self.norm_context(context)
|
|
else:
|
|
context = x
|
|
|
|
v = self.to_v(context)
|
|
v = self.split_heads(v)
|
|
|
|
if not self.reuse_attention:
|
|
qk = (self.to_q(x), self.to_k(context))
|
|
q, k = tuple(self.split_heads(t) for t in qk)
|
|
|
|
q = q * self.scale
|
|
qk_sim = einsum(q, k, 'b h i d, b h j d -> b h i j')
|
|
|
|
else:
|
|
assert exists(qk_sim), 'qk sim matrix must be passed in for reusing previous attention'
|
|
|
|
attn = self.attend(qk_sim)
|
|
attn = self.dropout(attn)
|
|
|
|
out = einsum(attn, v, 'b h i j, b h j d -> b h i d')
|
|
out = self.to_out(out)
|
|
|
|
if not return_qk_sim:
|
|
return out
|
|
|
|
return out, qk_sim
|
|
|
|
# LookViT
|
|
|
|
class LookViT(Module):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
dim,
|
|
image_size,
|
|
num_classes,
|
|
depth = 3,
|
|
patch_size = 16,
|
|
heads = 8,
|
|
mlp_factor = 4,
|
|
dim_head = 64,
|
|
highres_patch_size = 12,
|
|
highres_mlp_factor = 4,
|
|
cross_attn_heads = 8,
|
|
cross_attn_dim_head = 64,
|
|
patch_conv_kernel_size = 7,
|
|
dropout = 0.1,
|
|
channels = 3
|
|
):
|
|
super().__init__()
|
|
assert divisible_by(image_size, highres_patch_size)
|
|
assert divisible_by(image_size, patch_size)
|
|
assert patch_size > highres_patch_size, 'patch size of the main vision transformer should be smaller than the highres patch sizes (that does the `lookup`)'
|
|
assert not divisible_by(patch_conv_kernel_size, 2)
|
|
|
|
self.dim = dim
|
|
self.image_size = image_size
|
|
self.patch_size = patch_size
|
|
|
|
kernel_size = patch_conv_kernel_size
|
|
patch_dim = (highres_patch_size * highres_patch_size) * channels
|
|
|
|
self.to_patches = nn.Sequential(
|
|
Rearrange('b c (h p1) (w p2) -> b (p1 p2 c) h w', p1 = highres_patch_size, p2 = highres_patch_size),
|
|
nn.Conv2d(patch_dim, dim, kernel_size, padding = kernel_size // 2),
|
|
Rearrange('b c h w -> b h w c'),
|
|
LayerNorm(dim),
|
|
)
|
|
|
|
# absolute positions
|
|
|
|
num_patches = (image_size // highres_patch_size) ** 2
|
|
self.pos_embedding = nn.Parameter(torch.randn(num_patches, dim))
|
|
|
|
# lookvit blocks
|
|
|
|
layers = ModuleList([])
|
|
|
|
for _ in range(depth):
|
|
layers.append(ModuleList([
|
|
Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = dropout),
|
|
MLP(dim = dim, factor = mlp_factor, dropout = dropout),
|
|
Attention(dim = dim, dim_head = cross_attn_dim_head, heads = cross_attn_heads, dropout = dropout, cross_attend = True),
|
|
Attention(dim = dim, dim_head = cross_attn_dim_head, heads = cross_attn_heads, dropout = dropout, cross_attend = True, reuse_attention = True),
|
|
LayerNorm(dim),
|
|
MLP(dim = dim, factor = highres_mlp_factor, dropout = dropout)
|
|
]))
|
|
|
|
self.layers = layers
|
|
|
|
self.norm = LayerNorm(dim)
|
|
self.highres_norm = LayerNorm(dim)
|
|
|
|
self.to_logits = nn.Linear(dim, num_classes, bias = False)
|
|
|
|
def forward(self, img):
|
|
assert img.shape[-2:] == (self.image_size, self.image_size)
|
|
|
|
# to patch tokens and positions
|
|
|
|
highres_tokens = self.to_patches(img)
|
|
size = highres_tokens.shape[-2]
|
|
|
|
pos_emb = posemb_sincos_2d(highres_tokens)
|
|
highres_tokens = highres_tokens + rearrange(pos_emb, '(h w) d -> h w d', h = size)
|
|
|
|
tokens = F.interpolate(
|
|
rearrange(highres_tokens, 'b h w d -> b d h w'),
|
|
img.shape[-1] // self.patch_size,
|
|
mode = 'bilinear'
|
|
)
|
|
|
|
tokens = rearrange(tokens, 'b c h w -> b (h w) c')
|
|
highres_tokens = rearrange(highres_tokens, 'b h w c -> b (h w) c')
|
|
|
|
# attention and feedforwards
|
|
|
|
for attn, mlp, lookup_cross_attn, highres_attn, highres_norm, highres_mlp in self.layers:
|
|
|
|
# main tokens cross attends (lookup) on the high res tokens
|
|
|
|
lookup_out, qk_sim = lookup_cross_attn(tokens, highres_tokens, return_qk_sim = True) # return attention as they reuse the attention matrix
|
|
tokens = lookup_out + tokens
|
|
|
|
tokens = attn(tokens) + tokens
|
|
tokens = mlp(tokens) + tokens
|
|
|
|
# attention-reuse
|
|
|
|
qk_sim = rearrange(qk_sim, 'b h i j -> b h j i') # transpose for reverse cross attention
|
|
|
|
highres_tokens = highres_attn(highres_tokens, tokens, qk_sim = qk_sim) + highres_tokens
|
|
highres_tokens = highres_norm(highres_tokens)
|
|
|
|
highres_tokens = highres_mlp(highres_tokens) + highres_tokens
|
|
|
|
# to logits
|
|
|
|
tokens = self.norm(tokens)
|
|
highres_tokens = self.highres_norm(highres_tokens)
|
|
|
|
tokens = reduce(tokens, 'b n d -> b d', 'mean')
|
|
highres_tokens = reduce(highres_tokens, 'b n d -> b d', 'mean')
|
|
|
|
return self.to_logits(tokens + highres_tokens)
|
|
|
|
# main
|
|
|
|
if __name__ == '__main__':
|
|
v = LookViT(
|
|
image_size = 256,
|
|
num_classes = 1000,
|
|
dim = 512,
|
|
depth = 2,
|
|
heads = 8,
|
|
dim_head = 64,
|
|
patch_size = 32,
|
|
highres_patch_size = 8,
|
|
highres_mlp_factor = 2,
|
|
cross_attn_heads = 8,
|
|
cross_attn_dim_head = 64,
|
|
dropout = 0.1
|
|
).cuda()
|
|
|
|
img = torch.randn(2, 3, 256, 256).cuda()
|
|
pred = v(img)
|
|
|
|
assert pred.shape == (2, 1000)
|