cache position grids in NaViT forward pass (#354)

Use lru_cache to cache unique (ph, pw, device) position grids, avoiding
redundant computation when multiple images share the same patch
dimensions. Cache persists across forward passes.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Amit Moryossef
2025-12-07 13:32:30 +01:00
committed by GitHub
parent dd6462d19b
commit d518e89573

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from functools import partial
from functools import partial, lru_cache
from typing import List
import torch
@@ -27,6 +27,12 @@ def pair(t):
def divisible_by(numer, denom):
return (numer % denom) == 0
@lru_cache(maxsize=128)
def posemb_grid(ph, pw, device):
h_idx = torch.arange(ph, device=device).repeat_interleave(pw)
w_idx = torch.arange(pw, device=device).repeat(ph)
return torch.stack([h_idx, w_idx], dim=-1)
# auto grouping images
def group_images_by_max_seq_len(
@@ -293,12 +299,8 @@ class NaViT(nn.Module):
# extract patches for all images
sequences = [rearrange(img, 'c (h p1) (w p2) -> (h w) (c p1 p2)', p1=p, p2=p) for img in images]
# compute positions using repeat_interleave (faster than meshgrid per image)
positions = []
for ph, pw in patch_dims:
h_idx = arange(ph).repeat_interleave(pw)
w_idx = arange(pw).repeat(ph)
positions.append(torch.stack([h_idx, w_idx], dim=-1))
# compute positions - uses lru_cache to avoid redundant computation across forward passes
positions = [posemb_grid(ph, pw, device) for ph, pw in patch_dims]
# handle token dropout
if has_token_dropout: