From 0082301f9e7c99758eff434472407e8cb30db5f5 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 3 Jan 2022 12:56:18 -0800 Subject: [PATCH] build @jrounds suggestion --- vit_pytorch/mae.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/vit_pytorch/mae.py b/vit_pytorch/mae.py index 7b076b5..e095da1 100644 --- a/vit_pytorch/mae.py +++ b/vit_pytorch/mae.py @@ -14,11 +14,13 @@ class MAE(nn.Module): masking_ratio = 0.75, decoder_depth = 1, decoder_heads = 8, - decoder_dim_head = 64 + decoder_dim_head = 64, + apply_decoder_pos_emb_all = False # whether to (re)apply decoder positional embedding to encoder unmasked tokens ): super().__init__() assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1' self.masking_ratio = masking_ratio + self.apply_decoder_pos_emb_all = apply_decoder_pos_emb_all # extract some hyperparameters and functions from encoder (vision transformer to be trained) @@ -71,6 +73,11 @@ class MAE(nn.Module): decoder_tokens = self.enc_to_dec(encoded_tokens) + # reapply decoder position embedding to unmasked tokens, if desired + + if self.apply_decoder_pos_emb_all: + decoder_tokens = decoder_tokens + self.decoder_pos_emb(unmasked_indices) + # repeat mask tokens for number of masked, and add the positions using the masked indices derived above mask_tokens = repeat(self.mask_token, 'd -> b n d', b = batch, n = num_masked)