diff --git a/README.md b/README.md index 627dc0f..604e384 100644 --- a/README.md +++ b/README.md @@ -179,6 +179,24 @@ preds = v(images) # (5, 1000) - 5, because 5 images of different resolution abov ``` +Or if you would rather that the framework auto group the images into variable lengthed sequences that do not exceed a certain max length + +```python +images = [ + torch.randn(3, 256, 256), + torch.randn(3, 128, 128), + torch.randn(3, 128, 256), + torch.randn(3, 256, 128), + torch.randn(3, 64, 256) +] + +preds = v( + images, + group_images = True, + group_max_seq_len = 64 +) # (5, 1000) +``` + ## Distillation diff --git a/setup.py b/setup.py index cb85938..3db75a6 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '1.2.8', + version = '1.2.9', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', long_description_content_type = 'text/markdown', diff --git a/vit_pytorch/na_vit.py b/vit_pytorch/na_vit.py index 0497d01..388edcc 100644 --- a/vit_pytorch/na_vit.py +++ b/vit_pytorch/na_vit.py @@ -1,5 +1,5 @@ from functools import partial -from typing import List +from typing import List, Union import torch import torch.nn.functional as F @@ -17,12 +17,58 @@ def exists(val): def default(val, d): return val if exists(val) else d +def always(val): + return lambda *args: val + def pair(t): return t if isinstance(t, tuple) else (t, t) def divisible_by(numer, denom): return (numer % denom) == 0 +# auto grouping images + +def group_images_by_max_seq_len( + images: List[Tensor], + patch_size: int, + calc_token_dropout = None, + max_seq_len = 2048 + +) -> List[List[Tensor]]: + + calc_token_dropout = default(calc_token_dropout, always(0.)) + + groups = [] + group = [] + seq_len = 0 + + if isinstance(calc_token_dropout, (float, int)): + calc_token_dropout = always(calc_token_dropout) + + for image in images: + assert isinstance(image, Tensor) + + image_dims = image.shape[-2:] + ph, pw = map(lambda t: t // patch_size, image_dims) + + image_seq_len = (ph * pw) + image_seq_len = int(image_seq_len * (1 - calc_token_dropout(*image_dims))) + + assert image_seq_len <= max_seq_len, f'image with dimensions {image_dims} exceeds maximum sequence length' + + if (seq_len + image_seq_len) > max_seq_len: + groups.append(group) + group = [] + seq_len = 0 + + group.append(image) + seq_len += image_seq_len + + if len(group) > 0: + groups.append(group) + + return groups + # normalization # they use layernorm without bias, something that pytorch does not offer @@ -199,13 +245,25 @@ class NaViT(nn.Module): def forward( self, - batched_images: List[List[Tensor]] # assume different resolution images already grouped correctly + batched_images: Union[List[Tensor], List[List[Tensor]]], # assume different resolution images already grouped correctly + group_images = False, + group_max_seq_len = 2048 ): p, c, device, has_token_dropout = self.patch_size, self.channels, self.device, exists(self.calc_token_dropout) arange = partial(torch.arange, device = device) pad_sequence = partial(orig_pad_sequence, batch_first = True) + # auto pack if specified + + if group_images: + batched_images = group_images_by_max_seq_len( + batched_images, + patch_size = self.patch_size, + calc_token_dropout = self.calc_token_dropout, + max_seq_len = group_max_seq_len + ) + # process images into variable lengthed sequences with attention mask num_images = []