mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
fix mpp
This commit is contained in:
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
|||||||
setup(
|
setup(
|
||||||
name = 'vit-pytorch',
|
name = 'vit-pytorch',
|
||||||
packages = find_packages(exclude=['examples']),
|
packages = find_packages(exclude=['examples']),
|
||||||
version = '1.2.2',
|
version = '1.2.4',
|
||||||
license='MIT',
|
license='MIT',
|
||||||
description = 'Vision Transformer (ViT) - Pytorch',
|
description = 'Vision Transformer (ViT) - Pytorch',
|
||||||
long_description_content_type = 'text/markdown',
|
long_description_content_type = 'text/markdown',
|
||||||
|
|||||||
@@ -96,6 +96,9 @@ class MPP(nn.Module):
|
|||||||
self.loss = MPPLoss(patch_size, channels, output_channel_bits,
|
self.loss = MPPLoss(patch_size, channels, output_channel_bits,
|
||||||
max_pixel_val, mean, std)
|
max_pixel_val, mean, std)
|
||||||
|
|
||||||
|
# extract patching function
|
||||||
|
self.patch_to_emb = nn.Sequential(transformer.to_patch_embedding[1:])
|
||||||
|
|
||||||
# output transformation
|
# output transformation
|
||||||
self.to_bits = nn.Linear(dim, 2**(output_channel_bits * channels))
|
self.to_bits = nn.Linear(dim, 2**(output_channel_bits * channels))
|
||||||
|
|
||||||
@@ -151,7 +154,7 @@ class MPP(nn.Module):
|
|||||||
masked_input[bool_mask_replace] = self.mask_token
|
masked_input[bool_mask_replace] = self.mask_token
|
||||||
|
|
||||||
# linear embedding of patches
|
# linear embedding of patches
|
||||||
masked_input = transformer.to_patch_embedding[-1](masked_input)
|
masked_input = self.patch_to_emb(masked_input)
|
||||||
|
|
||||||
# add cls token to input sequence
|
# add cls token to input sequence
|
||||||
b, n, _ = masked_input.shape
|
b, n, _ = masked_input.shape
|
||||||
|
|||||||
Reference in New Issue
Block a user