mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
@@ -49,7 +49,10 @@ class MAE(nn.Module):
|
||||
# patch to encoder tokens and add positions
|
||||
|
||||
tokens = self.patch_to_emb(patches)
|
||||
tokens = tokens + self.encoder.pos_embedding[:, 1:(num_patches + 1)]
|
||||
if self.encoder.pool == "cls":
|
||||
tokens += self.encoder.pos_embedding[:, 1:(num_patches + 1)]
|
||||
elif self.encoder.pool == "mean":
|
||||
tokens += self.encoder.pos_embedding.to(device, dtype=tokens.dtype)
|
||||
|
||||
# calculate of patches needed to be masked, and get random indices, dividing it up for mask vs unmasked
|
||||
|
||||
|
||||
@@ -9,17 +9,15 @@ from einops.layers.torch import Rearrange
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
|
||||
_, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype
|
||||
|
||||
y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
|
||||
assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
|
||||
omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1)
|
||||
omega = 1. / (temperature ** omega)
|
||||
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
|
||||
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
|
||||
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
|
||||
omega = torch.arange(dim // 4) / (dim // 4 - 1)
|
||||
omega = 1.0 / (temperature ** omega)
|
||||
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
|
||||
return pe.type(dtype)
|
||||
|
||||
# classes
|
||||
@@ -86,16 +84,21 @@ class SimpleViT(nn.Module):
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_height, p2 = patch_width),
|
||||
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = posemb_sincos_2d(
|
||||
h = image_height // patch_height,
|
||||
w = image_width // patch_width,
|
||||
dim = dim,
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
|
||||
|
||||
self.to_latent = nn.Identity()
|
||||
@@ -103,13 +106,13 @@ class SimpleViT(nn.Module):
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
self.pool = "mean"
|
||||
|
||||
def forward(self, img):
|
||||
*_, h, w, dtype = *img.shape, img.dtype
|
||||
device = img.device
|
||||
|
||||
x = self.to_patch_embedding(img)
|
||||
pe = posemb_sincos_2d(x)
|
||||
x = rearrange(x, 'b ... d -> b (...) d') + pe
|
||||
x += self.pos_embedding.to(device, dtype=x.dtype)
|
||||
|
||||
x = self.transformer(x)
|
||||
x = x.mean(dim = 1)
|
||||
|
||||
Reference in New Issue
Block a user