This commit is contained in:
Phil Wang
2021-06-16 16:46:32 -07:00
parent 53884f583f
commit 64a2ef6462
3 changed files with 44 additions and 52 deletions

View File

@@ -437,7 +437,7 @@ mpp_trainer = MPP(
opt = torch.optim.Adam(mpp_trainer.parameters(), lr=3e-4)
def sample_unlabelled_images():
return torch.randn(20, 3, 256, 256)
return torch.FloatTensor(20, 3, 256, 256).uniform_(0., 1.)
for _ in range(100):
images = sample_unlabelled_images()

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.19.5',
version = '0.19.6',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',

View File

@@ -1,20 +1,20 @@
import math
from functools import reduce
import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange, repeat
from einops import rearrange, repeat, reduce
# helpers
def exists(val):
return val is not None
def prob_mask_like(t, prob):
batch, seq_length, _ = t.shape
return torch.zeros((batch, seq_length)).float().uniform_(0, 1) < prob
def get_mask_subset_with_prob(patched_input, prob):
batch, seq_len, _, device = *patched_input.shape, patched_input.device
max_masked = math.ceil(prob * seq_len)
@@ -31,55 +31,45 @@ def get_mask_subset_with_prob(patched_input, prob):
class MPPLoss(nn.Module):
def __init__(self, patch_size, channels, output_channel_bits,
max_pixel_val, mean, std):
super(MPPLoss, self).__init__()
def __init__(
self,
patch_size,
channels,
output_channel_bits,
max_pixel_val,
mean,
std
):
super().__init__()
self.patch_size = patch_size
self.channels = channels
self.output_channel_bits = output_channel_bits
self.max_pixel_val = max_pixel_val
if mean:
self.mean = torch.tensor(mean).view(-1, 1, 1)
else:
self.mean = None
if std:
self.std = torch.tensor(std).view(-1, 1, 1)
else:
self.std = None
self.mean = torch.tensor(mean).view(-1, 1, 1) if mean else None
self.std = torch.tensor(std).view(-1, 1, 1) if std else None
def forward(self, predicted_patches, target, mask):
p, c, mpv, bits, device = self.patch_size, self.channels, self.max_pixel_val, self.output_channel_bits, target.device
bin_size = mpv / (2 ** bits)
# un-normalize input
if self.mean is not None and self.std is not None:
if exists(self.mean) and exists(self.std):
target = target * self.std + self.mean
# reshape target to patches
p = self.patch_size
target = rearrange(target,
"b c (h p1) (w p2) -> b (h w) c (p1 p2) ",
p1=p,
p2=p)
target = target.clamp(max = mpv) # clamp just in case
avg_target = reduce(target, 'b c (h p1) (w p2) -> b (h w) c', 'mean', p1 = p, p2 = p).contiguous()
avg_target = target.mean(dim=3)
bin_size = self.max_pixel_val / self.output_channel_bits
channel_bins = torch.arange(bin_size, self.max_pixel_val, bin_size).to(avg_target.device)
channel_bins = torch.arange(bin_size, mpv, bin_size, device = device)
discretized_target = torch.bucketize(avg_target, channel_bins)
discretized_target = F.one_hot(discretized_target,
self.output_channel_bits)
c, bi = self.channels, self.output_channel_bits
discretized_target = rearrange(discretized_target,
"b n c bi -> b n (c bi)",
c=c,
bi=bi)
bin_mask = 2**torch.arange(c * bi - 1, -1,
-1).to(discretized_target.device,
discretized_target.dtype)
target_label = torch.sum(bin_mask * discretized_target, -1)
predicted_patches = predicted_patches[mask]
target_label = target_label[mask]
loss = F.cross_entropy(predicted_patches, target_label)
bin_mask = (2 ** bits) ** torch.arange(0, c, device = device).long()
bin_mask = rearrange(bin_mask, 'c -> () () c')
target_label = torch.sum(bin_mask * discretized_target, dim = -1)
loss = F.cross_entropy(predicted_patches[mask], target_label[mask])
return loss
@@ -87,7 +77,8 @@ class MPPLoss(nn.Module):
class MPP(nn.Module):
def __init__(self,
def __init__(
self,
transformer,
patch_size,
dim,
@@ -98,7 +89,8 @@ class MPP(nn.Module):
replace_prob=0.5,
random_patch_prob=0.5,
mean=None,
std=None):
std=None
):
super().__init__()
self.transformer = transformer
self.loss = MPPLoss(patch_size, channels, output_channel_bits,