mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2026-01-07 13:32:34 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9e3fec2398 | ||
|
|
ce4bcd08fb |
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '1.2.1',
|
||||
version = '1.2.4',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
long_description_content_type = 'text/markdown',
|
||||
|
||||
@@ -96,6 +96,9 @@ class MPP(nn.Module):
|
||||
self.loss = MPPLoss(patch_size, channels, output_channel_bits,
|
||||
max_pixel_val, mean, std)
|
||||
|
||||
# extract patching function
|
||||
self.patch_to_emb = nn.Sequential(transformer.to_patch_embedding[1:])
|
||||
|
||||
# output transformation
|
||||
self.to_bits = nn.Linear(dim, 2**(output_channel_bits * channels))
|
||||
|
||||
@@ -151,7 +154,7 @@ class MPP(nn.Module):
|
||||
masked_input[bool_mask_replace] = self.mask_token
|
||||
|
||||
# linear embedding of patches
|
||||
masked_input = transformer.to_patch_embedding[-1](masked_input)
|
||||
masked_input = self.patch_to_emb(masked_input)
|
||||
|
||||
# add cls token to input sequence
|
||||
b, n, _ = masked_input.shape
|
||||
|
||||
@@ -146,7 +146,7 @@ class ViT(nn.Module):
|
||||
x = self.to_patch_embedding(video)
|
||||
b, f, n, _ = x.shape
|
||||
|
||||
x = x + self.pos_embedding
|
||||
x = x + self.pos_embedding[:, :f, :n]
|
||||
|
||||
if exists(self.spatial_cls_token):
|
||||
spatial_cls_tokens = repeat(self.spatial_cls_token, '1 1 d -> b f 1 d', b = b, f = f)
|
||||
|
||||
Reference in New Issue
Block a user