mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
83 lines
2.8 KiB
Python
83 lines
2.8 KiB
Python
import math
|
|
import torch
|
|
from torch import nn
|
|
|
|
from vit_pytorch.vit_pytorch import Transformer
|
|
|
|
from einops import rearrange, repeat
|
|
from einops.layers.torch import Rearrange
|
|
|
|
# helpers
|
|
|
|
def exists(val):
|
|
return val is not None
|
|
|
|
def conv_output_size(image_size, kernel_size, stride, padding):
|
|
return int(((image_size - kernel_size + (2 * padding)) / stride) + 1)
|
|
|
|
# classes
|
|
|
|
class RearrangeImage(nn.Module):
|
|
def forward(self, x):
|
|
return rearrange(x, 'b (h w) c -> b c h w', h = int(math.sqrt(x.shape[1])))
|
|
|
|
# main class
|
|
|
|
class T2TViT(nn.Module):
|
|
def __init__(self, *, image_size, num_classes, dim, depth = None, heads = None, mlp_dim = None, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., transformer = None, t2t_layers = ((7, 4), (3, 2), (3, 2))):
|
|
super().__init__()
|
|
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
|
|
|
layers = []
|
|
layer_dim = channels
|
|
output_image_size = image_size
|
|
|
|
for i, (kernel_size, stride) in enumerate(t2t_layers):
|
|
layer_dim *= kernel_size ** 2
|
|
is_first = i == 0
|
|
output_image_size = conv_output_size(output_image_size, kernel_size, stride, stride // 2)
|
|
|
|
layers.extend([
|
|
RearrangeImage() if not is_first else nn.Identity(),
|
|
nn.Unfold(kernel_size = kernel_size, stride = stride, padding = stride // 2),
|
|
Rearrange('b c n -> b n c'),
|
|
Transformer(dim = layer_dim, heads = 1, depth = 1, dim_head = layer_dim, mlp_dim = layer_dim, dropout = dropout),
|
|
])
|
|
|
|
layers.append(nn.Linear(layer_dim, dim))
|
|
self.to_patch_embedding = nn.Sequential(*layers)
|
|
|
|
self.pos_embedding = nn.Parameter(torch.randn(1, output_image_size ** 2 + 1, dim))
|
|
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
|
self.dropout = nn.Dropout(emb_dropout)
|
|
|
|
if not exists(transformer):
|
|
assert all([exists(depth), exists(heads), exists(mlp_dim)]), 'depth, heads, and mlp_dim must be supplied'
|
|
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
|
else:
|
|
self.transformer = transformer
|
|
|
|
self.pool = pool
|
|
self.to_latent = nn.Identity()
|
|
|
|
self.mlp_head = nn.Sequential(
|
|
nn.LayerNorm(dim),
|
|
nn.Linear(dim, num_classes)
|
|
)
|
|
|
|
def forward(self, img):
|
|
x = self.to_patch_embedding(img)
|
|
b, n, _ = x.shape
|
|
|
|
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
|
x = torch.cat((cls_tokens, x), dim=1)
|
|
x += self.pos_embedding
|
|
x = self.dropout(x)
|
|
|
|
x = self.transformer(x)
|
|
|
|
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
|
|
|
x = self.to_latent(x)
|
|
return self.mlp_head(x)
|