From d518e89573174d471e1c3ea988e0c69f6e9f6863 Mon Sep 17 00:00:00 2001 From: Amit Moryossef Date: Sun, 7 Dec 2025 13:32:30 +0100 Subject: [PATCH] cache position grids in NaViT forward pass (#354) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Use lru_cache to cache unique (ph, pw, device) position grids, avoiding redundant computation when multiple images share the same patch dimensions. Cache persists across forward passes. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-authored-by: Claude Opus 4.5 --- vit_pytorch/na_vit.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/vit_pytorch/na_vit.py b/vit_pytorch/na_vit.py index 4850d56..0d3d262 100644 --- a/vit_pytorch/na_vit.py +++ b/vit_pytorch/na_vit.py @@ -1,6 +1,6 @@ from __future__ import annotations -from functools import partial +from functools import partial, lru_cache from typing import List import torch @@ -27,6 +27,12 @@ def pair(t): def divisible_by(numer, denom): return (numer % denom) == 0 +@lru_cache(maxsize=128) +def posemb_grid(ph, pw, device): + h_idx = torch.arange(ph, device=device).repeat_interleave(pw) + w_idx = torch.arange(pw, device=device).repeat(ph) + return torch.stack([h_idx, w_idx], dim=-1) + # auto grouping images def group_images_by_max_seq_len( @@ -293,12 +299,8 @@ class NaViT(nn.Module): # extract patches for all images sequences = [rearrange(img, 'c (h p1) (w p2) -> (h w) (c p1 p2)', p1=p, p2=p) for img in images] - # compute positions using repeat_interleave (faster than meshgrid per image) - positions = [] - for ph, pw in patch_dims: - h_idx = arange(ph).repeat_interleave(pw) - w_idx = arange(pw).repeat(ph) - positions.append(torch.stack([h_idx, w_idx], dim=-1)) + # compute positions - uses lru_cache to avoid redundant computation across forward passes + positions = [posemb_grid(ph, pw, device) for ph, pw in patch_dims] # handle token dropout if has_token_dropout: