optimize NaViT with SDPA and vectorized forward pass (#353)

- Replace manual attention with F.scaled_dot_product_attention
- Use repeat_interleave instead of meshgrid for position computation
- Build image_ids efficiently with repeat_interleave instead of F.pad
- Remove unused Rearrange import

~56% speedup (91ms -> 58ms on 512 variable-sized images)
Numerically equivalent (max diff ~5e-4, within flash attention tolerance)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Amit Moryossef
2025-12-06 13:56:40 +01:00
committed by GitHub
parent 3cff5e547a
commit a1ee1daa1a

View File

@@ -9,7 +9,6 @@ from torch import nn, Tensor
from torch.nn.utils.rnn import pad_sequence as orig_pad_sequence
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
@@ -117,8 +116,7 @@ class Attention(nn.Module):
self.q_norm = RMSNorm(heads, dim_head)
self.k_norm = RMSNorm(heads, dim_head)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.dropout_p = dropout
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
@@ -145,19 +143,22 @@ class Attention(nn.Module):
q = self.q_norm(q)
k = self.k_norm(k)
dots = torch.matmul(q, k.transpose(-1, -2))
# combine masks if both exist
if exists(mask) or exists(attn_mask):
if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
dots = dots.masked_fill(~mask, -torch.finfo(dots.dtype).max)
if exists(mask) and exists(attn_mask):
attn_mask = mask & attn_mask
elif exists(mask):
attn_mask = mask
if exists(attn_mask):
dots = dots.masked_fill(~attn_mask, -torch.finfo(dots.dtype).max)
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask = attn_mask,
dropout_p = self.dropout_p if self.training else 0.,
scale = 1. # RMSNorm already includes sqrt(dim) scaling
)
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
@@ -281,38 +282,41 @@ class NaViT(nn.Module):
for images in batched_images:
num_images.append(len(images))
sequences = []
positions = []
image_ids = torch.empty((0,), device = device, dtype = torch.long)
for image_id, image in enumerate(images):
# compute patch dimensions for all images
patch_dims = []
for image in images:
assert image.ndim == 3 and image.shape[0] == c
image_dims = image.shape[-2:]
assert all([divisible_by(dim, p) for dim in image_dims]), f'height and width {image_dims} of images must be divisible by patch size {p}'
patch_dims.append((image_dims[0] // p, image_dims[1] // p))
ph, pw = map(lambda dim: dim // p, image_dims)
# extract patches for all images
sequences = [rearrange(img, 'c (h p1) (w p2) -> (h w) (c p1 p2)', p1=p, p2=p) for img in images]
pos = torch.stack(torch.meshgrid((
arange(ph),
arange(pw)
), indexing = 'ij'), dim = -1)
pos = rearrange(pos, 'h w c -> (h w) c')
seq = rearrange(image, 'c (h p1) (w p2) -> (h w) (c p1 p2)', p1 = p, p2 = p)
seq_len = seq.shape[-2]
# compute positions using repeat_interleave (faster than meshgrid per image)
positions = []
for ph, pw in patch_dims:
h_idx = arange(ph).repeat_interleave(pw)
w_idx = arange(pw).repeat(ph)
positions.append(torch.stack([h_idx, w_idx], dim=-1))
# handle token dropout
if has_token_dropout:
for i, (seq, pos) in enumerate(zip(sequences, positions)):
image_dims = images[i].shape[-2:]
token_dropout = self.calc_token_dropout(*image_dims)
seq_len = seq.shape[0]
num_keep = max(1, int(seq_len * (1 - token_dropout)))
keep_indices = torch.randn((seq_len,), device=device).topk(num_keep, dim=-1).indices
sequences[i] = seq[keep_indices]
positions[i] = pos[keep_indices]
seq = seq[keep_indices]
pos = pos[keep_indices]
image_ids = F.pad(image_ids, (0, seq.shape[-2]), value = image_id)
sequences.append(seq)
positions.append(pos)
# build image_ids efficiently using repeat_interleave
patch_counts = [seq.shape[0] for seq in sequences]
image_ids = torch.repeat_interleave(
arange(len(images)),
torch.tensor(patch_counts, device=device)
)
batched_image_ids.append(image_ids)
batched_sequences.append(torch.cat(sequences, dim=0))