Compare commits

...

4 Commits

Author SHA1 Message Date
Phil Wang
b483b16833 0.18.4 2021-05-18 14:40:33 -07:00
Phil Wang
c457573808 Merge pull request #118 from loctruong96/main
update  mpp.py to work on GPU
2021-05-18 14:40:17 -07:00
Loc Truong
e75b6d0251 Update mpp.py
fix issue with GPU device mismatch
2021-05-16 20:07:49 -07:00
Phil Wang
679e5be3e7 apply scale to 2d rel pos bias in levit 2021-05-10 11:37:23 -07:00
3 changed files with 7 additions and 7 deletions

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.18.1',
version = '0.18.4',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',

View File

@@ -84,7 +84,7 @@ class Attention(nn.Module):
def apply_pos_bias(self, fmap):
bias = self.pos_bias(self.pos_indices)
bias = rearrange(bias, 'i j h -> () h i j')
return fmap + bias
return fmap + (bias / self.scale)
def forward(self, x):
b, n, *_, h = *x.shape, self.heads

View File

@@ -50,7 +50,7 @@ class MPPLoss(nn.Module):
avg_target = target.mean(dim=3)
bin_size = self.max_pixel_val / self.output_channel_bits
channel_bins = torch.arange(bin_size, self.max_pixel_val, bin_size)
channel_bins = torch.arange(bin_size, self.max_pixel_val, bin_size).to(avg_target.device)
discretized_target = torch.bucketize(avg_target, channel_bins)
discretized_target = F.one_hot(discretized_target,
self.output_channel_bits)
@@ -86,7 +86,6 @@ class MPP(nn.Module):
replace_prob=0.5,
random_patch_prob=0.5):
super().__init__()
self.transformer = transformer
self.loss = MPPLoss(patch_size, channels, output_channel_bits,
max_pixel_val)
@@ -127,8 +126,9 @@ class MPP(nn.Module):
random_patch_sampling_prob = self.random_patch_prob / (
1 - self.replace_prob)
random_patch_prob = prob_mask_like(input,
random_patch_sampling_prob)
bool_random_patch_prob = mask * random_patch_prob == True
random_patch_sampling_prob).to(mask.device)
bool_random_patch_prob = mask * (random_patch_prob == True)
random_patches = torch.randint(0,
input.shape[1],
(input.shape[0], input.shape[1]),
@@ -140,7 +140,7 @@ class MPP(nn.Module):
bool_random_patch_prob]
# [mask] input
replace_prob = prob_mask_like(input, self.replace_prob)
replace_prob = prob_mask_like(input, self.replace_prob).to(mask.device)
bool_mask_replace = (mask * replace_prob) == True
masked_input[bool_mask_replace] = self.mask_token