wrap up NaViT

This commit is contained in:
Phil Wang
2023-07-25 10:38:55 -07:00
parent 32974c33df
commit 6e2393de95
3 changed files with 79 additions and 3 deletions

View File

@@ -179,6 +179,24 @@ preds = v(images) # (5, 1000) - 5, because 5 images of different resolution abov
```
Or if you would rather that the framework auto group the images into variable lengthed sequences that do not exceed a certain max length
```python
images = [
torch.randn(3, 256, 256),
torch.randn(3, 128, 128),
torch.randn(3, 128, 256),
torch.randn(3, 256, 128),
torch.randn(3, 64, 256)
]
preds = v(
images,
group_images = True,
group_max_seq_len = 64
) # (5, 1000)
```
## Distillation
<img src="./images/distill.png" width="300px"></img>

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.2.8',
version = '1.2.9',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description_content_type = 'text/markdown',

View File

@@ -1,5 +1,5 @@
from functools import partial
from typing import List
from typing import List, Union
import torch
import torch.nn.functional as F
@@ -17,12 +17,58 @@ def exists(val):
def default(val, d):
return val if exists(val) else d
def always(val):
return lambda *args: val
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def divisible_by(numer, denom):
return (numer % denom) == 0
# auto grouping images
def group_images_by_max_seq_len(
images: List[Tensor],
patch_size: int,
calc_token_dropout = None,
max_seq_len = 2048
) -> List[List[Tensor]]:
calc_token_dropout = default(calc_token_dropout, always(0.))
groups = []
group = []
seq_len = 0
if isinstance(calc_token_dropout, (float, int)):
calc_token_dropout = always(calc_token_dropout)
for image in images:
assert isinstance(image, Tensor)
image_dims = image.shape[-2:]
ph, pw = map(lambda t: t // patch_size, image_dims)
image_seq_len = (ph * pw)
image_seq_len = int(image_seq_len * (1 - calc_token_dropout(*image_dims)))
assert image_seq_len <= max_seq_len, f'image with dimensions {image_dims} exceeds maximum sequence length'
if (seq_len + image_seq_len) > max_seq_len:
groups.append(group)
group = []
seq_len = 0
group.append(image)
seq_len += image_seq_len
if len(group) > 0:
groups.append(group)
return groups
# normalization
# they use layernorm without bias, something that pytorch does not offer
@@ -199,13 +245,25 @@ class NaViT(nn.Module):
def forward(
self,
batched_images: List[List[Tensor]] # assume different resolution images already grouped correctly
batched_images: Union[List[Tensor], List[List[Tensor]]], # assume different resolution images already grouped correctly
group_images = False,
group_max_seq_len = 2048
):
p, c, device, has_token_dropout = self.patch_size, self.channels, self.device, exists(self.calc_token_dropout)
arange = partial(torch.arange, device = device)
pad_sequence = partial(orig_pad_sequence, batch_first = True)
# auto pack if specified
if group_images:
batched_images = group_images_by_max_seq_len(
batched_images,
patch_size = self.patch_size,
calc_token_dropout = self.calc_token_dropout,
max_seq_len = group_max_seq_len
)
# process images into variable lengthed sequences with attention mask
num_images = []