Compare commits

...

4 Commits

Author SHA1 Message Date
Phil Wang
53884f583f 0.19.5 2021-06-16 14:24:46 -07:00
Phil Wang
e616b5dcbc Merge pull request #101 from zankner/mpp-fix
Mpp fix
2021-06-16 14:24:26 -07:00
Zack Ankner
a2df363224 adding un-normalizing targets and fix for mask token dimension 2021-04-29 15:43:22 -04:00
Zack Ankner
710b6d57d3 Merge pull request #1 from lucidrains/main
catch up
2021-04-29 19:33:25 +00:00
2 changed files with 20 additions and 6 deletions

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.19.4',
version = '0.19.5',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',

View File

@@ -32,14 +32,27 @@ def get_mask_subset_with_prob(patched_input, prob):
class MPPLoss(nn.Module):
def __init__(self, patch_size, channels, output_channel_bits,
max_pixel_val):
max_pixel_val, mean, std):
super(MPPLoss, self).__init__()
self.patch_size = patch_size
self.channels = channels
self.output_channel_bits = output_channel_bits
self.max_pixel_val = max_pixel_val
if mean:
self.mean = torch.tensor(mean).view(-1, 1, 1)
else:
self.mean = None
if std:
self.std = torch.tensor(std).view(-1, 1, 1)
else:
self.std = None
def forward(self, predicted_patches, target, mask):
# un-normalize input
if self.mean is not None and self.std is not None:
target = target * self.std + self.mean
# reshape target to patches
p = self.patch_size
target = rearrange(target,
@@ -64,7 +77,6 @@ class MPPLoss(nn.Module):
-1).to(discretized_target.device,
discretized_target.dtype)
target_label = torch.sum(bin_mask * discretized_target, -1)
predicted_patches = predicted_patches[mask]
target_label = target_label[mask]
loss = F.cross_entropy(predicted_patches, target_label)
@@ -84,11 +96,13 @@ class MPP(nn.Module):
max_pixel_val=1.0,
mask_prob=0.15,
replace_prob=0.5,
random_patch_prob=0.5):
random_patch_prob=0.5,
mean=None,
std=None):
super().__init__()
self.transformer = transformer
self.loss = MPPLoss(patch_size, channels, output_channel_bits,
max_pixel_val)
max_pixel_val, mean, std)
# output transformation
self.to_bits = nn.Linear(dim, 2**(output_channel_bits * channels))
@@ -102,7 +116,7 @@ class MPP(nn.Module):
self.random_patch_prob = random_patch_prob
# token ids
self.mask_token = nn.Parameter(torch.randn(1, 1, dim * channels))
self.mask_token = nn.Parameter(torch.randn(1, 1, channels * patch_size ** 2))
def forward(self, input, **kwargs):
transformer = self.transformer