mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
268 lines
8.2 KiB
Python
268 lines
8.2 KiB
Python
import torch
|
|
from torch import nn, einsum
|
|
from einops import rearrange
|
|
from einops.layers.torch import Rearrange, Reduce
|
|
import torch.nn.functional as F
|
|
|
|
# helpers
|
|
|
|
def cast_tuple(val, length = 1):
|
|
return val if isinstance(val, tuple) else ((val,) * length)
|
|
|
|
# cross embed layer
|
|
|
|
class CrossEmbedLayer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim_in,
|
|
dim_out,
|
|
kernel_sizes,
|
|
stride = 2
|
|
):
|
|
super().__init__()
|
|
kernel_sizes = sorted(kernel_sizes)
|
|
num_scales = len(kernel_sizes)
|
|
|
|
# calculate the dimension at each scale
|
|
dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]
|
|
dim_scales = [*dim_scales, dim_out - sum(dim_scales)]
|
|
|
|
self.convs = nn.ModuleList([])
|
|
for kernel, dim_scale in zip(kernel_sizes, dim_scales):
|
|
self.convs.append(nn.Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))
|
|
|
|
def forward(self, x):
|
|
fmaps = tuple(map(lambda conv: conv(x), self.convs))
|
|
return torch.cat(fmaps, dim = 1)
|
|
|
|
# dynamic positional bias
|
|
|
|
def DynamicPositionBias(dim):
|
|
return nn.Sequential(
|
|
nn.Linear(2, dim),
|
|
nn.LayerNorm(dim),
|
|
nn.ReLU(),
|
|
nn.Linear(dim, dim),
|
|
nn.LayerNorm(dim),
|
|
nn.ReLU(),
|
|
nn.Linear(dim, dim),
|
|
nn.LayerNorm(dim),
|
|
nn.ReLU(),
|
|
nn.Linear(dim, 1),
|
|
Rearrange('... () -> ...')
|
|
)
|
|
|
|
# transformer classes
|
|
|
|
class LayerNorm(nn.Module):
|
|
def __init__(self, dim, eps = 1e-5):
|
|
super().__init__()
|
|
self.eps = eps
|
|
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
|
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
|
|
|
def forward(self, x):
|
|
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
|
mean = torch.mean(x, dim = 1, keepdim = True)
|
|
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
|
|
|
|
def FeedForward(dim, mult = 4, dropout = 0.):
|
|
return nn.Sequential(
|
|
LayerNorm(dim),
|
|
nn.Conv2d(dim, dim * mult, 1),
|
|
nn.GELU(),
|
|
nn.Dropout(dropout),
|
|
nn.Conv2d(dim * mult, dim, 1)
|
|
)
|
|
|
|
class Attention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
attn_type,
|
|
window_size,
|
|
dim_head = 32,
|
|
dropout = 0.
|
|
):
|
|
super().__init__()
|
|
assert attn_type in {'short', 'long'}, 'attention type must be one of local or distant'
|
|
heads = dim // dim_head
|
|
self.heads = heads
|
|
self.scale = dim_head ** -0.5
|
|
inner_dim = dim_head * heads
|
|
|
|
self.attn_type = attn_type
|
|
self.window_size = window_size
|
|
|
|
self.norm = LayerNorm(dim)
|
|
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
|
|
self.to_out = nn.Conv2d(inner_dim, dim, 1)
|
|
|
|
# positions
|
|
|
|
self.dpb = DynamicPositionBias(dim // 4)
|
|
|
|
# calculate and store indices for retrieving bias
|
|
|
|
pos = torch.arange(window_size)
|
|
grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
|
|
grid = rearrange(grid, 'c i j -> (i j) c')
|
|
rel_pos = grid[:, None] - grid[None, :]
|
|
rel_pos += window_size - 1
|
|
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1)
|
|
|
|
self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False)
|
|
|
|
def forward(self, x):
|
|
*_, height, width, heads, wsz, device = *x.shape, self.heads, self.window_size, x.device
|
|
|
|
# prenorm
|
|
|
|
x = self.norm(x)
|
|
|
|
# rearrange for short or long distance attention
|
|
|
|
if self.attn_type == 'short':
|
|
x = rearrange(x, 'b d (h s1) (w s2) -> (b h w) d s1 s2', s1 = wsz, s2 = wsz)
|
|
elif self.attn_type == 'long':
|
|
x = rearrange(x, 'b d (l1 h) (l2 w) -> (b h w) d l1 l2', l1 = wsz, l2 = wsz)
|
|
|
|
# queries / keys / values
|
|
|
|
q, k, v = self.to_qkv(x).chunk(3, dim = 1)
|
|
|
|
# split heads
|
|
|
|
q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> b h (x y) d', h = heads), (q, k, v))
|
|
q = q * self.scale
|
|
|
|
sim = einsum('b h i d, b h j d -> b h i j', q, k)
|
|
|
|
# add dynamic positional bias
|
|
|
|
pos = torch.arange(-wsz, wsz + 1, device = device)
|
|
rel_pos = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
|
|
rel_pos = rearrange(rel_pos, 'c i j -> (i j) c')
|
|
biases = self.dpb(rel_pos.float())
|
|
rel_pos_bias = biases[self.rel_pos_indices]
|
|
|
|
sim = sim + rel_pos_bias
|
|
|
|
# attend
|
|
|
|
attn = sim.softmax(dim = -1)
|
|
attn = self.dropout(attn)
|
|
|
|
# merge heads
|
|
|
|
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
|
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = wsz, y = wsz)
|
|
out = self.to_out(out)
|
|
|
|
# rearrange back for long or short distance attention
|
|
|
|
if self.attn_type == 'short':
|
|
out = rearrange(out, '(b h w) d s1 s2 -> b d (h s1) (w s2)', h = height // wsz, w = width // wsz)
|
|
elif self.attn_type == 'long':
|
|
out = rearrange(out, '(b h w) d l1 l2 -> b d (l1 h) (l2 w)', h = height // wsz, w = width // wsz)
|
|
|
|
return out
|
|
|
|
class Transformer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
*,
|
|
local_window_size,
|
|
global_window_size,
|
|
depth = 4,
|
|
dim_head = 32,
|
|
attn_dropout = 0.,
|
|
ff_dropout = 0.,
|
|
):
|
|
super().__init__()
|
|
self.layers = nn.ModuleList([])
|
|
|
|
for _ in range(depth):
|
|
self.layers.append(nn.ModuleList([
|
|
Attention(dim, attn_type = 'short', window_size = local_window_size, dim_head = dim_head, dropout = attn_dropout),
|
|
FeedForward(dim, dropout = ff_dropout),
|
|
Attention(dim, attn_type = 'long', window_size = global_window_size, dim_head = dim_head, dropout = attn_dropout),
|
|
FeedForward(dim, dropout = ff_dropout)
|
|
]))
|
|
|
|
def forward(self, x):
|
|
for short_attn, short_ff, long_attn, long_ff in self.layers:
|
|
x = short_attn(x) + x
|
|
x = short_ff(x) + x
|
|
x = long_attn(x) + x
|
|
x = long_ff(x) + x
|
|
|
|
return x
|
|
|
|
# classes
|
|
|
|
class CrossFormer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
dim = (64, 128, 256, 512),
|
|
depth = (2, 2, 8, 2),
|
|
global_window_size = (8, 4, 2, 1),
|
|
local_window_size = 7,
|
|
cross_embed_kernel_sizes = ((4, 8, 16, 32), (2, 4), (2, 4), (2, 4)),
|
|
cross_embed_strides = (4, 2, 2, 2),
|
|
num_classes = 1000,
|
|
attn_dropout = 0.,
|
|
ff_dropout = 0.,
|
|
channels = 3
|
|
):
|
|
super().__init__()
|
|
|
|
dim = cast_tuple(dim, 4)
|
|
depth = cast_tuple(depth, 4)
|
|
global_window_size = cast_tuple(global_window_size, 4)
|
|
local_window_size = cast_tuple(local_window_size, 4)
|
|
cross_embed_kernel_sizes = cast_tuple(cross_embed_kernel_sizes, 4)
|
|
cross_embed_strides = cast_tuple(cross_embed_strides, 4)
|
|
|
|
assert len(dim) == 4
|
|
assert len(depth) == 4
|
|
assert len(global_window_size) == 4
|
|
assert len(local_window_size) == 4
|
|
assert len(cross_embed_kernel_sizes) == 4
|
|
assert len(cross_embed_strides) == 4
|
|
|
|
# dimensions
|
|
|
|
last_dim = dim[-1]
|
|
dims = [channels, *dim]
|
|
dim_in_and_out = tuple(zip(dims[:-1], dims[1:]))
|
|
|
|
# layers
|
|
|
|
self.layers = nn.ModuleList([])
|
|
|
|
for (dim_in, dim_out), layers, global_wsz, local_wsz, cel_kernel_sizes, cel_stride in zip(dim_in_and_out, depth, global_window_size, local_window_size, cross_embed_kernel_sizes, cross_embed_strides):
|
|
self.layers.append(nn.ModuleList([
|
|
CrossEmbedLayer(dim_in, dim_out, cel_kernel_sizes, stride = cel_stride),
|
|
Transformer(dim_out, local_window_size = local_wsz, global_window_size = global_wsz, depth = layers, attn_dropout = attn_dropout, ff_dropout = ff_dropout)
|
|
]))
|
|
|
|
# final logits
|
|
|
|
self.to_logits = nn.Sequential(
|
|
Reduce('b c h w -> b c', 'mean'),
|
|
nn.Linear(last_dim, num_classes)
|
|
)
|
|
|
|
def forward(self, x):
|
|
for cel, transformer in self.layers:
|
|
x = cel(x)
|
|
x = transformer(x)
|
|
|
|
return self.to_logits(x)
|