mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
160 lines
4.6 KiB
Python
160 lines
4.6 KiB
Python
import torch
|
|
from torch import nn
|
|
from torch.nn import Module, ModuleList
|
|
|
|
from einops import rearrange
|
|
from einops.layers.torch import Rearrange
|
|
|
|
# helpers
|
|
|
|
def exists(v):
|
|
return v is not None
|
|
|
|
def default(v, d):
|
|
return v if exists(v) else d
|
|
|
|
def pair(t):
|
|
return t if isinstance(t, tuple) else (t, t)
|
|
|
|
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
|
|
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
|
|
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
|
|
omega = torch.arange(dim // 4) / (dim // 4 - 1)
|
|
omega = 1.0 / (temperature ** omega)
|
|
|
|
y = y.flatten()[:, None] * omega[None, :]
|
|
x = x.flatten()[:, None] * omega[None, :]
|
|
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
|
|
return pe.type(dtype)
|
|
|
|
# classes
|
|
|
|
def FeedForward(dim, hidden_dim):
|
|
return nn.Sequential(
|
|
nn.LayerNorm(dim),
|
|
nn.Linear(dim, hidden_dim),
|
|
nn.GELU(),
|
|
nn.Linear(hidden_dim, dim),
|
|
)
|
|
|
|
class Attention(Module):
|
|
def __init__(self, dim, heads = 8, dim_head = 64, learned_value_residual_mix = False):
|
|
super().__init__()
|
|
inner_dim = dim_head * heads
|
|
self.heads = heads
|
|
self.scale = dim_head ** -0.5
|
|
self.norm = nn.LayerNorm(dim)
|
|
|
|
self.attend = nn.Softmax(dim = -1)
|
|
|
|
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
|
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
|
|
|
self.to_residual_mix = nn.Sequential(
|
|
nn.Linear(dim, heads),
|
|
nn.Sigmoid(),
|
|
Rearrange('b n h -> b h n 1')
|
|
) if learned_value_residual_mix else (lambda _: 0.5)
|
|
|
|
def forward(self, x, value_residual = None):
|
|
x = self.norm(x)
|
|
|
|
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
|
|
|
if exists(value_residual):
|
|
mix = self.to_residual_mix(x)
|
|
v = v * mix + value_residual * (1. - mix)
|
|
|
|
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
|
|
|
attn = self.attend(dots)
|
|
|
|
out = torch.matmul(attn, v)
|
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
|
|
|
return self.to_out(out), v
|
|
|
|
class Transformer(Module):
|
|
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
|
|
super().__init__()
|
|
self.norm = nn.LayerNorm(dim)
|
|
self.layers = ModuleList([])
|
|
for i in range(depth):
|
|
is_first = i == 0
|
|
self.layers.append(ModuleList([
|
|
Attention(dim, heads = heads, dim_head = dim_head, learned_value_residual_mix = not is_first),
|
|
FeedForward(dim, mlp_dim)
|
|
]))
|
|
def forward(self, x):
|
|
value_residual = None
|
|
|
|
for attn, ff in self.layers:
|
|
|
|
attn_out, values = attn(x, value_residual = value_residual)
|
|
value_residual = default(value_residual, values)
|
|
|
|
x = attn_out + x
|
|
x = ff(x) + x
|
|
|
|
return self.norm(x)
|
|
|
|
class SimpleViT(Module):
|
|
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
|
|
super().__init__()
|
|
image_height, image_width = pair(image_size)
|
|
patch_height, patch_width = pair(patch_size)
|
|
|
|
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
|
|
|
patch_dim = channels * patch_height * patch_width
|
|
|
|
self.to_patch_embedding = nn.Sequential(
|
|
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
|
|
nn.LayerNorm(patch_dim),
|
|
nn.Linear(patch_dim, dim),
|
|
nn.LayerNorm(dim),
|
|
)
|
|
|
|
self.pos_embedding = posemb_sincos_2d(
|
|
h = image_height // patch_height,
|
|
w = image_width // patch_width,
|
|
dim = dim,
|
|
)
|
|
|
|
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
|
|
|
|
self.pool = "mean"
|
|
self.to_latent = nn.Identity()
|
|
|
|
self.linear_head = nn.Linear(dim, num_classes)
|
|
|
|
def forward(self, img):
|
|
device = img.device
|
|
|
|
x = self.to_patch_embedding(img)
|
|
x += self.pos_embedding.to(device, dtype=x.dtype)
|
|
|
|
x = self.transformer(x)
|
|
x = x.mean(dim = 1)
|
|
|
|
x = self.to_latent(x)
|
|
return self.linear_head(x)
|
|
|
|
# quick test
|
|
|
|
if __name__ == '__main__':
|
|
v = SimpleViT(
|
|
num_classes = 1000,
|
|
image_size = 256,
|
|
patch_size = 8,
|
|
dim = 1024,
|
|
depth = 6,
|
|
heads = 8,
|
|
mlp_dim = 2048,
|
|
)
|
|
|
|
images = torch.randn(2, 3, 256, 256)
|
|
|
|
logits = v(images)
|