Compare commits

...

1 Commits
1.2.6 ... 1.1.1

Author SHA1 Message Date
lucidrains
0ed18c89ea separate a simple vit from mp3, so that simple vit can be used after being pretrained 2023-03-07 19:29:50 -08:00
2 changed files with 80 additions and 31 deletions

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.1.0',
version = '1.1.1',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description_content_type = 'text/markdown',

View File

@@ -7,23 +7,37 @@ from einops.layers.torch import Rearrange
# helpers
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# pre-layernorm
# positional embedding
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
_, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype
y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1)
omega = 1. / (temperature ** omega)
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
return pe.type(dtype)
# feedforward
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
@@ -33,9 +47,9 @@ class FeedForward(nn.Module):
def forward(self, x):
return self.net(x)
# cross attention
# (cross)attention
class CrossAttention(nn.Module):
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
@@ -45,6 +59,8 @@ class CrossAttention(nn.Module):
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.norm = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
@@ -53,9 +69,13 @@ class CrossAttention(nn.Module):
nn.Dropout(dropout)
)
def forward(self, x, context):
def forward(self, x, context = None):
b, n, _, h = *x.shape, self.heads
x = self.norm(x)
context = self.norm(context) if exists(context) else x
qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
@@ -74,33 +94,31 @@ class Transformer(nn.Module):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, CrossAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x, context):
def forward(self, x, context = None):
for attn, ff in self.layers:
x = attn(x, context=context) + x
x = attn(x, context = context) + x
x = ff(x) + x
return x
# Masked Position Prediction Pre-Training
class MP3(nn.Module):
def __init__(self, *, image_size, patch_size, masking_ratio, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0.):
class ViT(nn.Module):
def __init__(self, *, num_classes, image_size, patch_size, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1'
self.masking_ratio = masking_ratio
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
self.dim = dim
self.num_patches = num_patches
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
@@ -108,15 +126,46 @@ class MP3(nn.Module):
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.to_latent = nn.Identity()
self.linear_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
*_, h, w, dtype = *img.shape, img.dtype
x = self.to_patch_embedding(img)
pe = posemb_sincos_2d(x)
x = rearrange(x, 'b ... d -> b (...) d') + pe
x = self.transformer(x)
x = x.mean(dim = 1)
x = self.to_latent(x)
return self.linear_head(x)
# Masked Position Prediction Pre-Training
class MP3(nn.Module):
def __init__(self, vit: ViT, masking_ratio):
super().__init__()
self.vit = vit
assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1'
self.masking_ratio = masking_ratio
dim = vit.dim
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_patches)
nn.Linear(dim, vit.num_patches)
)
self.out = nn.Softmax(dim = -1)
def forward(self, img):
device = img.device
tokens = self.to_patch_embedding(img)
tokens = self.vit.to_patch_embedding(img)
tokens = rearrange(tokens, 'b ... d -> b (...) d')
batch, num_patches, *_ = tokens.shape
# Masking
@@ -127,11 +176,11 @@ class MP3(nn.Module):
batch_range = torch.arange(batch, device = device)[:, None]
tokens_unmasked = tokens[batch_range, unmasked_indices]
x = rearrange(self.mlp_head(self.transformer(tokens, tokens_unmasked)), 'b n d -> (b n) d')
x = self.out(x)
attended_tokens = self.vit.transformer(tokens, tokens_unmasked)
logits = rearrange(self.mlp_head(attended_tokens), 'b n d -> (b n) d')
# Define labels
labels = repeat(torch.arange(num_patches, device = device), 'n -> b n', b = batch).flatten()
loss = F.cross_entropy(x, labels)
labels = repeat(torch.arange(num_patches, device = device), 'n -> (b n)', b = batch)
loss = F.cross_entropy(logits, labels)
return loss