mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
151 lines
4.8 KiB
Python
151 lines
4.8 KiB
Python
from math import sqrt
|
|
import torch
|
|
from torch import nn, einsum
|
|
import torch.nn.functional as F
|
|
|
|
from einops import rearrange, repeat
|
|
from einops.layers.torch import Rearrange
|
|
|
|
# classes
|
|
|
|
class Residual(nn.Module):
|
|
def __init__(self, fn):
|
|
super().__init__()
|
|
self.fn = fn
|
|
|
|
def forward(self, x, **kwargs):
|
|
return self.fn(x, **kwargs) + x
|
|
|
|
class ExcludeCLS(nn.Module):
|
|
def __init__(self, fn):
|
|
super().__init__()
|
|
self.fn = fn
|
|
|
|
def forward(self, x, **kwargs):
|
|
cls_token, x = x[:, :1], x[:, 1:]
|
|
x = self.fn(x, **kwargs)
|
|
return torch.cat((cls_token, x), dim = 1)
|
|
|
|
# feed forward related classes
|
|
|
|
class DepthWiseConv2d(nn.Module):
|
|
def __init__(self, dim_in, dim_out, kernel_size, padding, stride = 1, bias = True):
|
|
super().__init__()
|
|
self.net = nn.Sequential(
|
|
nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
|
|
nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)
|
|
)
|
|
def forward(self, x):
|
|
return self.net(x)
|
|
|
|
class FeedForward(nn.Module):
|
|
def __init__(self, dim, hidden_dim, dropout = 0.):
|
|
super().__init__()
|
|
self.net = nn.Sequential(
|
|
nn.LayerNorm(dim),
|
|
nn.Conv2d(dim, hidden_dim, 1),
|
|
nn.Hardswish(),
|
|
DepthWiseConv2d(hidden_dim, hidden_dim, 3, padding = 1),
|
|
nn.Hardswish(),
|
|
nn.Dropout(dropout),
|
|
nn.Conv2d(hidden_dim, dim, 1),
|
|
nn.Dropout(dropout)
|
|
)
|
|
def forward(self, x):
|
|
h = w = int(sqrt(x.shape[-2]))
|
|
x = rearrange(x, 'b (h w) c -> b c h w', h = h, w = w)
|
|
x = self.net(x)
|
|
x = rearrange(x, 'b c h w -> b (h w) c')
|
|
return x
|
|
|
|
# attention
|
|
|
|
class Attention(nn.Module):
|
|
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
|
super().__init__()
|
|
inner_dim = dim_head * heads
|
|
|
|
self.heads = heads
|
|
self.scale = dim_head ** -0.5
|
|
|
|
self.norm = nn.LayerNorm(dim)
|
|
self.attend = nn.Softmax(dim = -1)
|
|
self.dropout = nn.Dropout(dropout)
|
|
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
|
|
|
self.to_out = nn.Sequential(
|
|
nn.Linear(inner_dim, dim),
|
|
nn.Dropout(dropout)
|
|
)
|
|
|
|
def forward(self, x):
|
|
b, n, _, h = *x.shape, self.heads
|
|
|
|
x = self.norm(x)
|
|
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
|
|
|
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
|
|
|
attn = self.attend(dots)
|
|
attn = self.dropout(attn)
|
|
|
|
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
|
return self.to_out(out)
|
|
|
|
class Transformer(nn.Module):
|
|
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
|
super().__init__()
|
|
self.layers = nn.ModuleList([])
|
|
for _ in range(depth):
|
|
self.layers.append(nn.ModuleList([
|
|
Residual(Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
|
|
ExcludeCLS(Residual(FeedForward(dim, mlp_dim, dropout = dropout)))
|
|
]))
|
|
def forward(self, x):
|
|
for attn, ff in self.layers:
|
|
x = attn(x)
|
|
x = ff(x)
|
|
return x
|
|
|
|
# main class
|
|
|
|
class LocalViT(nn.Module):
|
|
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
|
|
super().__init__()
|
|
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
|
|
num_patches = (image_size // patch_size) ** 2
|
|
patch_dim = channels * patch_size ** 2
|
|
|
|
self.to_patch_embedding = nn.Sequential(
|
|
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
|
nn.LayerNorm(patch_dim),
|
|
nn.Linear(patch_dim, dim),
|
|
nn.LayerNorm(dim),
|
|
)
|
|
|
|
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
|
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
|
self.dropout = nn.Dropout(emb_dropout)
|
|
|
|
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
|
|
|
self.mlp_head = nn.Sequential(
|
|
nn.LayerNorm(dim),
|
|
nn.Linear(dim, num_classes)
|
|
)
|
|
|
|
def forward(self, img):
|
|
x = self.to_patch_embedding(img)
|
|
b, n, _ = x.shape
|
|
|
|
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
|
x = torch.cat((cls_tokens, x), dim=1)
|
|
x += self.pos_embedding[:, :(n + 1)]
|
|
x = self.dropout(x)
|
|
|
|
x = self.transformer(x)
|
|
|
|
return self.mlp_head(x[:, 0])
|