Update mae.py (#242)

update mae so decoded tokens can be easily reshaped back to visualize the reconstruction
This commit is contained in:
Srikumar Sastry
2022-10-17 12:41:10 -05:00
committed by GitHub
parent b4853d39c2
commit 9a95e7904e

View File

@@ -28,7 +28,7 @@ class MAE(nn.Module):
pixel_values_per_patch = self.patch_to_emb.weight.shape[-1]
# decoder parameters
self.decoder_dim = decoder_dim
self.enc_to_dec = nn.Linear(encoder_dim, decoder_dim) if encoder_dim != decoder_dim else nn.Identity()
self.mask_token = nn.Parameter(torch.randn(decoder_dim))
self.decoder = Transformer(dim = decoder_dim, depth = decoder_depth, heads = decoder_heads, dim_head = decoder_dim_head, mlp_dim = decoder_dim * 4)
@@ -73,7 +73,7 @@ class MAE(nn.Module):
# reapply decoder position embedding to unmasked tokens
decoder_tokens = decoder_tokens + self.decoder_pos_emb(unmasked_indices)
unmasked_decoder_tokens = decoder_tokens + self.decoder_pos_emb(unmasked_indices)
# repeat mask tokens for number of masked, and add the positions using the masked indices derived above
@@ -81,13 +81,15 @@ class MAE(nn.Module):
mask_tokens = mask_tokens + self.decoder_pos_emb(masked_indices)
# concat the masked tokens to the decoder tokens and attend with decoder
decoder_tokens = torch.cat((mask_tokens, decoder_tokens), dim = 1)
decoder_tokens = torch.zeros(batch, num_patches, self.decoder_dim, device=device)
decoder_tokens[batch_range, unmasked_indices] = unmasked_decoder_tokens
decoder_tokens[batch_range, masked_indices] = mask_tokens
decoded_tokens = self.decoder(decoder_tokens)
# splice out the mask tokens and project to pixel values
mask_tokens = decoded_tokens[:, :num_masked]
mask_tokens = decoded_tokens[batch_range, masked_indices]
pred_pixel_values = self.to_pixels(mask_tokens)
# calculate reconstruction loss