This commit is contained in:
Phil Wang
2023-06-28 08:02:43 -07:00
parent ce4bcd08fb
commit 9e3fec2398
2 changed files with 5 additions and 2 deletions

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup( setup(
name = 'vit-pytorch', name = 'vit-pytorch',
packages = find_packages(exclude=['examples']), packages = find_packages(exclude=['examples']),
version = '1.2.2', version = '1.2.4',
license='MIT', license='MIT',
description = 'Vision Transformer (ViT) - Pytorch', description = 'Vision Transformer (ViT) - Pytorch',
long_description_content_type = 'text/markdown', long_description_content_type = 'text/markdown',

View File

@@ -96,6 +96,9 @@ class MPP(nn.Module):
self.loss = MPPLoss(patch_size, channels, output_channel_bits, self.loss = MPPLoss(patch_size, channels, output_channel_bits,
max_pixel_val, mean, std) max_pixel_val, mean, std)
# extract patching function
self.patch_to_emb = nn.Sequential(transformer.to_patch_embedding[1:])
# output transformation # output transformation
self.to_bits = nn.Linear(dim, 2**(output_channel_bits * channels)) self.to_bits = nn.Linear(dim, 2**(output_channel_bits * channels))
@@ -151,7 +154,7 @@ class MPP(nn.Module):
masked_input[bool_mask_replace] = self.mask_token masked_input[bool_mask_replace] = self.mask_token
# linear embedding of patches # linear embedding of patches
masked_input = transformer.to_patch_embedding[-1](masked_input) masked_input = self.patch_to_emb(masked_input)
# add cls token to input sequence # add cls token to input sequence
b, n, _ = masked_input.shape b, n, _ = masked_input.shape