From cd210905d943e56399ff0351fe650b63ce987d5d Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 24 Jul 2023 14:30:30 -0700 Subject: [PATCH] one can pass a callback to token_dropout_prob for NaViT that takes in height and width and calculate appropriate dropout rate --- setup.py | 2 +- vit_pytorch/na_vit.py | 23 ++++++++++++++++++----- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/setup.py b/setup.py index da67652..cb85938 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '1.2.7', + version = '1.2.8', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', long_description_content_type = 'text/markdown', diff --git a/vit_pytorch/na_vit.py b/vit_pytorch/na_vit.py index 46e88c0..00ac8be 100644 --- a/vit_pytorch/na_vit.py +++ b/vit_pytorch/na_vit.py @@ -138,14 +138,25 @@ class Transformer(nn.Module): return self.norm(x) class NaViT(nn.Module): - def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., token_dropout_prob = 0.): + def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., token_dropout_prob = None): super().__init__() image_height, image_width = pair(image_size) # what percent of tokens to dropout - # in paper, they found this should vary depending on resolution (todo - figure out how to do this, maybe with callback?) + # if int or float given, then assume constant dropout prob + # otherwise accept a callback that in turn calculates dropout prob from height and width - self.token_dropout_prob = token_dropout_prob + self.calc_token_dropout = calc_token_dropout = None + + if callable(token_dropout_prob): + self.calc_token_dropout = token_dropout_prob + + elif isinstance(token_dropout_prob, (float, int)): + assert 0. < token_dropout_prob < 1. + token_dropout_prob = float(token_dropout_prob) + self.calc_token_dropout = lambda height, width: token_dropout_prob + + # calculate patching related stuff assert divisible_by(image_height, patch_size) and divisible_by(image_width, patch_size), 'Image dimensions must be divisible by the patch size.' @@ -190,7 +201,7 @@ class NaViT(nn.Module): self, batched_images: List[List[Tensor]] # assume different resolution images already grouped correctly ): - p, c, device, has_token_dropout = self.patch_size, self.channels, self.device, self.token_dropout_prob > 0. + p, c, device, has_token_dropout = self.patch_size, self.channels, self.device, exists(self.calc_token_dropout) arange = partial(torch.arange, device = device) pad_sequence = partial(orig_pad_sequence, batch_first = True) @@ -227,8 +238,10 @@ class NaViT(nn.Module): seq_len = seq.shape[-2] if has_token_dropout: - num_keep = max(1, int(seq_len * (1 - self.token_dropout_prob))) + token_dropout = self.calc_token_dropout(*image_dims) + num_keep = max(1, int(seq_len * (1 - token_dropout))) keep_indices = torch.randn((seq_len,), device = device).topk(num_keep, dim = -1).indices + seq = seq[keep_indices] pos = pos[keep_indices]