mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
Working implementation of masked patch prediction as a wrapper. Need to clean code up
This commit is contained in:
44
test.py
44
test.py
@@ -1,22 +1,22 @@
|
||||
import torch
|
||||
from vit_pytorch import MPP
|
||||
from vit_pytorch import MPP, ViT
|
||||
# from vit_pytorch import ViT, MaskedPredictionLoss
|
||||
|
||||
# v = ViT(
|
||||
# image_size = 256,
|
||||
# patch_size = 32,
|
||||
# num_classes = 1000,
|
||||
# dim = 1024,
|
||||
# depth = 6,
|
||||
# heads = 16,
|
||||
# mlp_dim = 2048,
|
||||
# dropout = 0.1,
|
||||
# emb_dropout = 0.1
|
||||
# )
|
||||
v = ViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 3,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
# l = MaskedPredictionLoss(patch_size=32, img_size=256)
|
||||
|
||||
# img = torch.randn(1, 3, 256, 256)
|
||||
img = torch.randn(2, 3, 256, 256)
|
||||
# mask = [1,2,3,4] # optional mask, designating which patch to attend to
|
||||
|
||||
# preds = v(img) # (1, 1000)
|
||||
@@ -24,19 +24,19 @@ from vit_pytorch import MPP
|
||||
|
||||
# print(preds.shape)
|
||||
|
||||
transformer = 5
|
||||
|
||||
|
||||
trainer = MPP(
|
||||
transformer,
|
||||
5,
|
||||
75,
|
||||
transformer = v,
|
||||
patch_size = 32,
|
||||
dim = 1024,
|
||||
mask_prob = 0.15, # masking probability for masked language modeling
|
||||
random_patch_prob=0.05,
|
||||
replace_prob = 0.90, # ~10% probability that token will not be masked, but included in loss, as detailed in the epaper
|
||||
random_patch_prob=0.30,
|
||||
replace_prob = 0.50, # ~10% probability that token will not be masked, but included in loss, as detailed in the epaper
|
||||
)
|
||||
|
||||
data = torch.rand((2, 3, 10, 10))
|
||||
# data = torch.rand((2, 3, 10, 10))
|
||||
|
||||
loss = trainer(data)
|
||||
loss = trainer(img)
|
||||
|
||||
print(loss)
|
||||
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
from vit_pytorch.vit_pytorch import ViT
|
||||
from vit_pytorch.masked_prediction_loss import MaskedPredictionLoss
|
||||
from vit_pytorch.mpp_loss import MPPLoss
|
||||
from vit_pytorch.mpp_pytorch import MPP
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class MaskedPredictionLoss(nn.Module):
|
||||
def __init__(self,
|
||||
patch_size,
|
||||
img_size,
|
||||
device="cpu"):
|
||||
super(MaskedPredictionLoss, self).__init__()
|
||||
self.patch_size = patch_size
|
||||
self.num_patch_axis = img_size // patch_size
|
||||
|
||||
def _transform_targets(self, targets, masked_patches):
|
||||
masked_patches_dims = []
|
||||
for masked_patch in masked_patches:
|
||||
height_offset = self.patch_size * ((masked_patch) // self.num_patch_axis)
|
||||
width_offset = self.patch_size * ((masked_patch) % self.num_patch_axis)
|
||||
masked_patches_dims.append([height_offset, width_offset])
|
||||
|
||||
target_patches = []
|
||||
for target in range(targets.shape[0]):
|
||||
for masked_patch in masked_patches_dims:
|
||||
height_offset, width_offset = masked_patch
|
||||
extracted_patch = targets[target, :,
|
||||
height_offset:height_offset + self.patch_size,
|
||||
width_offset:width_offset + self.patch_size]
|
||||
target_patches.append(extracted_patch)
|
||||
|
||||
target_patches_tensor = torch.stack(target_patches)
|
||||
target_patches_tensor = (target_patches_tensor // 0.34).long()
|
||||
encoded_targets = F.one_hot(target_patches_tensor, 3)
|
||||
n, c, w, h, e = encoded_targets.shape
|
||||
encoded_targets = torch.reshape(encoded_targets, [n, w, h, c * e])
|
||||
mean_targets = torch.mean(encoded_targets.float(),
|
||||
dim=[1, 2],
|
||||
keepdim=True).view(n, c * e)
|
||||
return mean_targets
|
||||
|
||||
def _transform_outputs(self, outputs, masked_patches):
|
||||
output_dim = outputs.shape[-1]
|
||||
|
||||
masked_patches_shifted = [
|
||||
masked_patch + 1 for masked_patch in masked_patches
|
||||
]
|
||||
outputs = outputs[:, masked_patches_shifted, :]
|
||||
outputs = outputs.view(-1, output_dim)
|
||||
return outputs
|
||||
|
||||
def forward(self, x, y, masked_patches):
|
||||
transformed_outputs = self._transform_outputs(x, masked_patches)
|
||||
transformed_targets = self._transform_targets(y, masked_patches)
|
||||
loss = F.mse_loss(transformed_outputs, transformed_targets)
|
||||
return loss
|
||||
26
vit_pytorch/mpp_loss.py
Normal file
26
vit_pytorch/mpp_loss.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
|
||||
class MPPLoss(nn.Module):
|
||||
def __init__(self, patch_size):
|
||||
super(MPPLoss, self).__init__()
|
||||
self.patch_size = patch_size
|
||||
|
||||
def forward(self, predicted_patches, target, mask):
|
||||
# reshape target to patches
|
||||
p = self.patch_size
|
||||
target = rearrange(target, "b c (h p1) (w p2) -> b (h w) c (p1 p2) ", p1 = p, p2 = p)
|
||||
|
||||
channel_bins = torch.tensor([0.333, 0.666, 1.0])
|
||||
target = torch.bucketize(target, channel_bins, right=True)
|
||||
target = target.float().mean(dim=3)
|
||||
|
||||
predicted_patches = predicted_patches[mask]
|
||||
target = target[mask]
|
||||
|
||||
loss = F.mse_loss(predicted_patches, target)
|
||||
return loss
|
||||
@@ -7,6 +7,8 @@ import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from vit_pytorch import MPPLoss
|
||||
|
||||
# helpers
|
||||
|
||||
def prob_mask_like(t, prob):
|
||||
@@ -32,12 +34,14 @@ class MPP(nn.Module):
|
||||
transformer,
|
||||
patch_size,
|
||||
dim,
|
||||
channels = 3,
|
||||
mask_prob = 0.15,
|
||||
replace_prob = 0.5,
|
||||
random_patch_prob = 0.5):
|
||||
super().__init__()
|
||||
|
||||
self.transformer = transformer
|
||||
self.loss = MPPLoss(patch_size)
|
||||
|
||||
# vit related dimensions
|
||||
self.patch_size = patch_size
|
||||
@@ -48,9 +52,11 @@ class MPP(nn.Module):
|
||||
self.random_patch_prob = random_patch_prob
|
||||
|
||||
# token ids
|
||||
self.mask_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.mask_token = nn.Parameter(torch.randn(1, 1, dim * channels))
|
||||
|
||||
def forward(self, input, **kwargs):
|
||||
# clone original image for loss
|
||||
img = input.clone().detach()
|
||||
|
||||
# reshape raw image to patches
|
||||
p = self.patch_size
|
||||
@@ -75,17 +81,15 @@ class MPP(nn.Module):
|
||||
bool_mask_replace = (mask * replace_prob) == True
|
||||
masked_input[bool_mask_replace] = self.mask_token
|
||||
|
||||
# set inverse of mask to padding tokens for labels
|
||||
# get labels for input patches that were masked
|
||||
bool_mask = mask == True
|
||||
labels = input[bool_mask]
|
||||
|
||||
# get generator output and get mpp loss
|
||||
logits = self.transformer(masked_input, **kwargs)
|
||||
cls_logits = self.transformer(masked_input, mpp=True, **kwargs)
|
||||
logits = cls_logits[:,1:,:]
|
||||
|
||||
mpp_loss = F.cross_entropy(
|
||||
logits.transpose(1, 2),
|
||||
labels,
|
||||
)
|
||||
mpp_loss = self.loss(logits, img, mask)
|
||||
|
||||
return mpp_loss
|
||||
|
||||
|
||||
@@ -97,7 +97,6 @@ class ViT(nn.Module):
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.patch_to_embedding = nn.Linear(patch_dim, dim)
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.mask_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
@@ -110,25 +109,25 @@ class ViT(nn.Module):
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img, mask = None, prediction_mask = None):
|
||||
p = self.patch_size
|
||||
def forward(self, img, mask = None, mpp = False):
|
||||
if mpp:
|
||||
x = img
|
||||
else:
|
||||
p = self.patch_size
|
||||
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
|
||||
|
||||
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
|
||||
x = self.patch_to_embedding(x)
|
||||
b, n, _ = x.shape
|
||||
|
||||
if prediction_mask:
|
||||
x[:, prediction_mask, :] = self.mask_token
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x, mask)
|
||||
|
||||
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
||||
if not mpp:
|
||||
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.mlp_head(x)
|
||||
return self.mlp_head(x)
|
||||
Reference in New Issue
Block a user