From d4daf7bd0fecae9d3bc7dd4ea1beb24d67f553f8 Mon Sep 17 00:00:00 2001 From: roydenwa <65123203+roydenwa@users.noreply.github.com> Date: Mon, 24 Jul 2023 15:43:01 +0200 Subject: [PATCH] Support SimpleViT as encoder in MAE (#272) support simplevit in mae --- vit_pytorch/mae.py | 5 ++++- vit_pytorch/simple_vit.py | 31 +++++++++++++++++-------------- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/vit_pytorch/mae.py b/vit_pytorch/mae.py index e2b993c..7eb1679 100644 --- a/vit_pytorch/mae.py +++ b/vit_pytorch/mae.py @@ -49,7 +49,10 @@ class MAE(nn.Module): # patch to encoder tokens and add positions tokens = self.patch_to_emb(patches) - tokens = tokens + self.encoder.pos_embedding[:, 1:(num_patches + 1)] + if self.encoder.pool == "cls": + tokens += self.encoder.pos_embedding[:, 1:(num_patches + 1)] + elif self.encoder.pool == "mean": + tokens += self.encoder.pos_embedding.to(device, dtype=tokens.dtype) # calculate of patches needed to be masked, and get random indices, dividing it up for mask vs unmasked diff --git a/vit_pytorch/simple_vit.py b/vit_pytorch/simple_vit.py index 2d6acf0..2b63b60 100644 --- a/vit_pytorch/simple_vit.py +++ b/vit_pytorch/simple_vit.py @@ -9,17 +9,15 @@ from einops.layers.torch import Rearrange def pair(t): return t if isinstance(t, tuple) else (t, t) -def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32): - _, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype - - y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij') - assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb' - omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1) - omega = 1. / (temperature ** omega) +def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32): + y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") + assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" + omega = torch.arange(dim // 4) / (dim // 4 - 1) + omega = 1.0 / (temperature ** omega) y = y.flatten()[:, None] * omega[None, :] - x = x.flatten()[:, None] * omega[None, :] - pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1) + x = x.flatten()[:, None] * omega[None, :] + pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) return pe.type(dtype) # classes @@ -86,16 +84,21 @@ class SimpleViT(nn.Module): assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' - num_patches = (image_height // patch_height) * (image_width // patch_width) patch_dim = channels * patch_height * patch_width self.to_patch_embedding = nn.Sequential( - Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_height, p2 = patch_width), + Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width), nn.LayerNorm(patch_dim), nn.Linear(patch_dim, dim), nn.LayerNorm(dim), ) + self.pos_embedding = posemb_sincos_2d( + h = image_height // patch_height, + w = image_width // patch_width, + dim = dim, + ) + self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim) self.to_latent = nn.Identity() @@ -103,13 +106,13 @@ class SimpleViT(nn.Module): nn.LayerNorm(dim), nn.Linear(dim, num_classes) ) + self.pool = "mean" def forward(self, img): - *_, h, w, dtype = *img.shape, img.dtype + device = img.device x = self.to_patch_embedding(img) - pe = posemb_sincos_2d(x) - x = rearrange(x, 'b ... d -> b (...) d') + pe + x += self.pos_embedding.to(device, dtype=x.dtype) x = self.transformer(x) x = x.mean(dim = 1)