mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 16:12:29 +00:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
53884f583f | ||
|
|
e616b5dcbc | ||
|
|
a2df363224 | ||
|
|
710b6d57d3 |
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.19.4',
|
||||
version = '0.19.5',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
|
||||
@@ -32,14 +32,27 @@ def get_mask_subset_with_prob(patched_input, prob):
|
||||
|
||||
class MPPLoss(nn.Module):
|
||||
def __init__(self, patch_size, channels, output_channel_bits,
|
||||
max_pixel_val):
|
||||
max_pixel_val, mean, std):
|
||||
super(MPPLoss, self).__init__()
|
||||
self.patch_size = patch_size
|
||||
self.channels = channels
|
||||
self.output_channel_bits = output_channel_bits
|
||||
self.max_pixel_val = max_pixel_val
|
||||
|
||||
if mean:
|
||||
self.mean = torch.tensor(mean).view(-1, 1, 1)
|
||||
else:
|
||||
self.mean = None
|
||||
if std:
|
||||
self.std = torch.tensor(std).view(-1, 1, 1)
|
||||
else:
|
||||
self.std = None
|
||||
|
||||
def forward(self, predicted_patches, target, mask):
|
||||
# un-normalize input
|
||||
if self.mean is not None and self.std is not None:
|
||||
target = target * self.std + self.mean
|
||||
|
||||
# reshape target to patches
|
||||
p = self.patch_size
|
||||
target = rearrange(target,
|
||||
@@ -64,7 +77,6 @@ class MPPLoss(nn.Module):
|
||||
-1).to(discretized_target.device,
|
||||
discretized_target.dtype)
|
||||
target_label = torch.sum(bin_mask * discretized_target, -1)
|
||||
|
||||
predicted_patches = predicted_patches[mask]
|
||||
target_label = target_label[mask]
|
||||
loss = F.cross_entropy(predicted_patches, target_label)
|
||||
@@ -84,11 +96,13 @@ class MPP(nn.Module):
|
||||
max_pixel_val=1.0,
|
||||
mask_prob=0.15,
|
||||
replace_prob=0.5,
|
||||
random_patch_prob=0.5):
|
||||
random_patch_prob=0.5,
|
||||
mean=None,
|
||||
std=None):
|
||||
super().__init__()
|
||||
self.transformer = transformer
|
||||
self.loss = MPPLoss(patch_size, channels, output_channel_bits,
|
||||
max_pixel_val)
|
||||
max_pixel_val, mean, std)
|
||||
|
||||
# output transformation
|
||||
self.to_bits = nn.Linear(dim, 2**(output_channel_bits * channels))
|
||||
@@ -102,7 +116,7 @@ class MPP(nn.Module):
|
||||
self.random_patch_prob = random_patch_prob
|
||||
|
||||
# token ids
|
||||
self.mask_token = nn.Parameter(torch.randn(1, 1, dim * channels))
|
||||
self.mask_token = nn.Parameter(torch.randn(1, 1, channels * patch_size ** 2))
|
||||
|
||||
def forward(self, input, **kwargs):
|
||||
transformer = self.transformer
|
||||
|
||||
Reference in New Issue
Block a user