2021-03-30 22:15:19 -07:00
|
|
|
from math import sqrt
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from torch import nn, einsum
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
from einops import rearrange, repeat
|
|
|
|
|
from einops.layers.torch import Rearrange
|
|
|
|
|
|
|
|
|
|
# helpers
|
|
|
|
|
|
|
|
|
|
def cast_tuple(val, num):
|
|
|
|
|
return val if isinstance(val, tuple) else (val,) * num
|
|
|
|
|
|
|
|
|
|
def conv_output_size(image_size, kernel_size, stride, padding = 0):
|
|
|
|
|
return int(((image_size - kernel_size + (2 * padding)) / stride) + 1)
|
|
|
|
|
|
|
|
|
|
# classes
|
|
|
|
|
|
|
|
|
|
class FeedForward(nn.Module):
|
|
|
|
|
def __init__(self, dim, hidden_dim, dropout = 0.):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.net = nn.Sequential(
|
2023-08-14 09:48:55 -07:00
|
|
|
nn.LayerNorm(dim),
|
2021-03-30 22:15:19 -07:00
|
|
|
nn.Linear(dim, hidden_dim),
|
|
|
|
|
nn.GELU(),
|
|
|
|
|
nn.Dropout(dropout),
|
|
|
|
|
nn.Linear(hidden_dim, dim),
|
|
|
|
|
nn.Dropout(dropout)
|
|
|
|
|
)
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
return self.net(x)
|
|
|
|
|
|
|
|
|
|
class Attention(nn.Module):
|
|
|
|
|
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
|
|
|
|
super().__init__()
|
|
|
|
|
inner_dim = dim_head * heads
|
|
|
|
|
project_out = not (heads == 1 and dim_head == dim)
|
|
|
|
|
|
|
|
|
|
self.heads = heads
|
|
|
|
|
self.scale = dim_head ** -0.5
|
|
|
|
|
|
2023-08-14 09:48:55 -07:00
|
|
|
self.norm = nn.LayerNorm(dim)
|
2021-03-30 22:15:19 -07:00
|
|
|
self.attend = nn.Softmax(dim = -1)
|
2022-03-30 10:50:57 -07:00
|
|
|
self.dropout = nn.Dropout(dropout)
|
2021-03-30 22:15:19 -07:00
|
|
|
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
|
|
|
|
|
|
|
|
|
self.to_out = nn.Sequential(
|
|
|
|
|
nn.Linear(inner_dim, dim),
|
|
|
|
|
nn.Dropout(dropout)
|
|
|
|
|
) if project_out else nn.Identity()
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
b, n, _, h = *x.shape, self.heads
|
2023-08-14 09:48:55 -07:00
|
|
|
|
|
|
|
|
x = self.norm(x)
|
2021-03-30 22:15:19 -07:00
|
|
|
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
|
|
|
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
|
|
|
|
|
|
|
|
|
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
|
|
|
|
|
|
|
|
|
attn = self.attend(dots)
|
2022-03-30 10:50:57 -07:00
|
|
|
attn = self.dropout(attn)
|
2021-03-30 22:15:19 -07:00
|
|
|
|
|
|
|
|
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
|
|
|
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
|
|
|
|
return self.to_out(out)
|
|
|
|
|
|
|
|
|
|
class Transformer(nn.Module):
|
|
|
|
|
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.layers = nn.ModuleList([])
|
|
|
|
|
for _ in range(depth):
|
|
|
|
|
self.layers.append(nn.ModuleList([
|
2023-08-14 09:48:55 -07:00
|
|
|
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
|
|
|
|
FeedForward(dim, mlp_dim, dropout = dropout)
|
2021-03-30 22:15:19 -07:00
|
|
|
]))
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
for attn, ff in self.layers:
|
|
|
|
|
x = attn(x) + x
|
|
|
|
|
x = ff(x) + x
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
# depthwise convolution, for pooling
|
|
|
|
|
|
|
|
|
|
class DepthWiseConv2d(nn.Module):
|
|
|
|
|
def __init__(self, dim_in, dim_out, kernel_size, padding, stride, bias = True):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.net = nn.Sequential(
|
2021-04-29 12:41:00 -07:00
|
|
|
nn.Conv2d(dim_in, dim_out, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
|
|
|
|
|
nn.Conv2d(dim_out, dim_out, kernel_size = 1, bias = bias)
|
2021-03-30 22:15:19 -07:00
|
|
|
)
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
return self.net(x)
|
|
|
|
|
|
|
|
|
|
# pooling layer
|
|
|
|
|
|
|
|
|
|
class Pool(nn.Module):
|
|
|
|
|
def __init__(self, dim):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.downsample = DepthWiseConv2d(dim, dim * 2, kernel_size = 3, stride = 2, padding = 1)
|
|
|
|
|
self.cls_ff = nn.Linear(dim, dim * 2)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
cls_token, tokens = x[:, :1], x[:, 1:]
|
|
|
|
|
|
|
|
|
|
cls_token = self.cls_ff(cls_token)
|
|
|
|
|
|
|
|
|
|
tokens = rearrange(tokens, 'b (h w) c -> b c h w', h = int(sqrt(tokens.shape[1])))
|
|
|
|
|
tokens = self.downsample(tokens)
|
|
|
|
|
tokens = rearrange(tokens, 'b c h w -> b (h w) c')
|
|
|
|
|
|
|
|
|
|
return torch.cat((cls_token, tokens), dim = 1)
|
|
|
|
|
|
|
|
|
|
# main class
|
|
|
|
|
|
|
|
|
|
class PiT(nn.Module):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
*,
|
|
|
|
|
image_size,
|
|
|
|
|
patch_size,
|
|
|
|
|
num_classes,
|
|
|
|
|
dim,
|
|
|
|
|
depth,
|
|
|
|
|
heads,
|
|
|
|
|
mlp_dim,
|
|
|
|
|
dim_head = 64,
|
|
|
|
|
dropout = 0.,
|
2021-11-22 18:08:49 -08:00
|
|
|
emb_dropout = 0.,
|
|
|
|
|
channels = 3
|
2021-03-30 22:15:19 -07:00
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
|
|
|
|
|
assert isinstance(depth, tuple), 'depth must be a tuple of integers, specifying the number of blocks before each downsizing'
|
|
|
|
|
heads = cast_tuple(heads, len(depth))
|
|
|
|
|
|
2021-11-22 18:08:49 -08:00
|
|
|
patch_dim = channels * patch_size ** 2
|
2021-03-30 22:15:19 -07:00
|
|
|
|
|
|
|
|
self.to_patch_embedding = nn.Sequential(
|
|
|
|
|
nn.Unfold(kernel_size = patch_size, stride = patch_size // 2),
|
|
|
|
|
Rearrange('b c n -> b n c'),
|
|
|
|
|
nn.Linear(patch_dim, dim)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
output_size = conv_output_size(image_size, patch_size, patch_size // 2)
|
|
|
|
|
num_patches = output_size ** 2
|
|
|
|
|
|
|
|
|
|
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
|
|
|
|
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
|
|
|
|
self.dropout = nn.Dropout(emb_dropout)
|
|
|
|
|
|
|
|
|
|
layers = []
|
|
|
|
|
|
|
|
|
|
for ind, (layer_depth, layer_heads) in enumerate(zip(depth, heads)):
|
|
|
|
|
not_last = ind < (len(depth) - 1)
|
|
|
|
|
|
|
|
|
|
layers.append(Transformer(dim, layer_depth, layer_heads, dim_head, mlp_dim, dropout))
|
|
|
|
|
|
|
|
|
|
if not_last:
|
|
|
|
|
layers.append(Pool(dim))
|
|
|
|
|
dim *= 2
|
|
|
|
|
|
2021-04-19 22:36:23 -07:00
|
|
|
self.layers = nn.Sequential(*layers)
|
|
|
|
|
|
|
|
|
|
self.mlp_head = nn.Sequential(
|
2021-03-30 22:15:19 -07:00
|
|
|
nn.LayerNorm(dim),
|
|
|
|
|
nn.Linear(dim, num_classes)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def forward(self, img):
|
|
|
|
|
x = self.to_patch_embedding(img)
|
|
|
|
|
b, n, _ = x.shape
|
|
|
|
|
|
|
|
|
|
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
|
|
|
|
x = torch.cat((cls_tokens, x), dim=1)
|
2021-08-21 22:25:46 +09:00
|
|
|
x += self.pos_embedding[:, :n+1]
|
2021-03-30 22:15:19 -07:00
|
|
|
x = self.dropout(x)
|
|
|
|
|
|
2021-04-19 22:36:23 -07:00
|
|
|
x = self.layers(x)
|
|
|
|
|
|
|
|
|
|
return self.mlp_head(x[:, 0])
|