add constant token dropout for NaViT

This commit is contained in:
Phil Wang
2023-07-24 14:14:36 -07:00
parent 598cffab53
commit 17675e0de4
3 changed files with 18 additions and 4 deletions

View File

@@ -161,7 +161,8 @@ v = NaViT(
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
emb_dropout = 0.1,
token_dropout_prob = 0.1 # token dropout of 10% (keep 90% of tokens)
)
# 5 images of different resolutions - List[List[Tensor]]

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.2.6',
version = '1.2.7',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description_content_type = 'text/markdown',

View File

@@ -138,10 +138,15 @@ class Transformer(nn.Module):
return self.norm(x)
class NaViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., token_dropout_prob = 0.):
super().__init__()
image_height, image_width = pair(image_size)
# what percent of tokens to dropout
# in paper, they found this should vary depending on resolution (todo - figure out how to do this, maybe with callback?)
self.token_dropout_prob = token_dropout_prob
assert divisible_by(image_height, patch_size) and divisible_by(image_width, patch_size), 'Image dimensions must be divisible by the patch size.'
patch_height_dim, patch_width_dim = (image_height // patch_size), (image_width // patch_size)
@@ -185,7 +190,7 @@ class NaViT(nn.Module):
self,
batched_images: List[List[Tensor]] # assume different resolution images already grouped correctly
):
p, c, device = self.patch_size, self.channels, self.device
p, c, device, has_token_dropout = self.patch_size, self.channels, self.device, self.token_dropout_prob > 0.
arange = partial(torch.arange, device = device)
pad_sequence = partial(orig_pad_sequence, batch_first = True)
@@ -219,6 +224,14 @@ class NaViT(nn.Module):
pos = rearrange(pos, 'h w c -> (h w) c')
seq = rearrange(image, 'c (h p1) (w p2) -> (h w) (c p1 p2)', p1 = p, p2 = p)
seq_len = seq.shape[-2]
if has_token_dropout:
num_keep = max(1, int(seq_len * (1 - self.token_dropout_prob)))
keep_indices = torch.randn((seq_len,), device = device).topk(num_keep, dim = -1).indices
seq = seq[keep_indices]
pos = pos[keep_indices]
image_ids = F.pad(image_ids, (0, seq.shape[-2]), value = image_id)
sequences.append(seq)
positions.append(pos)