diff --git a/vit_pytorch/na_vit.py b/vit_pytorch/na_vit.py index 4850d56..0d3d262 100644 --- a/vit_pytorch/na_vit.py +++ b/vit_pytorch/na_vit.py @@ -1,6 +1,6 @@ from __future__ import annotations -from functools import partial +from functools import partial, lru_cache from typing import List import torch @@ -27,6 +27,12 @@ def pair(t): def divisible_by(numer, denom): return (numer % denom) == 0 +@lru_cache(maxsize=128) +def posemb_grid(ph, pw, device): + h_idx = torch.arange(ph, device=device).repeat_interleave(pw) + w_idx = torch.arange(pw, device=device).repeat(ph) + return torch.stack([h_idx, w_idx], dim=-1) + # auto grouping images def group_images_by_max_seq_len( @@ -293,12 +299,8 @@ class NaViT(nn.Module): # extract patches for all images sequences = [rearrange(img, 'c (h p1) (w p2) -> (h w) (c p1 p2)', p1=p, p2=p) for img in images] - # compute positions using repeat_interleave (faster than meshgrid per image) - positions = [] - for ph, pw in patch_dims: - h_idx = arange(ph).repeat_interleave(pw) - w_idx = arange(pw).repeat(ph) - positions.append(torch.stack([h_idx, w_idx], dim=-1)) + # compute positions - uses lru_cache to avoid redundant computation across forward passes + positions = [posemb_grid(ph, pw, device) for ph, pw in patch_dims] # handle token dropout if has_token_dropout: