Working implementation of masked patch prediction as a wrapper. Need to clean code up

This commit is contained in:
Zack Ankner
2021-02-09 22:55:06 -07:00
parent 174e71cf53
commit a0a4fa5e7d
6 changed files with 69 additions and 95 deletions

44
test.py
View File

@@ -1,22 +1,22 @@
import torch
from vit_pytorch import MPP
from vit_pytorch import MPP, ViT
# from vit_pytorch import ViT, MaskedPredictionLoss
# v = ViT(
# image_size = 256,
# patch_size = 32,
# num_classes = 1000,
# dim = 1024,
# depth = 6,
# heads = 16,
# mlp_dim = 2048,
# dropout = 0.1,
# emb_dropout = 0.1
# )
v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 3,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
# l = MaskedPredictionLoss(patch_size=32, img_size=256)
# img = torch.randn(1, 3, 256, 256)
img = torch.randn(2, 3, 256, 256)
# mask = [1,2,3,4] # optional mask, designating which patch to attend to
# preds = v(img) # (1, 1000)
@@ -24,19 +24,19 @@ from vit_pytorch import MPP
# print(preds.shape)
transformer = 5
trainer = MPP(
transformer,
5,
75,
transformer = v,
patch_size = 32,
dim = 1024,
mask_prob = 0.15, # masking probability for masked language modeling
random_patch_prob=0.05,
replace_prob = 0.90, # ~10% probability that token will not be masked, but included in loss, as detailed in the epaper
random_patch_prob=0.30,
replace_prob = 0.50, # ~10% probability that token will not be masked, but included in loss, as detailed in the epaper
)
data = torch.rand((2, 3, 10, 10))
# data = torch.rand((2, 3, 10, 10))
loss = trainer(data)
loss = trainer(img)
print(loss)

View File

@@ -1,3 +1,3 @@
from vit_pytorch.vit_pytorch import ViT
from vit_pytorch.masked_prediction_loss import MaskedPredictionLoss
from vit_pytorch.mpp_loss import MPPLoss
from vit_pytorch.mpp_pytorch import MPP

View File

@@ -1,55 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class MaskedPredictionLoss(nn.Module):
def __init__(self,
patch_size,
img_size,
device="cpu"):
super(MaskedPredictionLoss, self).__init__()
self.patch_size = patch_size
self.num_patch_axis = img_size // patch_size
def _transform_targets(self, targets, masked_patches):
masked_patches_dims = []
for masked_patch in masked_patches:
height_offset = self.patch_size * ((masked_patch) // self.num_patch_axis)
width_offset = self.patch_size * ((masked_patch) % self.num_patch_axis)
masked_patches_dims.append([height_offset, width_offset])
target_patches = []
for target in range(targets.shape[0]):
for masked_patch in masked_patches_dims:
height_offset, width_offset = masked_patch
extracted_patch = targets[target, :,
height_offset:height_offset + self.patch_size,
width_offset:width_offset + self.patch_size]
target_patches.append(extracted_patch)
target_patches_tensor = torch.stack(target_patches)
target_patches_tensor = (target_patches_tensor // 0.34).long()
encoded_targets = F.one_hot(target_patches_tensor, 3)
n, c, w, h, e = encoded_targets.shape
encoded_targets = torch.reshape(encoded_targets, [n, w, h, c * e])
mean_targets = torch.mean(encoded_targets.float(),
dim=[1, 2],
keepdim=True).view(n, c * e)
return mean_targets
def _transform_outputs(self, outputs, masked_patches):
output_dim = outputs.shape[-1]
masked_patches_shifted = [
masked_patch + 1 for masked_patch in masked_patches
]
outputs = outputs[:, masked_patches_shifted, :]
outputs = outputs.view(-1, output_dim)
return outputs
def forward(self, x, y, masked_patches):
transformed_outputs = self._transform_outputs(x, masked_patches)
transformed_targets = self._transform_targets(y, masked_patches)
loss = F.mse_loss(transformed_outputs, transformed_targets)
return loss

26
vit_pytorch/mpp_loss.py Normal file
View File

@@ -0,0 +1,26 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
class MPPLoss(nn.Module):
def __init__(self, patch_size):
super(MPPLoss, self).__init__()
self.patch_size = patch_size
def forward(self, predicted_patches, target, mask):
# reshape target to patches
p = self.patch_size
target = rearrange(target, "b c (h p1) (w p2) -> b (h w) c (p1 p2) ", p1 = p, p2 = p)
channel_bins = torch.tensor([0.333, 0.666, 1.0])
target = torch.bucketize(target, channel_bins, right=True)
target = target.float().mean(dim=3)
predicted_patches = predicted_patches[mask]
target = target[mask]
loss = F.mse_loss(predicted_patches, target)
return loss

View File

@@ -7,6 +7,8 @@ import torch.nn.functional as F
from einops import rearrange
from vit_pytorch import MPPLoss
# helpers
def prob_mask_like(t, prob):
@@ -32,12 +34,14 @@ class MPP(nn.Module):
transformer,
patch_size,
dim,
channels = 3,
mask_prob = 0.15,
replace_prob = 0.5,
random_patch_prob = 0.5):
super().__init__()
self.transformer = transformer
self.loss = MPPLoss(patch_size)
# vit related dimensions
self.patch_size = patch_size
@@ -48,9 +52,11 @@ class MPP(nn.Module):
self.random_patch_prob = random_patch_prob
# token ids
self.mask_token = nn.Parameter(torch.randn(1, 1, dim))
self.mask_token = nn.Parameter(torch.randn(1, 1, dim * channels))
def forward(self, input, **kwargs):
# clone original image for loss
img = input.clone().detach()
# reshape raw image to patches
p = self.patch_size
@@ -75,17 +81,15 @@ class MPP(nn.Module):
bool_mask_replace = (mask * replace_prob) == True
masked_input[bool_mask_replace] = self.mask_token
# set inverse of mask to padding tokens for labels
# get labels for input patches that were masked
bool_mask = mask == True
labels = input[bool_mask]
# get generator output and get mpp loss
logits = self.transformer(masked_input, **kwargs)
cls_logits = self.transformer(masked_input, mpp=True, **kwargs)
logits = cls_logits[:,1:,:]
mpp_loss = F.cross_entropy(
logits.transpose(1, 2),
labels,
)
mpp_loss = self.loss(logits, img, mask)
return mpp_loss

View File

@@ -97,7 +97,6 @@ class ViT(nn.Module):
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.patch_to_embedding = nn.Linear(patch_dim, dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.mask_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
@@ -110,25 +109,25 @@ class ViT(nn.Module):
nn.Linear(dim, num_classes)
)
def forward(self, img, mask = None, prediction_mask = None):
p = self.patch_size
def forward(self, img, mask = None, mpp = False):
if mpp:
x = img
else:
p = self.patch_size
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
x = self.patch_to_embedding(x)
b, n, _ = x.shape
if prediction_mask:
x[:, prediction_mask, :] = self.mask_token
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x, mask)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
if not mpp:
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)
return self.mlp_head(x)