From a1ee1daa1a75d75706c271704dc17c11e4a5c8f6 Mon Sep 17 00:00:00 2001 From: Amit Moryossef Date: Sat, 6 Dec 2025 13:56:40 +0100 Subject: [PATCH] optimize NaViT with SDPA and vectorized forward pass (#353) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace manual attention with F.scaled_dot_product_attention - Use repeat_interleave instead of meshgrid for position computation - Build image_ids efficiently with repeat_interleave instead of F.pad - Remove unused Rearrange import ~56% speedup (91ms -> 58ms on 512 variable-sized images) Numerically equivalent (max diff ~5e-4, within flash attention tolerance) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-authored-by: Claude --- vit_pytorch/na_vit.py | 84 ++++++++++++++++++++++--------------------- 1 file changed, 44 insertions(+), 40 deletions(-) diff --git a/vit_pytorch/na_vit.py b/vit_pytorch/na_vit.py index 1aec56a..4850d56 100644 --- a/vit_pytorch/na_vit.py +++ b/vit_pytorch/na_vit.py @@ -9,7 +9,6 @@ from torch import nn, Tensor from torch.nn.utils.rnn import pad_sequence as orig_pad_sequence from einops import rearrange, repeat -from einops.layers.torch import Rearrange # helpers @@ -117,8 +116,7 @@ class Attention(nn.Module): self.q_norm = RMSNorm(heads, dim_head) self.k_norm = RMSNorm(heads, dim_head) - self.attend = nn.Softmax(dim = -1) - self.dropout = nn.Dropout(dropout) + self.dropout_p = dropout self.to_q = nn.Linear(dim, inner_dim, bias = False) self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) @@ -145,19 +143,22 @@ class Attention(nn.Module): q = self.q_norm(q) k = self.k_norm(k) - dots = torch.matmul(q, k.transpose(-1, -2)) + # combine masks if both exist + if exists(mask) or exists(attn_mask): + if exists(mask): + mask = rearrange(mask, 'b j -> b 1 1 j') + if exists(mask) and exists(attn_mask): + attn_mask = mask & attn_mask + elif exists(mask): + attn_mask = mask - if exists(mask): - mask = rearrange(mask, 'b j -> b 1 1 j') - dots = dots.masked_fill(~mask, -torch.finfo(dots.dtype).max) + out = F.scaled_dot_product_attention( + q, k, v, + attn_mask = attn_mask, + dropout_p = self.dropout_p if self.training else 0., + scale = 1. # RMSNorm already includes sqrt(dim) scaling + ) - if exists(attn_mask): - dots = dots.masked_fill(~attn_mask, -torch.finfo(dots.dtype).max) - - attn = self.attend(dots) - attn = self.dropout(attn) - - out = torch.matmul(attn, v) out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out) @@ -281,42 +282,45 @@ class NaViT(nn.Module): for images in batched_images: num_images.append(len(images)) - sequences = [] - positions = [] - image_ids = torch.empty((0,), device = device, dtype = torch.long) - - for image_id, image in enumerate(images): - assert image.ndim ==3 and image.shape[0] == c + # compute patch dimensions for all images + patch_dims = [] + for image in images: + assert image.ndim == 3 and image.shape[0] == c image_dims = image.shape[-2:] assert all([divisible_by(dim, p) for dim in image_dims]), f'height and width {image_dims} of images must be divisible by patch size {p}' + patch_dims.append((image_dims[0] // p, image_dims[1] // p)) - ph, pw = map(lambda dim: dim // p, image_dims) + # extract patches for all images + sequences = [rearrange(img, 'c (h p1) (w p2) -> (h w) (c p1 p2)', p1=p, p2=p) for img in images] - pos = torch.stack(torch.meshgrid(( - arange(ph), - arange(pw) - ), indexing = 'ij'), dim = -1) + # compute positions using repeat_interleave (faster than meshgrid per image) + positions = [] + for ph, pw in patch_dims: + h_idx = arange(ph).repeat_interleave(pw) + w_idx = arange(pw).repeat(ph) + positions.append(torch.stack([h_idx, w_idx], dim=-1)) - pos = rearrange(pos, 'h w c -> (h w) c') - seq = rearrange(image, 'c (h p1) (w p2) -> (h w) (c p1 p2)', p1 = p, p2 = p) - - seq_len = seq.shape[-2] - - if has_token_dropout: + # handle token dropout + if has_token_dropout: + for i, (seq, pos) in enumerate(zip(sequences, positions)): + image_dims = images[i].shape[-2:] token_dropout = self.calc_token_dropout(*image_dims) + seq_len = seq.shape[0] num_keep = max(1, int(seq_len * (1 - token_dropout))) - keep_indices = torch.randn((seq_len,), device = device).topk(num_keep, dim = -1).indices + keep_indices = torch.randn((seq_len,), device=device).topk(num_keep, dim=-1).indices + sequences[i] = seq[keep_indices] + positions[i] = pos[keep_indices] - seq = seq[keep_indices] - pos = pos[keep_indices] - - image_ids = F.pad(image_ids, (0, seq.shape[-2]), value = image_id) - sequences.append(seq) - positions.append(pos) + # build image_ids efficiently using repeat_interleave + patch_counts = [seq.shape[0] for seq in sequences] + image_ids = torch.repeat_interleave( + arange(len(images)), + torch.tensor(patch_counts, device=device) + ) batched_image_ids.append(image_ids) - batched_sequences.append(torch.cat(sequences, dim = 0)) - batched_positions.append(torch.cat(positions, dim = 0)) + batched_sequences.append(torch.cat(sequences, dim=0)) + batched_positions.append(torch.cat(positions, dim=0)) # derive key padding mask