mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
add constant token dropout for NaViT
This commit is contained in:
@@ -161,7 +161,8 @@ v = NaViT(
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
emb_dropout = 0.1,
|
||||
token_dropout_prob = 0.1 # token dropout of 10% (keep 90% of tokens)
|
||||
)
|
||||
|
||||
# 5 images of different resolutions - List[List[Tensor]]
|
||||
|
||||
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '1.2.6',
|
||||
version = '1.2.7',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
long_description_content_type = 'text/markdown',
|
||||
|
||||
@@ -138,10 +138,15 @@ class Transformer(nn.Module):
|
||||
return self.norm(x)
|
||||
|
||||
class NaViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., token_dropout_prob = 0.):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
|
||||
# what percent of tokens to dropout
|
||||
# in paper, they found this should vary depending on resolution (todo - figure out how to do this, maybe with callback?)
|
||||
|
||||
self.token_dropout_prob = token_dropout_prob
|
||||
|
||||
assert divisible_by(image_height, patch_size) and divisible_by(image_width, patch_size), 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
patch_height_dim, patch_width_dim = (image_height // patch_size), (image_width // patch_size)
|
||||
@@ -185,7 +190,7 @@ class NaViT(nn.Module):
|
||||
self,
|
||||
batched_images: List[List[Tensor]] # assume different resolution images already grouped correctly
|
||||
):
|
||||
p, c, device = self.patch_size, self.channels, self.device
|
||||
p, c, device, has_token_dropout = self.patch_size, self.channels, self.device, self.token_dropout_prob > 0.
|
||||
|
||||
arange = partial(torch.arange, device = device)
|
||||
pad_sequence = partial(orig_pad_sequence, batch_first = True)
|
||||
@@ -219,6 +224,14 @@ class NaViT(nn.Module):
|
||||
pos = rearrange(pos, 'h w c -> (h w) c')
|
||||
seq = rearrange(image, 'c (h p1) (w p2) -> (h w) (c p1 p2)', p1 = p, p2 = p)
|
||||
|
||||
seq_len = seq.shape[-2]
|
||||
|
||||
if has_token_dropout:
|
||||
num_keep = max(1, int(seq_len * (1 - self.token_dropout_prob)))
|
||||
keep_indices = torch.randn((seq_len,), device = device).topk(num_keep, dim = -1).indices
|
||||
seq = seq[keep_indices]
|
||||
pos = pos[keep_indices]
|
||||
|
||||
image_ids = F.pad(image_ids, (0, seq.shape[-2]), value = image_id)
|
||||
sequences.append(seq)
|
||||
positions.append(pos)
|
||||
|
||||
Reference in New Issue
Block a user