diff --git a/README.md b/README.md index bdb2bbf..dcc359a 100644 --- a/README.md +++ b/README.md @@ -437,7 +437,7 @@ mpp_trainer = MPP( opt = torch.optim.Adam(mpp_trainer.parameters(), lr=3e-4) def sample_unlabelled_images(): - return torch.randn(20, 3, 256, 256) + return torch.FloatTensor(20, 3, 256, 256).uniform_(0., 1.) for _ in range(100): images = sample_unlabelled_images() diff --git a/setup.py b/setup.py index 00cc97b..93e7c30 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.19.5', + version = '0.19.6', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', author = 'Phil Wang', diff --git a/vit_pytorch/mpp.py b/vit_pytorch/mpp.py index bf093b1..e8b044d 100644 --- a/vit_pytorch/mpp.py +++ b/vit_pytorch/mpp.py @@ -1,20 +1,20 @@ import math -from functools import reduce import torch from torch import nn import torch.nn.functional as F -from einops import rearrange, repeat +from einops import rearrange, repeat, reduce # helpers +def exists(val): + return val is not None def prob_mask_like(t, prob): batch, seq_length, _ = t.shape return torch.zeros((batch, seq_length)).float().uniform_(0, 1) < prob - def get_mask_subset_with_prob(patched_input, prob): batch, seq_len, _, device = *patched_input.shape, patched_input.device max_masked = math.ceil(prob * seq_len) @@ -31,55 +31,45 @@ def get_mask_subset_with_prob(patched_input, prob): class MPPLoss(nn.Module): - def __init__(self, patch_size, channels, output_channel_bits, - max_pixel_val, mean, std): - super(MPPLoss, self).__init__() + def __init__( + self, + patch_size, + channels, + output_channel_bits, + max_pixel_val, + mean, + std + ): + super().__init__() self.patch_size = patch_size self.channels = channels self.output_channel_bits = output_channel_bits self.max_pixel_val = max_pixel_val - if mean: - self.mean = torch.tensor(mean).view(-1, 1, 1) - else: - self.mean = None - if std: - self.std = torch.tensor(std).view(-1, 1, 1) - else: - self.std = None + self.mean = torch.tensor(mean).view(-1, 1, 1) if mean else None + self.std = torch.tensor(std).view(-1, 1, 1) if std else None def forward(self, predicted_patches, target, mask): + p, c, mpv, bits, device = self.patch_size, self.channels, self.max_pixel_val, self.output_channel_bits, target.device + bin_size = mpv / (2 ** bits) + # un-normalize input - if self.mean is not None and self.std is not None: + if exists(self.mean) and exists(self.std): target = target * self.std + self.mean # reshape target to patches - p = self.patch_size - target = rearrange(target, - "b c (h p1) (w p2) -> b (h w) c (p1 p2) ", - p1=p, - p2=p) + target = target.clamp(max = mpv) # clamp just in case + avg_target = reduce(target, 'b c (h p1) (w p2) -> b (h w) c', 'mean', p1 = p, p2 = p).contiguous() - avg_target = target.mean(dim=3) - - bin_size = self.max_pixel_val / self.output_channel_bits - channel_bins = torch.arange(bin_size, self.max_pixel_val, bin_size).to(avg_target.device) + channel_bins = torch.arange(bin_size, mpv, bin_size, device = device) discretized_target = torch.bucketize(avg_target, channel_bins) - discretized_target = F.one_hot(discretized_target, - self.output_channel_bits) - c, bi = self.channels, self.output_channel_bits - discretized_target = rearrange(discretized_target, - "b n c bi -> b n (c bi)", - c=c, - bi=bi) - bin_mask = 2**torch.arange(c * bi - 1, -1, - -1).to(discretized_target.device, - discretized_target.dtype) - target_label = torch.sum(bin_mask * discretized_target, -1) - predicted_patches = predicted_patches[mask] - target_label = target_label[mask] - loss = F.cross_entropy(predicted_patches, target_label) + bin_mask = (2 ** bits) ** torch.arange(0, c, device = device).long() + bin_mask = rearrange(bin_mask, 'c -> () () c') + + target_label = torch.sum(bin_mask * discretized_target, dim = -1) + + loss = F.cross_entropy(predicted_patches[mask], target_label[mask]) return loss @@ -87,18 +77,20 @@ class MPPLoss(nn.Module): class MPP(nn.Module): - def __init__(self, - transformer, - patch_size, - dim, - output_channel_bits=3, - channels=3, - max_pixel_val=1.0, - mask_prob=0.15, - replace_prob=0.5, - random_patch_prob=0.5, - mean=None, - std=None): + def __init__( + self, + transformer, + patch_size, + dim, + output_channel_bits=3, + channels=3, + max_pixel_val=1.0, + mask_prob=0.15, + replace_prob=0.5, + random_patch_prob=0.5, + mean=None, + std=None + ): super().__init__() self.transformer = transformer self.loss = MPPLoss(patch_size, channels, output_channel_bits,