mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
wrap up NaViT
This commit is contained in:
18
README.md
18
README.md
@@ -179,6 +179,24 @@ preds = v(images) # (5, 1000) - 5, because 5 images of different resolution abov
|
||||
|
||||
```
|
||||
|
||||
Or if you would rather that the framework auto group the images into variable lengthed sequences that do not exceed a certain max length
|
||||
|
||||
```python
|
||||
images = [
|
||||
torch.randn(3, 256, 256),
|
||||
torch.randn(3, 128, 128),
|
||||
torch.randn(3, 128, 256),
|
||||
torch.randn(3, 256, 128),
|
||||
torch.randn(3, 64, 256)
|
||||
]
|
||||
|
||||
preds = v(
|
||||
images,
|
||||
group_images = True,
|
||||
group_max_seq_len = 64
|
||||
) # (5, 1000)
|
||||
```
|
||||
|
||||
## Distillation
|
||||
|
||||
<img src="./images/distill.png" width="300px"></img>
|
||||
|
||||
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '1.2.8',
|
||||
version = '1.2.9',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
long_description_content_type = 'text/markdown',
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from functools import partial
|
||||
from typing import List
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -17,12 +17,58 @@ def exists(val):
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def always(val):
|
||||
return lambda *args: val
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def divisible_by(numer, denom):
|
||||
return (numer % denom) == 0
|
||||
|
||||
# auto grouping images
|
||||
|
||||
def group_images_by_max_seq_len(
|
||||
images: List[Tensor],
|
||||
patch_size: int,
|
||||
calc_token_dropout = None,
|
||||
max_seq_len = 2048
|
||||
|
||||
) -> List[List[Tensor]]:
|
||||
|
||||
calc_token_dropout = default(calc_token_dropout, always(0.))
|
||||
|
||||
groups = []
|
||||
group = []
|
||||
seq_len = 0
|
||||
|
||||
if isinstance(calc_token_dropout, (float, int)):
|
||||
calc_token_dropout = always(calc_token_dropout)
|
||||
|
||||
for image in images:
|
||||
assert isinstance(image, Tensor)
|
||||
|
||||
image_dims = image.shape[-2:]
|
||||
ph, pw = map(lambda t: t // patch_size, image_dims)
|
||||
|
||||
image_seq_len = (ph * pw)
|
||||
image_seq_len = int(image_seq_len * (1 - calc_token_dropout(*image_dims)))
|
||||
|
||||
assert image_seq_len <= max_seq_len, f'image with dimensions {image_dims} exceeds maximum sequence length'
|
||||
|
||||
if (seq_len + image_seq_len) > max_seq_len:
|
||||
groups.append(group)
|
||||
group = []
|
||||
seq_len = 0
|
||||
|
||||
group.append(image)
|
||||
seq_len += image_seq_len
|
||||
|
||||
if len(group) > 0:
|
||||
groups.append(group)
|
||||
|
||||
return groups
|
||||
|
||||
# normalization
|
||||
# they use layernorm without bias, something that pytorch does not offer
|
||||
|
||||
@@ -199,13 +245,25 @@ class NaViT(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batched_images: List[List[Tensor]] # assume different resolution images already grouped correctly
|
||||
batched_images: Union[List[Tensor], List[List[Tensor]]], # assume different resolution images already grouped correctly
|
||||
group_images = False,
|
||||
group_max_seq_len = 2048
|
||||
):
|
||||
p, c, device, has_token_dropout = self.patch_size, self.channels, self.device, exists(self.calc_token_dropout)
|
||||
|
||||
arange = partial(torch.arange, device = device)
|
||||
pad_sequence = partial(orig_pad_sequence, batch_first = True)
|
||||
|
||||
# auto pack if specified
|
||||
|
||||
if group_images:
|
||||
batched_images = group_images_by_max_seq_len(
|
||||
batched_images,
|
||||
patch_size = self.patch_size,
|
||||
calc_token_dropout = self.calc_token_dropout,
|
||||
max_seq_len = group_max_seq_len
|
||||
)
|
||||
|
||||
# process images into variable lengthed sequences with attention mask
|
||||
|
||||
num_images = []
|
||||
|
||||
Reference in New Issue
Block a user