mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
fix mpp
This commit is contained in:
@@ -437,7 +437,7 @@ mpp_trainer = MPP(
|
||||
opt = torch.optim.Adam(mpp_trainer.parameters(), lr=3e-4)
|
||||
|
||||
def sample_unlabelled_images():
|
||||
return torch.randn(20, 3, 256, 256)
|
||||
return torch.FloatTensor(20, 3, 256, 256).uniform_(0., 1.)
|
||||
|
||||
for _ in range(100):
|
||||
images = sample_unlabelled_images()
|
||||
|
||||
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.19.5',
|
||||
version = '0.19.6',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
import math
|
||||
from functools import reduce
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops import rearrange, repeat, reduce
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def prob_mask_like(t, prob):
|
||||
batch, seq_length, _ = t.shape
|
||||
return torch.zeros((batch, seq_length)).float().uniform_(0, 1) < prob
|
||||
|
||||
|
||||
def get_mask_subset_with_prob(patched_input, prob):
|
||||
batch, seq_len, _, device = *patched_input.shape, patched_input.device
|
||||
max_masked = math.ceil(prob * seq_len)
|
||||
@@ -31,55 +31,45 @@ def get_mask_subset_with_prob(patched_input, prob):
|
||||
|
||||
|
||||
class MPPLoss(nn.Module):
|
||||
def __init__(self, patch_size, channels, output_channel_bits,
|
||||
max_pixel_val, mean, std):
|
||||
super(MPPLoss, self).__init__()
|
||||
def __init__(
|
||||
self,
|
||||
patch_size,
|
||||
channels,
|
||||
output_channel_bits,
|
||||
max_pixel_val,
|
||||
mean,
|
||||
std
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.channels = channels
|
||||
self.output_channel_bits = output_channel_bits
|
||||
self.max_pixel_val = max_pixel_val
|
||||
|
||||
if mean:
|
||||
self.mean = torch.tensor(mean).view(-1, 1, 1)
|
||||
else:
|
||||
self.mean = None
|
||||
if std:
|
||||
self.std = torch.tensor(std).view(-1, 1, 1)
|
||||
else:
|
||||
self.std = None
|
||||
self.mean = torch.tensor(mean).view(-1, 1, 1) if mean else None
|
||||
self.std = torch.tensor(std).view(-1, 1, 1) if std else None
|
||||
|
||||
def forward(self, predicted_patches, target, mask):
|
||||
p, c, mpv, bits, device = self.patch_size, self.channels, self.max_pixel_val, self.output_channel_bits, target.device
|
||||
bin_size = mpv / (2 ** bits)
|
||||
|
||||
# un-normalize input
|
||||
if self.mean is not None and self.std is not None:
|
||||
if exists(self.mean) and exists(self.std):
|
||||
target = target * self.std + self.mean
|
||||
|
||||
# reshape target to patches
|
||||
p = self.patch_size
|
||||
target = rearrange(target,
|
||||
"b c (h p1) (w p2) -> b (h w) c (p1 p2) ",
|
||||
p1=p,
|
||||
p2=p)
|
||||
target = target.clamp(max = mpv) # clamp just in case
|
||||
avg_target = reduce(target, 'b c (h p1) (w p2) -> b (h w) c', 'mean', p1 = p, p2 = p).contiguous()
|
||||
|
||||
avg_target = target.mean(dim=3)
|
||||
|
||||
bin_size = self.max_pixel_val / self.output_channel_bits
|
||||
channel_bins = torch.arange(bin_size, self.max_pixel_val, bin_size).to(avg_target.device)
|
||||
channel_bins = torch.arange(bin_size, mpv, bin_size, device = device)
|
||||
discretized_target = torch.bucketize(avg_target, channel_bins)
|
||||
discretized_target = F.one_hot(discretized_target,
|
||||
self.output_channel_bits)
|
||||
c, bi = self.channels, self.output_channel_bits
|
||||
discretized_target = rearrange(discretized_target,
|
||||
"b n c bi -> b n (c bi)",
|
||||
c=c,
|
||||
bi=bi)
|
||||
|
||||
bin_mask = 2**torch.arange(c * bi - 1, -1,
|
||||
-1).to(discretized_target.device,
|
||||
discretized_target.dtype)
|
||||
target_label = torch.sum(bin_mask * discretized_target, -1)
|
||||
predicted_patches = predicted_patches[mask]
|
||||
target_label = target_label[mask]
|
||||
loss = F.cross_entropy(predicted_patches, target_label)
|
||||
bin_mask = (2 ** bits) ** torch.arange(0, c, device = device).long()
|
||||
bin_mask = rearrange(bin_mask, 'c -> () () c')
|
||||
|
||||
target_label = torch.sum(bin_mask * discretized_target, dim = -1)
|
||||
|
||||
loss = F.cross_entropy(predicted_patches[mask], target_label[mask])
|
||||
return loss
|
||||
|
||||
|
||||
@@ -87,18 +77,20 @@ class MPPLoss(nn.Module):
|
||||
|
||||
|
||||
class MPP(nn.Module):
|
||||
def __init__(self,
|
||||
transformer,
|
||||
patch_size,
|
||||
dim,
|
||||
output_channel_bits=3,
|
||||
channels=3,
|
||||
max_pixel_val=1.0,
|
||||
mask_prob=0.15,
|
||||
replace_prob=0.5,
|
||||
random_patch_prob=0.5,
|
||||
mean=None,
|
||||
std=None):
|
||||
def __init__(
|
||||
self,
|
||||
transformer,
|
||||
patch_size,
|
||||
dim,
|
||||
output_channel_bits=3,
|
||||
channels=3,
|
||||
max_pixel_val=1.0,
|
||||
mask_prob=0.15,
|
||||
replace_prob=0.5,
|
||||
random_patch_prob=0.5,
|
||||
mean=None,
|
||||
std=None
|
||||
):
|
||||
super().__init__()
|
||||
self.transformer = transformer
|
||||
self.loss = MPPLoss(patch_size, channels, output_channel_bits,
|
||||
|
||||
Reference in New Issue
Block a user