Files
vit-pytorch/vit_pytorch/crossformer.py
2021-11-22 17:10:53 -08:00

261 lines
7.8 KiB
Python

import torch
from torch import nn, einsum
from einops import rearrange
from einops.layers.torch import Rearrange, Reduce
import torch.nn.functional as F
# helpers
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
def divisible_by(val, d):
return (val % d) == 0
# cross embed layer
class CrossEmbedLayer(nn.Module):
def __init__(
self,
dim_in,
dim_out,
kernel_sizes,
stride = 2
):
super().__init__()
kernel_sizes = sorted(kernel_sizes)
num_scales = len(kernel_sizes)
# calculate the dimension at each scale
dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]
dim_scales = [*dim_scales, dim_out - sum(dim_scales)]
self.convs = nn.ModuleList([])
for kernel, dim_scale in zip(kernel_sizes, dim_scales):
self.convs.append(nn.Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))
def forward(self, x):
fmaps = tuple(map(lambda conv: conv(x), self.convs))
return torch.cat(fmaps, dim = 1)
# dynamic positional bias
def DynamicPositionBias(dim):
return nn.Sequential(
nn.Linear(2, dim),
nn.LayerNorm(dim),
nn.ReLU(),
nn.Linear(dim, dim),
nn.LayerNorm(dim),
nn.ReLU(),
nn.Linear(dim, dim),
nn.LayerNorm(dim),
nn.ReLU(),
nn.Linear(dim, 1),
Rearrange('... () -> ...')
)
# transformer classes
class LayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (std + self.eps) * self.g + self.b
def FeedForward(dim, mult = 4, dropout = 0.):
return nn.Sequential(
LayerNorm(dim),
nn.Conv2d(dim, dim * mult, 1),
nn.GELU(),
nn.Dropout(dropout),
nn.Conv2d(dim * mult, dim, 1)
)
class Attention(nn.Module):
def __init__(
self,
dim,
attn_type,
window_size,
dim_head = 32,
dropout = 0.
):
super().__init__()
assert attn_type in {'short', 'long'}, 'attention type must be one of local or distant'
heads = dim // dim_head
self.heads = heads
self.scale = dim_head ** -0.5
inner_dim = dim_head * heads
self.attn_type = attn_type
self.window_size = window_size
self.dpb = DynamicPositionBias(dim // 4)
self.norm = LayerNorm(dim)
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(inner_dim, dim, 1)
def forward(self, x):
*_, height, width, heads, wsz, device = *x.shape, self.heads, self.window_size, x.device
# prenorm
x = self.norm(x)
# rearrange for short or long distance attention
if self.attn_type == 'short':
x = rearrange(x, 'b d (h s1) (w s2) -> (b h w) d s1 s2', s1 = wsz, s2 = wsz)
elif self.attn_type == 'long':
x = rearrange(x, 'b d (l1 h) (l2 w) -> (b h w) d l1 l2', l1 = wsz, l2 = wsz)
# queries / keys / values
q, k, v = self.to_qkv(x).chunk(3, dim = 1)
# split heads
q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> b h (x y) d', h = heads), (q, k, v))
q = q * self.scale
sim = einsum('b h i d, b h j d -> b h i j', q, k)
# add dynamic positional bias
i_pos = torch.arange(wsz, device = device)
j_pos = torch.arange(wsz, device = device)
grid = torch.stack(torch.meshgrid(i_pos, j_pos))
grid = rearrange(grid, 'c i j -> (i j) c')
rel_ij = grid[:, None] - grid[None, :]
rel_pos_bias = self.dpb(rel_ij.float())
sim = sim + rel_pos_bias
# attend
attn = sim.softmax(dim = -1)
# merge heads
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = wsz, y = wsz)
out = self.to_out(out)
# rearrange back for long or short distance attention
if self.attn_type == 'short':
out = rearrange(out, '(b h w) d s1 s2 -> b d (h s1) (w s2)', h = height // wsz, w = width // wsz)
elif self.attn_type == 'long':
out = rearrange(out, '(b h w) d l1 l2 -> b d (l1 h) (l2 w)', h = height // wsz, w = width // wsz)
return out
class Transformer(nn.Module):
def __init__(
self,
dim,
*,
local_window_size,
global_window_size,
depth = 4,
dim_head = 32,
attn_dropout = 0.,
ff_dropout = 0.,
):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, attn_type = 'short', window_size = local_window_size, dim_head = dim_head, dropout = attn_dropout),
FeedForward(dim, dropout = ff_dropout),
Attention(dim, attn_type = 'long', window_size = global_window_size, dim_head = dim_head, dropout = attn_dropout),
FeedForward(dim, dropout = ff_dropout)
]))
def forward(self, x):
for short_attn, short_ff, long_attn, long_ff in self.layers:
x = short_attn(x) + x
x = short_ff(x) + x
x = long_attn(x) + x
x = long_ff(x) + x
return x
# classes
class CrossFormer(nn.Module):
def __init__(
self,
*,
dim = (64, 128, 256, 512),
depth = (2, 2, 8, 2),
global_window_size = (8, 4, 2, 1),
local_window_size = 7,
cross_embed_kernel_sizes = ((4, 8, 16, 32), (2, 4), (2, 4), (2, 4)),
cross_embed_strides = (4, 2, 2, 2),
num_classes = 1000,
attn_dropout = 0.,
ff_dropout = 0.,
channels = 3
):
super().__init__()
dim = cast_tuple(dim, 4)
depth = cast_tuple(depth, 4)
global_window_size = cast_tuple(global_window_size, 4)
local_window_size = cast_tuple(local_window_size, 4)
cross_embed_kernel_sizes = cast_tuple(cross_embed_kernel_sizes, 4)
cross_embed_strides = cast_tuple(cross_embed_strides, 4)
assert len(dim) == 4
assert len(depth) == 4
assert len(global_window_size) == 4
assert len(local_window_size) == 4
assert len(cross_embed_kernel_sizes) == 4
assert len(cross_embed_strides) == 4
# dimensions
last_dim = dim[-1]
dims = [channels, *dim]
dim_in_and_out = tuple(zip(dims[:-1], dims[1:]))
# layers
self.layers = nn.ModuleList([])
for (dim_in, dim_out), layers, global_wsz, local_wsz, cel_kernel_sizes, cel_stride in zip(dim_in_and_out, depth, global_window_size, local_window_size, cross_embed_kernel_sizes, cross_embed_strides):
self.layers.append(nn.ModuleList([
CrossEmbedLayer(dim_in, dim_out, cel_kernel_sizes, stride = cel_stride),
Transformer(dim_out, local_window_size = local_wsz, global_window_size = global_wsz, depth = layers, attn_dropout = attn_dropout, ff_dropout = ff_dropout)
]))
# final logits
self.to_logits = nn.Sequential(
Reduce('b c h w -> b c', 'mean'),
nn.Linear(last_dim, num_classes)
)
def forward(self, x):
for cel, transformer in self.layers:
x = cel(x)
x = transformer(x)
return self.to_logits(x)