From 0ed18c89eab4df05c47f0ed2ea31756315c392ab Mon Sep 17 00:00:00 2001 From: lucidrains Date: Tue, 7 Mar 2023 19:29:50 -0800 Subject: [PATCH] separate a simple vit from mp3, so that simple vit can be used after being pretrained --- setup.py | 2 +- vit_pytorch/mp3.py | 109 ++++++++++++++++++++++++++++++++------------- 2 files changed, 80 insertions(+), 31 deletions(-) diff --git a/setup.py b/setup.py index 617893c..69f7c9d 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '1.1.0', + version = '1.1.1', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', long_description_content_type = 'text/markdown', diff --git a/vit_pytorch/mp3.py b/vit_pytorch/mp3.py index d3daa7c..42d9dae 100644 --- a/vit_pytorch/mp3.py +++ b/vit_pytorch/mp3.py @@ -7,23 +7,37 @@ from einops.layers.torch import Rearrange # helpers +def exists(val): + return val is not None + +def default(val, d): + return val if exists(val) else d + def pair(t): return t if isinstance(t, tuple) else (t, t) -# pre-layernorm +# positional embedding -class PreNorm(nn.Module): - def __init__(self, dim, fn): - super().__init__() - self.norm = nn.LayerNorm(dim) - self.fn = fn - def forward(self, x, **kwargs): - return self.fn(self.norm(x), **kwargs) +def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32): + _, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype + + y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij') + assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb' + omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1) + omega = 1. / (temperature ** omega) + + y = y.flatten()[:, None] * omega[None, :] + x = x.flatten()[:, None] * omega[None, :] + pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1) + return pe.type(dtype) + +# feedforward class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, dropout = 0.): super().__init__() self.net = nn.Sequential( + nn.LayerNorm(dim), nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), @@ -33,9 +47,9 @@ class FeedForward(nn.Module): def forward(self, x): return self.net(x) -# cross attention +# (cross)attention -class CrossAttention(nn.Module): +class Attention(nn.Module): def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): super().__init__() inner_dim = dim_head * heads @@ -45,6 +59,8 @@ class CrossAttention(nn.Module): self.attend = nn.Softmax(dim = -1) self.dropout = nn.Dropout(dropout) + self.norm = nn.LayerNorm(dim) + self.to_q = nn.Linear(dim, inner_dim, bias = False) self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) @@ -53,9 +69,13 @@ class CrossAttention(nn.Module): nn.Dropout(dropout) ) - def forward(self, x, context): + def forward(self, x, context = None): b, n, _, h = *x.shape, self.heads + x = self.norm(x) + + context = self.norm(context) if exists(context) else x + qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1)) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) @@ -74,33 +94,31 @@ class Transformer(nn.Module): self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ - PreNorm(dim, CrossAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), - PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), + FeedForward(dim, mlp_dim, dropout = dropout) ])) - def forward(self, x, context): + def forward(self, x, context = None): for attn, ff in self.layers: - x = attn(x, context=context) + x + x = attn(x, context = context) + x x = ff(x) + x return x -# Masked Position Prediction Pre-Training - -class MP3(nn.Module): - def __init__(self, *, image_size, patch_size, masking_ratio, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0.): +class ViT(nn.Module): + def __init__(self, *, num_classes, image_size, patch_size, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0.): super().__init__() image_height, image_width = pair(image_size) patch_height, patch_width = pair(patch_size) assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' - assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1' - self.masking_ratio = masking_ratio - num_patches = (image_height // patch_height) * (image_width // patch_width) patch_dim = channels * patch_height * patch_width + self.dim = dim + self.num_patches = num_patches + self.to_patch_embedding = nn.Sequential( - Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), + Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_height, p2 = patch_width), nn.LayerNorm(patch_dim), nn.Linear(patch_dim, dim), nn.LayerNorm(dim), @@ -108,15 +126,46 @@ class MP3(nn.Module): self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) + self.to_latent = nn.Identity() + self.linear_head = nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, num_classes) + ) + + def forward(self, img): + *_, h, w, dtype = *img.shape, img.dtype + + x = self.to_patch_embedding(img) + pe = posemb_sincos_2d(x) + x = rearrange(x, 'b ... d -> b (...) d') + pe + + x = self.transformer(x) + x = x.mean(dim = 1) + + x = self.to_latent(x) + return self.linear_head(x) + +# Masked Position Prediction Pre-Training + +class MP3(nn.Module): + def __init__(self, vit: ViT, masking_ratio): + super().__init__() + self.vit = vit + + assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1' + self.masking_ratio = masking_ratio + + dim = vit.dim self.mlp_head = nn.Sequential( nn.LayerNorm(dim), - nn.Linear(dim, num_patches) + nn.Linear(dim, vit.num_patches) ) - self.out = nn.Softmax(dim = -1) def forward(self, img): device = img.device - tokens = self.to_patch_embedding(img) + tokens = self.vit.to_patch_embedding(img) + tokens = rearrange(tokens, 'b ... d -> b (...) d') + batch, num_patches, *_ = tokens.shape # Masking @@ -127,11 +176,11 @@ class MP3(nn.Module): batch_range = torch.arange(batch, device = device)[:, None] tokens_unmasked = tokens[batch_range, unmasked_indices] - x = rearrange(self.mlp_head(self.transformer(tokens, tokens_unmasked)), 'b n d -> (b n) d') - x = self.out(x) + attended_tokens = self.vit.transformer(tokens, tokens_unmasked) + logits = rearrange(self.mlp_head(attended_tokens), 'b n d -> (b n) d') # Define labels - labels = repeat(torch.arange(num_patches, device = device), 'n -> b n', b = batch).flatten() - loss = F.cross_entropy(x, labels) + labels = repeat(torch.arange(num_patches, device = device), 'n -> (b n)', b = batch) + loss = F.cross_entropy(logits, labels) return loss