From a2df3632244c796b45c90e1e4066a784664fb669 Mon Sep 17 00:00:00 2001 From: Zack Ankner Date: Thu, 29 Apr 2021 15:43:22 -0400 Subject: [PATCH] adding un-normalizing targets and fix for mask token dimension --- vit_pytorch/mpp.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/vit_pytorch/mpp.py b/vit_pytorch/mpp.py index ef59a79..e6fdf78 100644 --- a/vit_pytorch/mpp.py +++ b/vit_pytorch/mpp.py @@ -32,14 +32,27 @@ def get_mask_subset_with_prob(patched_input, prob): class MPPLoss(nn.Module): def __init__(self, patch_size, channels, output_channel_bits, - max_pixel_val): + max_pixel_val, mean, std): super(MPPLoss, self).__init__() self.patch_size = patch_size self.channels = channels self.output_channel_bits = output_channel_bits self.max_pixel_val = max_pixel_val + if mean: + self.mean = torch.tensor(mean).view(-1, 1, 1) + else: + self.mean = None + if std: + self.std = torch.tensor(std).view(-1, 1, 1) + else: + self.std = None + def forward(self, predicted_patches, target, mask): + # un-normalize input + if self.mean is not None and self.std is not None: + target = target * self.std + self.mean + # reshape target to patches p = self.patch_size target = rearrange(target, @@ -64,7 +77,6 @@ class MPPLoss(nn.Module): -1).to(discretized_target.device, discretized_target.dtype) target_label = torch.sum(bin_mask * discretized_target, -1) - predicted_patches = predicted_patches[mask] target_label = target_label[mask] loss = F.cross_entropy(predicted_patches, target_label) @@ -84,12 +96,14 @@ class MPP(nn.Module): max_pixel_val=1.0, mask_prob=0.15, replace_prob=0.5, - random_patch_prob=0.5): + random_patch_prob=0.5, + mean=None, + std=None): super().__init__() self.transformer = transformer self.loss = MPPLoss(patch_size, channels, output_channel_bits, - max_pixel_val) + max_pixel_val, mean, std) # output transformation self.to_bits = nn.Linear(dim, 2**(output_channel_bits * channels)) @@ -103,7 +117,7 @@ class MPP(nn.Module): self.random_patch_prob = random_patch_prob # token ids - self.mask_token = nn.Parameter(torch.randn(1, 1, dim * channels)) + self.mask_token = nn.Parameter(torch.randn(1, 1, channels * patch_size ** 2)) def forward(self, input, **kwargs): transformer = self.transformer