diff --git a/vit_pytorch/mae.py b/vit_pytorch/mae.py index 271926e..b2d750a 100644 --- a/vit_pytorch/mae.py +++ b/vit_pytorch/mae.py @@ -28,7 +28,7 @@ class MAE(nn.Module): pixel_values_per_patch = self.patch_to_emb.weight.shape[-1] # decoder parameters - + self.decoder_dim = decoder_dim self.enc_to_dec = nn.Linear(encoder_dim, decoder_dim) if encoder_dim != decoder_dim else nn.Identity() self.mask_token = nn.Parameter(torch.randn(decoder_dim)) self.decoder = Transformer(dim = decoder_dim, depth = decoder_depth, heads = decoder_heads, dim_head = decoder_dim_head, mlp_dim = decoder_dim * 4) @@ -73,7 +73,7 @@ class MAE(nn.Module): # reapply decoder position embedding to unmasked tokens - decoder_tokens = decoder_tokens + self.decoder_pos_emb(unmasked_indices) + unmasked_decoder_tokens = decoder_tokens + self.decoder_pos_emb(unmasked_indices) # repeat mask tokens for number of masked, and add the positions using the masked indices derived above @@ -81,13 +81,15 @@ class MAE(nn.Module): mask_tokens = mask_tokens + self.decoder_pos_emb(masked_indices) # concat the masked tokens to the decoder tokens and attend with decoder - - decoder_tokens = torch.cat((mask_tokens, decoder_tokens), dim = 1) + + decoder_tokens = torch.zeros(batch, num_patches, self.decoder_dim, device=device) + decoder_tokens[batch_range, unmasked_indices] = unmasked_decoder_tokens + decoder_tokens[batch_range, masked_indices] = mask_tokens decoded_tokens = self.decoder(decoder_tokens) # splice out the mask tokens and project to pixel values - mask_tokens = decoded_tokens[:, :num_masked] + mask_tokens = decoded_tokens[batch_range, masked_indices] pred_pixel_values = self.to_pixels(mask_tokens) # calculate reconstruction loss