mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2026-05-12 10:58:23 +00:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b483b16833 | ||
|
|
c457573808 | ||
|
|
e75b6d0251 | ||
|
|
679e5be3e7 |
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.18.1',
|
||||
version = '0.18.4',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
|
||||
@@ -84,7 +84,7 @@ class Attention(nn.Module):
|
||||
def apply_pos_bias(self, fmap):
|
||||
bias = self.pos_bias(self.pos_indices)
|
||||
bias = rearrange(bias, 'i j h -> () h i j')
|
||||
return fmap + bias
|
||||
return fmap + (bias / self.scale)
|
||||
|
||||
def forward(self, x):
|
||||
b, n, *_, h = *x.shape, self.heads
|
||||
|
||||
@@ -50,7 +50,7 @@ class MPPLoss(nn.Module):
|
||||
avg_target = target.mean(dim=3)
|
||||
|
||||
bin_size = self.max_pixel_val / self.output_channel_bits
|
||||
channel_bins = torch.arange(bin_size, self.max_pixel_val, bin_size)
|
||||
channel_bins = torch.arange(bin_size, self.max_pixel_val, bin_size).to(avg_target.device)
|
||||
discretized_target = torch.bucketize(avg_target, channel_bins)
|
||||
discretized_target = F.one_hot(discretized_target,
|
||||
self.output_channel_bits)
|
||||
@@ -86,7 +86,6 @@ class MPP(nn.Module):
|
||||
replace_prob=0.5,
|
||||
random_patch_prob=0.5):
|
||||
super().__init__()
|
||||
|
||||
self.transformer = transformer
|
||||
self.loss = MPPLoss(patch_size, channels, output_channel_bits,
|
||||
max_pixel_val)
|
||||
@@ -127,8 +126,9 @@ class MPP(nn.Module):
|
||||
random_patch_sampling_prob = self.random_patch_prob / (
|
||||
1 - self.replace_prob)
|
||||
random_patch_prob = prob_mask_like(input,
|
||||
random_patch_sampling_prob)
|
||||
bool_random_patch_prob = mask * random_patch_prob == True
|
||||
random_patch_sampling_prob).to(mask.device)
|
||||
|
||||
bool_random_patch_prob = mask * (random_patch_prob == True)
|
||||
random_patches = torch.randint(0,
|
||||
input.shape[1],
|
||||
(input.shape[0], input.shape[1]),
|
||||
@@ -140,7 +140,7 @@ class MPP(nn.Module):
|
||||
bool_random_patch_prob]
|
||||
|
||||
# [mask] input
|
||||
replace_prob = prob_mask_like(input, self.replace_prob)
|
||||
replace_prob = prob_mask_like(input, self.replace_prob).to(mask.device)
|
||||
bool_mask_replace = (mask * replace_prob) == True
|
||||
masked_input[bool_mask_replace] = self.mask_token
|
||||
|
||||
|
||||
Reference in New Issue
Block a user