mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6e2393de95 | ||
|
|
32974c33df | ||
|
|
17675e0de4 | ||
|
|
598cffab53 |
25
README.md
25
README.md
@@ -7,7 +7,7 @@
|
||||
- [Usage](#usage)
|
||||
- [Parameters](#parameters)
|
||||
- [Simple ViT](#simple-vit)
|
||||
- [NaViT](#na-vit)
|
||||
- [NaViT](#navit)
|
||||
- [Distillation](#distillation)
|
||||
- [Deep ViT](#deep-vit)
|
||||
- [CaiT](#cait)
|
||||
@@ -142,7 +142,7 @@ preds = v(img) # (1, 1000)
|
||||
|
||||
## NaViT
|
||||
|
||||
<img src="./images/na_vit.png" width="450px"></img>
|
||||
<img src="./images/navit.png" width="450px"></img>
|
||||
|
||||
<a href="https://arxiv.org/abs/2307.06304">This paper</a> proposes to leverage the flexibility of attention and masking for variable lengthed sequences to train images of multiple resolution, packed into a single batch. They demonstrate much faster training and improved accuracies, with the only cost being extra complexity in the architecture and dataloading. They use factorized 2d positional encodings, token dropping, as well as query-key normalization.
|
||||
|
||||
@@ -161,7 +161,8 @@ v = NaViT(
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
emb_dropout = 0.1,
|
||||
token_dropout_prob = 0.1 # token dropout of 10% (keep 90% of tokens)
|
||||
)
|
||||
|
||||
# 5 images of different resolutions - List[List[Tensor]]
|
||||
@@ -178,6 +179,24 @@ preds = v(images) # (5, 1000) - 5, because 5 images of different resolution abov
|
||||
|
||||
```
|
||||
|
||||
Or if you would rather that the framework auto group the images into variable lengthed sequences that do not exceed a certain max length
|
||||
|
||||
```python
|
||||
images = [
|
||||
torch.randn(3, 256, 256),
|
||||
torch.randn(3, 128, 128),
|
||||
torch.randn(3, 128, 256),
|
||||
torch.randn(3, 256, 128),
|
||||
torch.randn(3, 64, 256)
|
||||
]
|
||||
|
||||
preds = v(
|
||||
images,
|
||||
group_images = True,
|
||||
group_max_seq_len = 64
|
||||
) # (5, 1000)
|
||||
```
|
||||
|
||||
## Distillation
|
||||
|
||||
<img src="./images/distill.png" width="300px"></img>
|
||||
|
||||
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '1.2.5',
|
||||
version = '1.2.9',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
long_description_content_type = 'text/markdown',
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from functools import partial
|
||||
from typing import List
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -17,12 +17,58 @@ def exists(val):
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def always(val):
|
||||
return lambda *args: val
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def divisible_by(numer, denom):
|
||||
return (numer % denom) == 0
|
||||
|
||||
# auto grouping images
|
||||
|
||||
def group_images_by_max_seq_len(
|
||||
images: List[Tensor],
|
||||
patch_size: int,
|
||||
calc_token_dropout = None,
|
||||
max_seq_len = 2048
|
||||
|
||||
) -> List[List[Tensor]]:
|
||||
|
||||
calc_token_dropout = default(calc_token_dropout, always(0.))
|
||||
|
||||
groups = []
|
||||
group = []
|
||||
seq_len = 0
|
||||
|
||||
if isinstance(calc_token_dropout, (float, int)):
|
||||
calc_token_dropout = always(calc_token_dropout)
|
||||
|
||||
for image in images:
|
||||
assert isinstance(image, Tensor)
|
||||
|
||||
image_dims = image.shape[-2:]
|
||||
ph, pw = map(lambda t: t // patch_size, image_dims)
|
||||
|
||||
image_seq_len = (ph * pw)
|
||||
image_seq_len = int(image_seq_len * (1 - calc_token_dropout(*image_dims)))
|
||||
|
||||
assert image_seq_len <= max_seq_len, f'image with dimensions {image_dims} exceeds maximum sequence length'
|
||||
|
||||
if (seq_len + image_seq_len) > max_seq_len:
|
||||
groups.append(group)
|
||||
group = []
|
||||
seq_len = 0
|
||||
|
||||
group.append(image)
|
||||
seq_len += image_seq_len
|
||||
|
||||
if len(group) > 0:
|
||||
groups.append(group)
|
||||
|
||||
return groups
|
||||
|
||||
# normalization
|
||||
# they use layernorm without bias, something that pytorch does not offer
|
||||
|
||||
@@ -138,10 +184,26 @@ class Transformer(nn.Module):
|
||||
return self.norm(x)
|
||||
|
||||
class NaViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., token_dropout_prob = None):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
|
||||
# what percent of tokens to dropout
|
||||
# if int or float given, then assume constant dropout prob
|
||||
# otherwise accept a callback that in turn calculates dropout prob from height and width
|
||||
|
||||
self.calc_token_dropout = None
|
||||
|
||||
if callable(token_dropout_prob):
|
||||
self.calc_token_dropout = token_dropout_prob
|
||||
|
||||
elif isinstance(token_dropout_prob, (float, int)):
|
||||
assert 0. < token_dropout_prob < 1.
|
||||
token_dropout_prob = float(token_dropout_prob)
|
||||
self.calc_token_dropout = lambda height, width: token_dropout_prob
|
||||
|
||||
# calculate patching related stuff
|
||||
|
||||
assert divisible_by(image_height, patch_size) and divisible_by(image_width, patch_size), 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
patch_height_dim, patch_width_dim = (image_height // patch_size), (image_width // patch_size)
|
||||
@@ -183,13 +245,25 @@ class NaViT(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batched_images: List[List[Tensor]] # assume different resolution images already grouped correctly
|
||||
batched_images: Union[List[Tensor], List[List[Tensor]]], # assume different resolution images already grouped correctly
|
||||
group_images = False,
|
||||
group_max_seq_len = 2048
|
||||
):
|
||||
p, c, device = self.patch_size, self.channels, self.device
|
||||
p, c, device, has_token_dropout = self.patch_size, self.channels, self.device, exists(self.calc_token_dropout)
|
||||
|
||||
arange = partial(torch.arange, device = device)
|
||||
pad_sequence = partial(orig_pad_sequence, batch_first = True)
|
||||
|
||||
# auto pack if specified
|
||||
|
||||
if group_images:
|
||||
batched_images = group_images_by_max_seq_len(
|
||||
batched_images,
|
||||
patch_size = self.patch_size,
|
||||
calc_token_dropout = self.calc_token_dropout,
|
||||
max_seq_len = group_max_seq_len
|
||||
)
|
||||
|
||||
# process images into variable lengthed sequences with attention mask
|
||||
|
||||
num_images = []
|
||||
@@ -219,6 +293,16 @@ class NaViT(nn.Module):
|
||||
pos = rearrange(pos, 'h w c -> (h w) c')
|
||||
seq = rearrange(image, 'c (h p1) (w p2) -> (h w) (c p1 p2)', p1 = p, p2 = p)
|
||||
|
||||
seq_len = seq.shape[-2]
|
||||
|
||||
if has_token_dropout:
|
||||
token_dropout = self.calc_token_dropout(*image_dims)
|
||||
num_keep = max(1, int(seq_len * (1 - token_dropout)))
|
||||
keep_indices = torch.randn((seq_len,), device = device).topk(num_keep, dim = -1).indices
|
||||
|
||||
seq = seq[keep_indices]
|
||||
pos = pos[keep_indices]
|
||||
|
||||
image_ids = F.pad(image_ids, (0, seq.shape[-2]), value = image_id)
|
||||
sequences.append(seq)
|
||||
positions.append(pos)
|
||||
|
||||
Reference in New Issue
Block a user