Commit Graph

8 Commits

Author SHA1 Message Date
Amit Moryossef
d518e89573 cache position grids in NaViT forward pass (#354)
Use lru_cache to cache unique (ph, pw, device) position grids, avoiding
redundant computation when multiple images share the same patch
dimensions. Cache persists across forward passes.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-07 04:32:30 -08:00
Amit Moryossef
a1ee1daa1a optimize NaViT with SDPA and vectorized forward pass (#353)
- Replace manual attention with F.scaled_dot_product_attention
- Use repeat_interleave instead of meshgrid for position computation
- Build image_ids efficiently with repeat_interleave instead of F.pad
- Remove unused Rearrange import

~56% speedup (91ms -> 58ms on 512 variable-sized images)
Numerically equivalent (max diff ~5e-4, within flash attention tolerance)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude <noreply@anthropic.com>
2025-12-06 04:56:40 -08:00
lucidrains
73199ab486 Nested navit (#325)
add a variant of NaViT using nested tensors
2024-08-20 15:12:29 -07:00
Phil Wang
96f66d2754 address https://github.com/lucidrains/vit-pytorch/issues/306 2024-04-18 09:44:29 -07:00
Phil Wang
6e2393de95 wrap up NaViT 2023-07-25 10:38:55 -07:00
Phil Wang
32974c33df one can pass a callback to token_dropout_prob for NaViT that takes in height and width and calculate appropriate dropout rate 2023-07-24 14:52:40 -07:00
Phil Wang
17675e0de4 add constant token dropout for NaViT 2023-07-24 14:14:36 -07:00
Phil Wang
23820bc54a begin work on NaViT (#273)
finish core idea of NaViT
2023-07-24 13:54:02 -07:00