mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
173 lines
5.7 KiB
Python
173 lines
5.7 KiB
Python
import math
|
|
|
|
import torch
|
|
from torch import nn
|
|
import torch.nn.functional as F
|
|
|
|
from einops import rearrange, repeat, reduce
|
|
|
|
# helpers
|
|
|
|
def exists(val):
|
|
return val is not None
|
|
|
|
def prob_mask_like(t, prob):
|
|
batch, seq_length, _ = t.shape
|
|
return torch.zeros((batch, seq_length)).float().uniform_(0, 1) < prob
|
|
|
|
def get_mask_subset_with_prob(patched_input, prob):
|
|
batch, seq_len, _, device = *patched_input.shape, patched_input.device
|
|
max_masked = math.ceil(prob * seq_len)
|
|
|
|
rand = torch.rand((batch, seq_len), device=device)
|
|
_, sampled_indices = rand.topk(max_masked, dim=-1)
|
|
|
|
new_mask = torch.zeros((batch, seq_len), device=device)
|
|
new_mask.scatter_(1, sampled_indices, 1)
|
|
return new_mask.bool()
|
|
|
|
|
|
# mpp loss
|
|
|
|
|
|
class MPPLoss(nn.Module):
|
|
def __init__(
|
|
self,
|
|
patch_size,
|
|
channels,
|
|
output_channel_bits,
|
|
max_pixel_val,
|
|
mean,
|
|
std
|
|
):
|
|
super().__init__()
|
|
self.patch_size = patch_size
|
|
self.channels = channels
|
|
self.output_channel_bits = output_channel_bits
|
|
self.max_pixel_val = max_pixel_val
|
|
|
|
self.mean = torch.tensor(mean).view(-1, 1, 1) if mean else None
|
|
self.std = torch.tensor(std).view(-1, 1, 1) if std else None
|
|
|
|
def forward(self, predicted_patches, target, mask):
|
|
p, c, mpv, bits, device = self.patch_size, self.channels, self.max_pixel_val, self.output_channel_bits, target.device
|
|
bin_size = mpv / (2 ** bits)
|
|
|
|
# un-normalize input
|
|
if exists(self.mean) and exists(self.std):
|
|
target = target * self.std + self.mean
|
|
|
|
# reshape target to patches
|
|
target = target.clamp(max = mpv) # clamp just in case
|
|
avg_target = reduce(target, 'b c (h p1) (w p2) -> b (h w) c', 'mean', p1 = p, p2 = p).contiguous()
|
|
|
|
channel_bins = torch.arange(bin_size, mpv, bin_size, device = device)
|
|
discretized_target = torch.bucketize(avg_target, channel_bins)
|
|
|
|
bin_mask = (2 ** bits) ** torch.arange(0, c, device = device).long()
|
|
bin_mask = rearrange(bin_mask, 'c -> () () c')
|
|
|
|
target_label = torch.sum(bin_mask * discretized_target, dim = -1)
|
|
|
|
loss = F.cross_entropy(predicted_patches[mask], target_label[mask])
|
|
return loss
|
|
|
|
|
|
# main class
|
|
|
|
|
|
class MPP(nn.Module):
|
|
def __init__(
|
|
self,
|
|
transformer,
|
|
patch_size,
|
|
dim,
|
|
output_channel_bits=3,
|
|
channels=3,
|
|
max_pixel_val=1.0,
|
|
mask_prob=0.15,
|
|
replace_prob=0.5,
|
|
random_patch_prob=0.5,
|
|
mean=None,
|
|
std=None
|
|
):
|
|
super().__init__()
|
|
self.transformer = transformer
|
|
self.loss = MPPLoss(patch_size, channels, output_channel_bits,
|
|
max_pixel_val, mean, std)
|
|
|
|
# output transformation
|
|
self.to_bits = nn.Linear(dim, 2**(output_channel_bits * channels))
|
|
|
|
# vit related dimensions
|
|
self.patch_size = patch_size
|
|
|
|
# mpp related probabilities
|
|
self.mask_prob = mask_prob
|
|
self.replace_prob = replace_prob
|
|
self.random_patch_prob = random_patch_prob
|
|
|
|
# token ids
|
|
self.mask_token = nn.Parameter(torch.randn(1, 1, channels * patch_size ** 2))
|
|
|
|
def forward(self, input, **kwargs):
|
|
transformer = self.transformer
|
|
# clone original image for loss
|
|
img = input.clone().detach()
|
|
|
|
# reshape raw image to patches
|
|
p = self.patch_size
|
|
input = rearrange(input,
|
|
'b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
|
|
p1=p,
|
|
p2=p)
|
|
|
|
mask = get_mask_subset_with_prob(input, self.mask_prob)
|
|
|
|
# mask input with mask patches with probability of `replace_prob` (keep patches the same with probability 1 - replace_prob)
|
|
masked_input = input.clone().detach()
|
|
|
|
# if random token probability > 0 for mpp
|
|
if self.random_patch_prob > 0:
|
|
random_patch_sampling_prob = self.random_patch_prob / (
|
|
1 - self.replace_prob)
|
|
random_patch_prob = prob_mask_like(input,
|
|
random_patch_sampling_prob).to(mask.device)
|
|
|
|
bool_random_patch_prob = mask * (random_patch_prob == True)
|
|
random_patches = torch.randint(0,
|
|
input.shape[1],
|
|
(input.shape[0], input.shape[1]),
|
|
device=input.device)
|
|
randomized_input = masked_input[
|
|
torch.arange(masked_input.shape[0]).unsqueeze(-1),
|
|
random_patches]
|
|
masked_input[bool_random_patch_prob] = randomized_input[
|
|
bool_random_patch_prob]
|
|
|
|
# [mask] input
|
|
replace_prob = prob_mask_like(input, self.replace_prob).to(mask.device)
|
|
bool_mask_replace = (mask * replace_prob) == True
|
|
masked_input[bool_mask_replace] = self.mask_token
|
|
|
|
# linear embedding of patches
|
|
masked_input = transformer.to_patch_embedding[-1](masked_input)
|
|
|
|
# add cls token to input sequence
|
|
b, n, _ = masked_input.shape
|
|
cls_tokens = repeat(transformer.cls_token, '() n d -> b n d', b=b)
|
|
masked_input = torch.cat((cls_tokens, masked_input), dim=1)
|
|
|
|
# add positional embeddings to input
|
|
masked_input += transformer.pos_embedding[:, :(n + 1)]
|
|
masked_input = transformer.dropout(masked_input)
|
|
|
|
# get generator output and get mpp loss
|
|
masked_input = transformer.transformer(masked_input, **kwargs)
|
|
cls_logits = self.to_bits(masked_input)
|
|
logits = cls_logits[:, 1:, :]
|
|
|
|
mpp_loss = self.loss(logits, img, mask)
|
|
|
|
return mpp_loss
|