mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 16:12:29 +00:00
167 lines
5.8 KiB
Python
167 lines
5.8 KiB
Python
import math
|
|
from functools import reduce
|
|
|
|
import torch
|
|
from torch import nn
|
|
import torch.nn.functional as F
|
|
|
|
from einops import rearrange, repeat
|
|
|
|
# helpers
|
|
|
|
|
|
def prob_mask_like(t, prob):
|
|
batch, seq_length, _ = t.shape
|
|
return torch.zeros((batch, seq_length)).float().uniform_(0, 1) < prob
|
|
|
|
|
|
def get_mask_subset_with_prob(patched_input, prob):
|
|
batch, seq_len, _, device = *patched_input.shape, patched_input.device
|
|
max_masked = math.ceil(prob * seq_len)
|
|
|
|
rand = torch.rand((batch, seq_len), device=device)
|
|
_, sampled_indices = rand.topk(max_masked, dim=-1)
|
|
|
|
new_mask = torch.zeros((batch, seq_len), device=device)
|
|
new_mask.scatter_(1, sampled_indices, 1)
|
|
return new_mask.bool()
|
|
|
|
|
|
# mpp loss
|
|
|
|
|
|
class MPPLoss(nn.Module):
|
|
def __init__(self, patch_size, channels, output_channel_bits,
|
|
max_pixel_val):
|
|
super(MPPLoss, self).__init__()
|
|
self.patch_size = patch_size
|
|
self.channels = channels
|
|
self.output_channel_bits = output_channel_bits
|
|
self.max_pixel_val = max_pixel_val
|
|
|
|
def forward(self, predicted_patches, target, mask):
|
|
# reshape target to patches
|
|
p = self.patch_size
|
|
target = rearrange(target,
|
|
"b c (h p1) (w p2) -> b (h w) c (p1 p2) ",
|
|
p1=p,
|
|
p2=p)
|
|
|
|
avg_target = target.mean(dim=3)
|
|
|
|
bin_size = self.max_pixel_val / self.output_channel_bits
|
|
channel_bins = torch.arange(bin_size, self.max_pixel_val, bin_size)
|
|
discretized_target = torch.bucketize(avg_target, channel_bins)
|
|
discretized_target = F.one_hot(discretized_target,
|
|
self.output_channel_bits)
|
|
c, bi = self.channels, self.output_channel_bits
|
|
discretized_target = rearrange(discretized_target,
|
|
"b n c bi -> b n (c bi)",
|
|
c=c,
|
|
bi=bi)
|
|
|
|
bin_mask = 2**torch.arange(c * bi - 1, -1,
|
|
-1).to(discretized_target.device,
|
|
discretized_target.dtype)
|
|
target_label = torch.sum(bin_mask * discretized_target, -1)
|
|
|
|
predicted_patches = predicted_patches[mask]
|
|
target_label = target_label[mask]
|
|
loss = F.cross_entropy(predicted_patches, target_label)
|
|
return loss
|
|
|
|
|
|
# main class
|
|
|
|
|
|
class MPP(nn.Module):
|
|
def __init__(self,
|
|
transformer,
|
|
patch_size,
|
|
dim,
|
|
output_channel_bits=3,
|
|
channels=3,
|
|
max_pixel_val=1.0,
|
|
mask_prob=0.15,
|
|
replace_prob=0.5,
|
|
random_patch_prob=0.5):
|
|
super().__init__()
|
|
|
|
self.transformer = transformer
|
|
self.loss = MPPLoss(patch_size, channels, output_channel_bits,
|
|
max_pixel_val)
|
|
|
|
# output transformation
|
|
self.to_bits = nn.Linear(dim, 2**(output_channel_bits * channels))
|
|
|
|
# vit related dimensions
|
|
self.patch_size = patch_size
|
|
|
|
# mpp related probabilities
|
|
self.mask_prob = mask_prob
|
|
self.replace_prob = replace_prob
|
|
self.random_patch_prob = random_patch_prob
|
|
|
|
# token ids
|
|
self.mask_token = nn.Parameter(torch.randn(1, 1, dim * channels))
|
|
|
|
def forward(self, input, **kwargs):
|
|
transformer = self.transformer
|
|
# clone original image for loss
|
|
img = input.clone().detach()
|
|
|
|
# reshape raw image to patches
|
|
p = self.patch_size
|
|
input = rearrange(input,
|
|
'b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
|
|
p1=p,
|
|
p2=p)
|
|
|
|
mask = get_mask_subset_with_prob(input, self.mask_prob)
|
|
|
|
# mask input with mask patches with probability of `replace_prob` (keep patches the same with probability 1 - replace_prob)
|
|
masked_input = input.clone().detach()
|
|
|
|
# if random token probability > 0 for mpp
|
|
if self.random_patch_prob > 0:
|
|
random_patch_sampling_prob = self.random_patch_prob / (
|
|
1 - self.replace_prob)
|
|
random_patch_prob = prob_mask_like(input,
|
|
random_patch_sampling_prob)
|
|
bool_random_patch_prob = mask * random_patch_prob == True
|
|
random_patches = torch.randint(0,
|
|
input.shape[1],
|
|
(input.shape[0], input.shape[1]),
|
|
device=input.device)
|
|
randomized_input = masked_input[
|
|
torch.arange(masked_input.shape[0]).unsqueeze(-1),
|
|
random_patches]
|
|
masked_input[bool_random_patch_prob] = randomized_input[
|
|
bool_random_patch_prob]
|
|
|
|
# [mask] input
|
|
replace_prob = prob_mask_like(input, self.replace_prob)
|
|
bool_mask_replace = (mask * replace_prob) == True
|
|
masked_input[bool_mask_replace] = self.mask_token
|
|
|
|
# linear embedding of patches
|
|
masked_input = transformer.to_patch_embedding[-1](masked_input)
|
|
|
|
# add cls token to input sequence
|
|
b, n, _ = masked_input.shape
|
|
cls_tokens = repeat(transformer.cls_token, '() n d -> b n d', b=b)
|
|
masked_input = torch.cat((cls_tokens, masked_input), dim=1)
|
|
|
|
# add positional embeddings to input
|
|
masked_input += transformer.pos_embedding[:, :(n + 1)]
|
|
masked_input = transformer.dropout(masked_input)
|
|
|
|
# get generator output and get mpp loss
|
|
masked_input = transformer.transformer(masked_input, **kwargs)
|
|
cls_logits = self.to_bits(masked_input)
|
|
logits = cls_logits[:, 1:, :]
|
|
|
|
mpp_loss = self.loss(logits, img, mask)
|
|
|
|
return mpp_loss
|