Merge pull request #101 from zankner/mpp-fix

Mpp fix
This commit is contained in:
Phil Wang
2021-06-16 14:24:26 -07:00
committed by GitHub

View File

@@ -32,14 +32,27 @@ def get_mask_subset_with_prob(patched_input, prob):
class MPPLoss(nn.Module):
def __init__(self, patch_size, channels, output_channel_bits,
max_pixel_val):
max_pixel_val, mean, std):
super(MPPLoss, self).__init__()
self.patch_size = patch_size
self.channels = channels
self.output_channel_bits = output_channel_bits
self.max_pixel_val = max_pixel_val
if mean:
self.mean = torch.tensor(mean).view(-1, 1, 1)
else:
self.mean = None
if std:
self.std = torch.tensor(std).view(-1, 1, 1)
else:
self.std = None
def forward(self, predicted_patches, target, mask):
# un-normalize input
if self.mean is not None and self.std is not None:
target = target * self.std + self.mean
# reshape target to patches
p = self.patch_size
target = rearrange(target,
@@ -64,7 +77,6 @@ class MPPLoss(nn.Module):
-1).to(discretized_target.device,
discretized_target.dtype)
target_label = torch.sum(bin_mask * discretized_target, -1)
predicted_patches = predicted_patches[mask]
target_label = target_label[mask]
loss = F.cross_entropy(predicted_patches, target_label)
@@ -84,11 +96,13 @@ class MPP(nn.Module):
max_pixel_val=1.0,
mask_prob=0.15,
replace_prob=0.5,
random_patch_prob=0.5):
random_patch_prob=0.5,
mean=None,
std=None):
super().__init__()
self.transformer = transformer
self.loss = MPPLoss(patch_size, channels, output_channel_bits,
max_pixel_val)
max_pixel_val, mean, std)
# output transformation
self.to_bits = nn.Linear(dim, 2**(output_channel_bits * channels))
@@ -102,7 +116,7 @@ class MPP(nn.Module):
self.random_patch_prob = random_patch_prob
# token ids
self.mask_token = nn.Parameter(torch.randn(1, 1, dim * channels))
self.mask_token = nn.Parameter(torch.randn(1, 1, channels * patch_size ** 2))
def forward(self, input, **kwargs):
transformer = self.transformer