Compare commits

...

6 Commits

Author SHA1 Message Date
lucidrains
5888f05300 1.16.4 2025-12-07 04:32:52 -08:00
Amit Moryossef
d518e89573 cache position grids in NaViT forward pass (#354)
Use lru_cache to cache unique (ph, pw, device) position grids, avoiding
redundant computation when multiple images share the same patch
dimensions. Cache persists across forward passes.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-07 04:32:30 -08:00
lucidrains
dd6462d19b release small navit perf 2025-12-06 04:57:12 -08:00
Amit Moryossef
a1ee1daa1a optimize NaViT with SDPA and vectorized forward pass (#353)
- Replace manual attention with F.scaled_dot_product_attention
- Use repeat_interleave instead of meshgrid for position computation
- Build image_ids efficiently with repeat_interleave instead of F.pad
- Remove unused Rearrange import

~56% speedup (91ms -> 58ms on 512 variable-sized images)
Numerically equivalent (max diff ~5e-4, within flash attention tolerance)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude <noreply@anthropic.com>
2025-12-06 04:56:40 -08:00
lucidrains
3cff5e547a address https://github.com/lucidrains/vit-pytorch/issues/352 2025-12-02 05:21:52 -08:00
lucidrains
fdaf7f92b9 fix positional embed for mean pool case and cleanup 2025-11-27 17:01:47 -08:00
8 changed files with 70 additions and 59 deletions

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "vit-pytorch"
version = "1.16.0"
version = "1.16.4"
description = "Vision Transformer (ViT) - Pytorch"
readme = { file = "README.md", content-type = "text/markdown" }
license = { file = "LICENSE" }

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from functools import partial
from functools import partial, lru_cache
from typing import List
import torch
@@ -9,7 +9,6 @@ from torch import nn, Tensor
from torch.nn.utils.rnn import pad_sequence as orig_pad_sequence
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
@@ -28,6 +27,12 @@ def pair(t):
def divisible_by(numer, denom):
return (numer % denom) == 0
@lru_cache(maxsize=128)
def posemb_grid(ph, pw, device):
h_idx = torch.arange(ph, device=device).repeat_interleave(pw)
w_idx = torch.arange(pw, device=device).repeat(ph)
return torch.stack([h_idx, w_idx], dim=-1)
# auto grouping images
def group_images_by_max_seq_len(
@@ -117,8 +122,7 @@ class Attention(nn.Module):
self.q_norm = RMSNorm(heads, dim_head)
self.k_norm = RMSNorm(heads, dim_head)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.dropout_p = dropout
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
@@ -145,19 +149,22 @@ class Attention(nn.Module):
q = self.q_norm(q)
k = self.k_norm(k)
dots = torch.matmul(q, k.transpose(-1, -2))
# combine masks if both exist
if exists(mask) or exists(attn_mask):
if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
if exists(mask) and exists(attn_mask):
attn_mask = mask & attn_mask
elif exists(mask):
attn_mask = mask
if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
dots = dots.masked_fill(~mask, -torch.finfo(dots.dtype).max)
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask = attn_mask,
dropout_p = self.dropout_p if self.training else 0.,
scale = 1. # RMSNorm already includes sqrt(dim) scaling
)
if exists(attn_mask):
dots = dots.masked_fill(~attn_mask, -torch.finfo(dots.dtype).max)
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
@@ -281,42 +288,41 @@ class NaViT(nn.Module):
for images in batched_images:
num_images.append(len(images))
sequences = []
positions = []
image_ids = torch.empty((0,), device = device, dtype = torch.long)
for image_id, image in enumerate(images):
assert image.ndim ==3 and image.shape[0] == c
# compute patch dimensions for all images
patch_dims = []
for image in images:
assert image.ndim == 3 and image.shape[0] == c
image_dims = image.shape[-2:]
assert all([divisible_by(dim, p) for dim in image_dims]), f'height and width {image_dims} of images must be divisible by patch size {p}'
patch_dims.append((image_dims[0] // p, image_dims[1] // p))
ph, pw = map(lambda dim: dim // p, image_dims)
# extract patches for all images
sequences = [rearrange(img, 'c (h p1) (w p2) -> (h w) (c p1 p2)', p1=p, p2=p) for img in images]
pos = torch.stack(torch.meshgrid((
arange(ph),
arange(pw)
), indexing = 'ij'), dim = -1)
# compute positions - uses lru_cache to avoid redundant computation across forward passes
positions = [posemb_grid(ph, pw, device) for ph, pw in patch_dims]
pos = rearrange(pos, 'h w c -> (h w) c')
seq = rearrange(image, 'c (h p1) (w p2) -> (h w) (c p1 p2)', p1 = p, p2 = p)
seq_len = seq.shape[-2]
if has_token_dropout:
# handle token dropout
if has_token_dropout:
for i, (seq, pos) in enumerate(zip(sequences, positions)):
image_dims = images[i].shape[-2:]
token_dropout = self.calc_token_dropout(*image_dims)
seq_len = seq.shape[0]
num_keep = max(1, int(seq_len * (1 - token_dropout)))
keep_indices = torch.randn((seq_len,), device = device).topk(num_keep, dim = -1).indices
keep_indices = torch.randn((seq_len,), device=device).topk(num_keep, dim=-1).indices
sequences[i] = seq[keep_indices]
positions[i] = pos[keep_indices]
seq = seq[keep_indices]
pos = pos[keep_indices]
image_ids = F.pad(image_ids, (0, seq.shape[-2]), value = image_id)
sequences.append(seq)
positions.append(pos)
# build image_ids efficiently using repeat_interleave
patch_counts = [seq.shape[0] for seq in sequences]
image_ids = torch.repeat_interleave(
arange(len(images)),
torch.tensor(patch_counts, device=device)
)
batched_image_ids.append(image_ids)
batched_sequences.append(torch.cat(sequences, dim = 0))
batched_positions.append(torch.cat(positions, dim = 0))
batched_sequences.append(torch.cat(sequences, dim=0))
batched_positions.append(torch.cat(positions, dim=0))
# derive key padding mask

View File

@@ -176,7 +176,7 @@ class NaViT(Module):
self.channels = channels
self.patch_size = patch_size
self.to_patches = Rearrange('c (f pf) (h p1) (w p2) -> f h w (c p1 p2 pf)', p1 = patch_size, p2 = patch_size, pf = frame_patch_size)
self.to_patches = Rearrange('c (f pf) (h p1) (w p2) -> f h w (c pf p1 p2)', p1 = patch_size, p2 = patch_size, pf = frame_patch_size)
self.to_patch_embedding = nn.Sequential(
nn.LayerNorm(patch_dim),

View File

@@ -146,7 +146,7 @@ class SimpleViT(Module):
patch_dim = channels * patch_height * patch_width * frame_patch_size
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (pf p1 p2 c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),

View File

@@ -103,7 +103,7 @@ class SimpleViT(nn.Module):
patch_dim = channels * patch_height * patch_width * frame_patch_size
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (pf p1 p2 c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),

View File

@@ -1,5 +1,6 @@
import torch
from torch import nn
from torch.nn import Module, ModuleList
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
@@ -11,7 +12,7 @@ def pair(t):
# classes
class FeedForward(nn.Module):
class FeedForward(Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
@@ -26,7 +27,7 @@ class FeedForward(nn.Module):
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
@@ -62,13 +63,14 @@ class Attention(nn.Module):
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
@@ -80,7 +82,7 @@ class Transformer(nn.Module):
return self.norm(x)
class ViT(nn.Module):
class ViT(Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
@@ -101,8 +103,8 @@ class ViT(nn.Module):
nn.LayerNorm(dim),
)
self.cls_token = nn.Parameter(torch.randn(1, num_cls_tokens, dim))
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + num_cls_tokens, dim))
self.cls_token = nn.Parameter(torch.randn(num_cls_tokens, dim))
self.pos_embedding = nn.Parameter(torch.randn(num_patches + num_cls_tokens, dim))
self.dropout = nn.Dropout(emb_dropout)
@@ -114,12 +116,15 @@ class ViT(nn.Module):
self.mlp_head = nn.Linear(dim, num_classes)
def forward(self, img):
batch = img.shape[0]
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '1 ... d -> b ... d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
cls_tokens = repeat(self.cls_token, '... d -> b ... d', b = batch)
x = torch.cat((cls_tokens, x), dim = 1)
seq = x.shape[1]
x = x + self.pos_embedding[:seq]
x = self.dropout(x)
x = self.transformer(x)

View File

@@ -89,7 +89,7 @@ class ViT(nn.Module):
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (f pf) (h p1) (w p2) -> b (f h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
Rearrange('b c (f pf) (h p1) (w p2) -> b (f h w) (pf p1 p2 c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),

View File

@@ -141,7 +141,7 @@ class ViT(nn.Module):
self.global_average_pool = pool == 'mean'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (f pf) (h p1) (w p2) -> b f (h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
Rearrange('b c (f pf) (h p1) (w p2) -> b f (h w) (pf p1 p2 c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim)