Compare commits

...

3 Commits
1.6.5 ... 1.6.7

Author SHA1 Message Date
Phil Wang
96f66d2754 address https://github.com/lucidrains/vit-pytorch/issues/306 2024-04-18 09:44:29 -07:00
Phil Wang
12249dcc5f address https://github.com/lucidrains/vit-pytorch/issues/304 2024-04-17 09:40:03 -07:00
SOUMYADIP MAL
8b8da8dede Update setup.py (#303) 2024-04-17 08:21:30 -07:00
3 changed files with 17 additions and 11 deletions

View File

@@ -1,11 +1,15 @@
from setuptools import setup, find_packages
with open('README.md') as f:
long_description = f.read()
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.6.5',
version = '1.6.7',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description=long_description,
long_description_content_type = 'text/markdown',
author = 'Phil Wang',
author_email = 'lucidrains@gmail.com',

View File

@@ -170,12 +170,13 @@ class ImageEmbedder(nn.Module):
dim,
image_size,
patch_size,
dropout = 0.
dropout = 0.,
channels = 3
):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
patch_dim = 3 * patch_size ** 2
patch_dim = channels * patch_size ** 2
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
@@ -223,11 +224,12 @@ class CrossViT(nn.Module):
cross_attn_dim_head = 64,
depth = 3,
dropout = 0.1,
emb_dropout = 0.1
emb_dropout = 0.1,
channels = 3
):
super().__init__()
self.sm_image_embedder = ImageEmbedder(dim = sm_dim, image_size = image_size, patch_size = sm_patch_size, dropout = emb_dropout)
self.lg_image_embedder = ImageEmbedder(dim = lg_dim, image_size = image_size, patch_size = lg_patch_size, dropout = emb_dropout)
self.sm_image_embedder = ImageEmbedder(dim = sm_dim, channels= channels, image_size = image_size, patch_size = sm_patch_size, dropout = emb_dropout)
self.lg_image_embedder = ImageEmbedder(dim = lg_dim, channels = channels, image_size = image_size, patch_size = lg_patch_size, dropout = emb_dropout)
self.multi_scale_encoder = MultiScaleEncoder(
depth = depth,

View File

@@ -198,7 +198,7 @@ class NaViT(nn.Module):
self.calc_token_dropout = token_dropout_prob
elif isinstance(token_dropout_prob, (float, int)):
assert 0. < token_dropout_prob < 1.
assert 0. <= token_dropout_prob < 1.
token_dropout_prob = float(token_dropout_prob)
self.calc_token_dropout = lambda height, width: token_dropout_prob
@@ -249,7 +249,7 @@ class NaViT(nn.Module):
group_images = False,
group_max_seq_len = 2048
):
p, c, device, has_token_dropout = self.patch_size, self.channels, self.device, exists(self.calc_token_dropout)
p, c, device, has_token_dropout = self.patch_size, self.channels, self.device, exists(self.calc_token_dropout) and self.training
arange = partial(torch.arange, device = device)
pad_sequence = partial(orig_pad_sequence, batch_first = True)
@@ -260,7 +260,7 @@ class NaViT(nn.Module):
batched_images = group_images_by_max_seq_len(
batched_images,
patch_size = self.patch_size,
calc_token_dropout = self.calc_token_dropout,
calc_token_dropout = self.calc_token_dropout if self.training else None,
max_seq_len = group_max_seq_len
)
@@ -314,8 +314,8 @@ class NaViT(nn.Module):
# derive key padding mask
lengths = torch.tensor([seq.shape[-2] for seq in batched_sequences], device = device, dtype = torch.long)
max_length = arange(lengths.amax().item())
key_pad_mask = rearrange(lengths, 'b -> b 1') <= rearrange(max_length, 'n -> 1 n')
seq_arange = arange(lengths.amax().item())
key_pad_mask = rearrange(seq_arange, 'n -> 1 n') < rearrange(lengths, 'b -> b 1')
# derive attention mask, and combine with key padding mask from above