Compare commits

...

11 Commits

Author SHA1 Message Date
lucidrains
7e703f239f get a version of n-dimensional vit with golden gate polar coordinate embeddings into the repo for future use 2025-12-25 06:57:56 -08:00
Phil Wang
0b7518ef45 educate 2025-12-21 07:06:20 -08:00
lucidrains
077d8c188f fix distill 2025-12-10 15:52:10 -08:00
lucidrains
5888f05300 1.16.4 2025-12-07 04:32:52 -08:00
Amit Moryossef
d518e89573 cache position grids in NaViT forward pass (#354)
Use lru_cache to cache unique (ph, pw, device) position grids, avoiding
redundant computation when multiple images share the same patch
dimensions. Cache persists across forward passes.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-07 04:32:30 -08:00
lucidrains
dd6462d19b release small navit perf 2025-12-06 04:57:12 -08:00
Amit Moryossef
a1ee1daa1a optimize NaViT with SDPA and vectorized forward pass (#353)
- Replace manual attention with F.scaled_dot_product_attention
- Use repeat_interleave instead of meshgrid for position computation
- Build image_ids efficiently with repeat_interleave instead of F.pad
- Remove unused Rearrange import

~56% speedup (91ms -> 58ms on 512 variable-sized images)
Numerically equivalent (max diff ~5e-4, within flash attention tolerance)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude <noreply@anthropic.com>
2025-12-06 04:56:40 -08:00
lucidrains
3cff5e547a address https://github.com/lucidrains/vit-pytorch/issues/352 2025-12-02 05:21:52 -08:00
lucidrains
fdaf7f92b9 fix positional embed for mean pool case and cleanup 2025-11-27 17:01:47 -08:00
lucidrains
0ebd4edab9 address https://github.com/lucidrains/vit-pytorch/issues/351 2025-11-27 06:07:43 -08:00
lucidrains
aa49c2783a VAAT should have two ears 2025-11-22 08:32:23 -08:00
12 changed files with 494 additions and 83 deletions

View File

@@ -49,7 +49,7 @@
## Vision Transformer - Pytorch
Implementation of <a href="https://openreview.net/pdf?id=YicbFdNTTy">Vision Transformer</a>, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. Significance is further explained in <a href="https://www.youtube.com/watch?v=TrdevFK_am4">Yannic Kilcher's</a> video. There's really not much to code here, but may as well lay it out for everyone so we expedite the attention revolution.
Implementation of <a href="https://openreview.net/pdf?id=YicbFdNTTy">Vision Transformer</a>, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. Significance is further explained in <a href="https://www.youtube.com/watch?v=TrdevFK_am4">Yannic Kilcher's</a> video. There's really not much to code here, but may as well lay it out for everyone so we expedite the [attention](https://www.youtube.com/watch?v=eMlx5fFNoYc) revolution.
For a Pytorch implementation with pretrained models, please see Ross Wightman's repository <a href="https://github.com/rwightman/pytorch-image-models">here</a>.
@@ -2213,4 +2213,16 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```
```bibtex
@misc{gopalakrishnan2025decouplingwhatwherepolar,
title = {Decoupling the "What" and "Where" With Polar Coordinate Positional Embeddings},
author = {Anand Gopalakrishnan and Robert Csordás and Jürgen Schmidhuber and Michael C. Mozer},
year = {2025},
eprint = {2509.10534},
archivePrefix = {arXiv},
primaryClass = {cs.LG},
url = {https://arxiv.org/abs/2509.10534},
}
```
*I visualise a time when we will be to robots what dogs are to humans, and Im rooting for the machines.* — Claude Shannon

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "vit-pytorch"
version = "1.15.6"
version = "1.17.1"
description = "Vision Transformer (ViT) - Pytorch"
readme = { file = "README.md", content-type = "text/markdown" }
license = { file = "LICENSE" }

View File

@@ -25,12 +25,12 @@ class DistillMixin:
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = b)
cls_tokens = repeat(self.cls_token, 'n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim = 1)
x += self.pos_embedding[:, :(n + 1)]
x += self.pos_embedding[:(n + 1)]
if distilling:
distill_tokens = repeat(distill_token, '1 n d -> b n d', b = b)
distill_tokens = repeat(distill_token, 'n d -> b n d', b = b)
x = torch.cat((x, distill_tokens), dim = 1)
x = self._attend(x)
@@ -125,7 +125,7 @@ class DistillWrapper(Module):
self.alpha = alpha
self.hard = hard
self.distillation_token = nn.Parameter(torch.randn(1, 1, dim))
self.distillation_token = nn.Parameter(torch.randn(1, dim))
self.distill_mlp = nn.Sequential(
nn.LayerNorm(dim) if mlp_layernorm else nn.Identity(),

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from functools import partial
from functools import partial, lru_cache
from typing import List
import torch
@@ -9,7 +9,6 @@ from torch import nn, Tensor
from torch.nn.utils.rnn import pad_sequence as orig_pad_sequence
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
@@ -28,6 +27,12 @@ def pair(t):
def divisible_by(numer, denom):
return (numer % denom) == 0
@lru_cache(maxsize=128)
def posemb_grid(ph, pw, device):
h_idx = torch.arange(ph, device=device).repeat_interleave(pw)
w_idx = torch.arange(pw, device=device).repeat(ph)
return torch.stack([h_idx, w_idx], dim=-1)
# auto grouping images
def group_images_by_max_seq_len(
@@ -117,8 +122,7 @@ class Attention(nn.Module):
self.q_norm = RMSNorm(heads, dim_head)
self.k_norm = RMSNorm(heads, dim_head)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.dropout_p = dropout
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
@@ -145,19 +149,22 @@ class Attention(nn.Module):
q = self.q_norm(q)
k = self.k_norm(k)
dots = torch.matmul(q, k.transpose(-1, -2))
# combine masks if both exist
if exists(mask) or exists(attn_mask):
if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
if exists(mask) and exists(attn_mask):
attn_mask = mask & attn_mask
elif exists(mask):
attn_mask = mask
if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
dots = dots.masked_fill(~mask, -torch.finfo(dots.dtype).max)
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask = attn_mask,
dropout_p = self.dropout_p if self.training else 0.,
scale = 1. # RMSNorm already includes sqrt(dim) scaling
)
if exists(attn_mask):
dots = dots.masked_fill(~attn_mask, -torch.finfo(dots.dtype).max)
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
@@ -281,42 +288,41 @@ class NaViT(nn.Module):
for images in batched_images:
num_images.append(len(images))
sequences = []
positions = []
image_ids = torch.empty((0,), device = device, dtype = torch.long)
for image_id, image in enumerate(images):
assert image.ndim ==3 and image.shape[0] == c
# compute patch dimensions for all images
patch_dims = []
for image in images:
assert image.ndim == 3 and image.shape[0] == c
image_dims = image.shape[-2:]
assert all([divisible_by(dim, p) for dim in image_dims]), f'height and width {image_dims} of images must be divisible by patch size {p}'
patch_dims.append((image_dims[0] // p, image_dims[1] // p))
ph, pw = map(lambda dim: dim // p, image_dims)
# extract patches for all images
sequences = [rearrange(img, 'c (h p1) (w p2) -> (h w) (c p1 p2)', p1=p, p2=p) for img in images]
pos = torch.stack(torch.meshgrid((
arange(ph),
arange(pw)
), indexing = 'ij'), dim = -1)
# compute positions - uses lru_cache to avoid redundant computation across forward passes
positions = [posemb_grid(ph, pw, device) for ph, pw in patch_dims]
pos = rearrange(pos, 'h w c -> (h w) c')
seq = rearrange(image, 'c (h p1) (w p2) -> (h w) (c p1 p2)', p1 = p, p2 = p)
seq_len = seq.shape[-2]
if has_token_dropout:
# handle token dropout
if has_token_dropout:
for i, (seq, pos) in enumerate(zip(sequences, positions)):
image_dims = images[i].shape[-2:]
token_dropout = self.calc_token_dropout(*image_dims)
seq_len = seq.shape[0]
num_keep = max(1, int(seq_len * (1 - token_dropout)))
keep_indices = torch.randn((seq_len,), device = device).topk(num_keep, dim = -1).indices
keep_indices = torch.randn((seq_len,), device=device).topk(num_keep, dim=-1).indices
sequences[i] = seq[keep_indices]
positions[i] = pos[keep_indices]
seq = seq[keep_indices]
pos = pos[keep_indices]
image_ids = F.pad(image_ids, (0, seq.shape[-2]), value = image_id)
sequences.append(seq)
positions.append(pos)
# build image_ids efficiently using repeat_interleave
patch_counts = [seq.shape[0] for seq in sequences]
image_ids = torch.repeat_interleave(
arange(len(images)),
torch.tensor(patch_counts, device=device)
)
batched_image_ids.append(image_ids)
batched_sequences.append(torch.cat(sequences, dim = 0))
batched_positions.append(torch.cat(positions, dim = 0))
batched_sequences.append(torch.cat(sequences, dim=0))
batched_positions.append(torch.cat(positions, dim=0))
# derive key padding mask

View File

@@ -176,7 +176,7 @@ class NaViT(Module):
self.channels = channels
self.patch_size = patch_size
self.to_patches = Rearrange('c (f pf) (h p1) (w p2) -> f h w (c p1 p2 pf)', p1 = patch_size, p2 = patch_size, pf = frame_patch_size)
self.to_patches = Rearrange('c (f pf) (h p1) (w p2) -> f h w (c pf p1 p2)', p1 = patch_size, p2 = patch_size, pf = frame_patch_size)
self.to_patch_embedding = nn.Sequential(
nn.LayerNorm(patch_dim),

View File

@@ -146,7 +146,7 @@ class SimpleViT(Module):
patch_dim = channels * patch_height * patch_width * frame_patch_size
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (pf p1 p2 c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),

View File

@@ -103,7 +103,7 @@ class SimpleViT(nn.Module):
patch_dim = channels * patch_height * patch_width * frame_patch_size
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (pf p1 p2 c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),

View File

@@ -421,7 +421,8 @@ class VAAT(Module):
dim_head,
dim_action,
mlp_dim,
num_views = None,
num_image_views = None,
num_audio_views = None,
num_tasks = None,
dim_extra_token = None,
num_register_tokens = 4,
@@ -462,6 +463,8 @@ class VAAT(Module):
ast_dim = ast.dim
self.ast_accept_spec = ast.accept_spec
assert ast.depth == depth or exists(ast_layer_indices), f'if the VAAT depth is not equal to the AST depth, you must pass in the indices from the AST to be layered to the VAAT in order from bottom to top'
ast_layer_indices = default(ast_layer_indices, tuple(range(depth)))
@@ -480,7 +483,9 @@ class VAAT(Module):
# maybe view embeddings
self.view_emb = nn.Parameter(torch.randn(num_views, vit_dim) * 1e-2) if exists(num_views) and num_views > 1 else None
self.image_view_emb = nn.Parameter(torch.randn(num_image_views, vit_dim) * 1e-2) if exists(num_image_views) and num_image_views > 1 else None
self.audio_view_emb = nn.Parameter(torch.randn(num_audio_views, ast_dim) * 1e-2) if exists(num_audio_views) and num_audio_views > 1 else None
# handle maybe task conditioning
@@ -523,12 +528,12 @@ class VAAT(Module):
def forward(
self,
video_or_image, # (b v? c t? h w) - batch, views [wrist + third person or more], channels, maybe time, height, width
audio_or_spec, # (b t) | (b f t) - batch, audio len | batch, spec freq, time
video_or_image, # (b v? c t? h w) - batch, views [wrist + third person or more], channels, maybe time, height, width
audio_or_spec, # (b v? t) | (b v?f t) - batch, audio len | batch, spec freq, time
*,
extra = None, # (b d) - batch, dim extra
extra = None, # (b d) - batch, dim extra
tasks = None, # (b)
actions = None, # (b k d) - batch, action chunk length, action dimension
actions = None, # (b k d) - batch, action chunk length, action dimension
return_hiddens = False,
freeze_vit = False,
freeze_ast = False
@@ -551,11 +556,26 @@ class VAAT(Module):
assert video_or_image.shape[3] == self.time_seq_len
# audio shapes - adding view if impliciy to be 1
if audio_or_spec.ndim == 2 and not self.ast_accept_spec:
audio_or_spec = rearrange(audio_or_spec, 'b t -> b 1 t')
elif audio_or_spec.ndim == 3 and self.ast_accept_spec:
audio_or_spec = rearrange(audio_or_spec, 'b f t -> b 1 f t')
# to images
images = rearrange(video_or_image, 'b v c t h w -> b v t c h w')
images, packed_shape = pack([images], '* c h w')
images, image_packed_shape = pack([images], '* c h w')
# to audio
if self.ast_accept_spec:
audio_or_spec, audio_packed_shape = pack([audio_or_spec], '* f t')
else:
audio_or_spec, audio_packed_shape = pack([audio_or_spec], '* t')
# get representation trajectory from vit
@@ -570,9 +590,9 @@ class VAAT(Module):
hiddens = hiddens[self.vit_layer_indices]
# pack temporarily for embedding
# unpack temporarily for embedding
hiddens, = unpack(hiddens, packed_shape, 'l * n d') # l for layers
hiddens, = unpack(hiddens, image_packed_shape, 'l * n d') # l for layers
# maybe add time embeddings
@@ -582,11 +602,11 @@ class VAAT(Module):
# maybe view embeddings
if exists(self.view_emb):
assert self.view_emb.shape[0] == hiddens.shape[2]
if exists(self.image_view_emb):
assert self.image_view_emb.shape[0] == hiddens.shape[2]
view_emb = rearrange(self.view_emb, 'v d -> v 1 1 d')
hiddens = hiddens + view_emb
image_view_emb = rearrange(self.image_view_emb, 'v d -> v 1 1 d')
hiddens = hiddens + image_view_emb
# get representation trajectory from ast
@@ -601,6 +621,18 @@ class VAAT(Module):
audio_hiddens = audio_hiddens[self.ast_layer_indices]
# unpack audio temporarily for embedding
audio_hiddens, = unpack(audio_hiddens, audio_packed_shape, 'l * n d') # l for layers
# maybe audio view embeddings
if exists(self.audio_view_emb):
assert self.audio_view_emb.shape[0] == audio_hiddens.shape[2]
audio_view_emb = rearrange(self.audio_view_emb, 'v d -> v 1 1 d')
audio_hiddens = audio_hiddens + audio_view_emb
# maybe tasks
if exists(tasks):
@@ -612,7 +644,7 @@ class VAAT(Module):
image_context = rearrange(hiddens, 'l b v t n d -> l b (v t n) d')
audio_context = audio_hiddens # eventually handle views (stereo and beyond)
audio_context = rearrange(audio_hiddens, 'l b v n d -> l b (v n) d')
# get main action tokens and maybe append extra
@@ -703,7 +735,7 @@ if __name__ == '__main__':
mlp_dim = 384 * 4
)
vat = VAAT(
vaat = VAAT(
vit,
ast,
dim = 512,
@@ -714,7 +746,8 @@ if __name__ == '__main__':
dim_action = 20,
action_chunk_len = 7,
time_seq_len = 4,
num_views = 2,
num_image_views = 2,
num_audio_views = 2,
num_tasks = 4,
add_self_attn = True,
dim_extra_token = 33, # extra token with some variable dimension
@@ -727,18 +760,18 @@ if __name__ == '__main__':
)
images = torch.randn(2, 2, 3, 4, 256, 256) # (2 views with 4 frames)
audio = torch.randn(2, 14_100 * 5)
audio = torch.randn(2, 2, 14_100 * 5)
tasks = torch.randint(0, 4, (2,))
extra = torch.randn(2, 33) # extra internal state
actions = torch.randn(2, 7, 20) # actions for learning
loss = vat(images, audio, actions = actions, tasks = tasks, extra = extra, freeze_vit = True)
loss = vaat(images, audio, actions = actions, tasks = tasks, extra = extra, freeze_vit = True)
loss.backward()
# after much training
pred_actions, hiddens = vat(images, audio, tasks = tasks, extra = extra, return_hiddens = True)
pred_actions, hiddens = vaat(images, audio, tasks = tasks, extra = extra, return_hiddens = True)
assert pred_actions.shape == (2, 7, 20)

View File

@@ -1,5 +1,6 @@
import torch
from torch import nn
from torch.nn import Module, ModuleList
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
@@ -11,7 +12,7 @@ def pair(t):
# classes
class FeedForward(nn.Module):
class FeedForward(Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
@@ -26,7 +27,7 @@ class FeedForward(nn.Module):
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
@@ -62,13 +63,14 @@ class Attention(nn.Module):
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
@@ -80,7 +82,7 @@ class Transformer(nn.Module):
return self.norm(x)
class ViT(nn.Module):
class ViT(Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
@@ -90,7 +92,9 @@ class ViT(nn.Module):
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
num_cls_tokens = 1 if pool == 'cls' else 0
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
@@ -99,8 +103,9 @@ class ViT(nn.Module):
nn.LayerNorm(dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.cls_token = nn.Parameter(torch.randn(num_cls_tokens, dim))
self.pos_embedding = nn.Parameter(torch.randn(num_patches + num_cls_tokens, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
@@ -111,12 +116,15 @@ class ViT(nn.Module):
self.mlp_head = nn.Linear(dim, num_classes)
def forward(self, img):
batch = img.shape[0]
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
cls_tokens = repeat(self.cls_token, '... d -> b ... d', b = batch)
x = torch.cat((cls_tokens, x), dim = 1)
seq = x.shape[1]
x = x + self.pos_embedding[:seq]
x = self.dropout(x)
x = self.transformer(x)

View File

@@ -89,7 +89,7 @@ class ViT(nn.Module):
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (f pf) (h p1) (w p2) -> b (f h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
Rearrange('b c (f pf) (h p1) (w p2) -> b (f h w) (pf p1 p2 c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),

352
vit_pytorch/vit_nd_pope.py Normal file
View File

@@ -0,0 +1,352 @@
from __future__ import annotations
import torch
import torch.nn.functional as F
from torch import pi, nn, arange, cat, stack, Tensor
from torch.nn import Module, ModuleList
from torch.amp import autocast
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange
# helpers
def exists(val):
return val is not None
def l2norm(t):
return F.normalize(t, dim = -1, p = 2)
def join(arr, delimiter = ' '):
return delimiter.join(arr)
def ensure_tuple(t, length):
if isinstance(t, (tuple, list)):
assert len(t) == length, f'Expected tuple of length {length}, got {len(t)}'
return tuple(t)
return (t,) * length
# golden gate rotary - Jerry Xiong, PhD student at UIUC
# https://jerryxio.ng/posts/nd-rope/
# but using polar version instead
# Gopalakrishnan et al. https://arxiv.org/abs/2509.10534
def _phi(m: int) -> float:
x = 2.0
for _ in range(10):
x = (1 + x) ** (1.0 / (m + 1.0))
return x
def make_directions(n: int, d: int) -> Tensor:
g = _phi(d)
alpha = (1.0 / g) ** arange(1, d + 1, dtype = torch.float64)
i = arange(1, n + 1, dtype = torch.float64).unsqueeze(1)
z = torch.fmod(i * alpha, 1.0)
directions = torch.erfinv(2.0 * z - 1.0)
directions = l2norm(directions)
return directions.float()
class GoldenGatePoPENd(Module):
def __init__(
self,
dim_pos: int,
heads: int,
dim_head: int,
min_freq: float = 1.0,
max_freq: float = 10000.0,
p_zero_freqs: float = 0.0, # proportion of frequencies set to 0
init_learned_bias_uniform = False
):
super().__init__()
n_freqs = dim_head
n_zero_freqs = round(p_zero_freqs * n_freqs)
omega = cat((
torch.zeros(n_zero_freqs),
min_freq * (max_freq / min_freq) ** torch.linspace(0, 1, n_freqs - n_zero_freqs),
))
directions = rearrange(
make_directions(heads * n_freqs, dim_pos),
'(h f) p -> h f p',
h = heads
)
omega_expanded = rearrange(omega, 'f -> f 1')
self.register_buffer('freqs', directions * omega_expanded) # shape: (h, f, p)
self.learned_bias = nn.Parameter(torch.zeros(heads, dim_head))
if init_learned_bias_uniform:
self.learned_bias.uniform_(-2. * pi, 0.)
@autocast('cuda', enabled = False)
def forward(self, pos):
freqs = rearrange(self.freqs, 'h f p -> 1 h 1 f p')
positions = rearrange(pos.float(), 'b n p -> b 1 n 1 p')
# compute theta for each (batch, head, seq, freq)
theta = reduce(freqs * positions, 'b h n f p -> b h n f', 'sum')
bias = self.learned_bias.clamp(-2. * pi, 0.)
bias = rearrange(bias, 'h d -> h 1 d')
return theta, bias
@autocast('cuda', enabled = False)
def apply_polar_pos_emb(t, freqs):
orig_dtype = t.dtype
t = t.float()
t = F.softplus(t)
out = cat((t * freqs.cos(), t * freqs.sin()), dim = -1)
return out.type(orig_dtype)
# classes
class FeedForward(Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qk = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_v = nn.Linear(dim, inner_dim, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x, polar_pos_emb = None):
x = self.norm(x)
qkv = (*self.to_qk(x).chunk(2, dim = -1), self.to_v(x))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
if exists(polar_pos_emb):
freqs, bias = polar_pos_emb
q = apply_polar_pos_emb(q, freqs)
k = apply_polar_pos_emb(k, freqs + bias)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., polar_emb = None):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.polar_emb = polar_emb
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x, pos = None):
# pope embedding
polar_pos_emb = None
if exists(pos) and exists(self.polar_emb):
polar_pos_emb = self.polar_emb(pos)
# transformer layers
for attn, ff in self.layers:
x = attn(x, polar_pos_emb) + x
x = ff(x) + x
return self.norm(x)
class ViTND(Module):
def __init__(
self,
*,
ndim: int,
input_shape: int | tuple[int, ...],
patch_size: int | tuple[int, ...],
num_classes: int,
dim: int,
depth: int,
heads: int,
mlp_dim: int,
channels: int = 3,
dim_head: int = 64,
dropout: float = 0.,
emb_dropout: float = 0.,
pope_min_freq: float = 1.0,
pope_max_freq: float = 10000.0,
pope_p_zero_freqs: float = 0.0,
pope_init_learned_bias_uniform = False
):
super().__init__()
assert 1 <= ndim <= 7, 'ndim must be between 1 and 7'
self.ndim = ndim
input_shape = ensure_tuple(input_shape, ndim)
patch_size = ensure_tuple(patch_size, ndim)
for i, (inp_dim, patch_dim) in enumerate(zip(input_shape, patch_size)):
assert inp_dim % patch_dim == 0, f'Input dimension {i} ({inp_dim}) must be divisible by patch size ({patch_dim})'
num_patches_per_dim = [inp_dim // patch_dim for inp_dim, patch_dim in zip(input_shape, patch_size)]
num_patches = 1
for n in num_patches_per_dim:
num_patches *= n
patch_dim = channels
for p in patch_size:
patch_dim *= p
dim_names = 'fghijkl'[:ndim]
input_dims = [f'({d} p{i})' for i, d in enumerate(dim_names)]
patch_dims = [f'p{i}' for i in range(ndim)]
input_pattern = f'b c {join(input_dims)}'
output_pattern = f'b {join(dim_names)} ({join(patch_dims)} c)'
rearrange_str = f'{input_pattern} -> {output_pattern}'
rearrange_kwargs = {f'p{i}': p for i, p in enumerate(patch_size)}
self.to_patch_embedding = nn.Sequential(
Rearrange(rearrange_str, **rearrange_kwargs),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.dropout = nn.Dropout(emb_dropout)
# golden gate pope
self.polar_emb = GoldenGatePoPENd(
dim_pos = ndim,
heads = heads,
dim_head = dim_head,
min_freq = pope_min_freq,
max_freq = pope_max_freq,
p_zero_freqs = pope_p_zero_freqs,
init_learned_bias_uniform = pope_init_learned_bias_uniform
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, polar_emb = self.polar_emb)
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes)
def muon_parameters(self):
params = []
for m in self.modules():
if isinstance(m, Attention):
params.extend([
m.to_v.weight,
m.to_out[0].weight
])
elif isinstance(m, FeedForward):
params.extend([
m.net[1].weight,
m.net[-2].weight
])
return params
def forward(
self,
x,
return_embed = False
):
x = self.to_patch_embedding(x) # (b, *spatial_dims, patch_dim)
batch, *spatial_dims, _, device = *x.shape, x.device
# Generate position coordinates
grids = [arange(d, device = device, dtype = torch.float32) for d in spatial_dims]
grid = torch.meshgrid(*grids, indexing = 'ij')
pos = stack(grid, dim = -1) # (*spatial_dims, ndim)
# flatten spatial dimensions for attention with nd rotary
pos = repeat(pos, '... p -> b (...) p', b = batch)
x, packed_shape = pack([x], 'b * d')
x = self.dropout(x)
embed = self.transformer(x, pos)
# return the embed with reconstituted patch shape
if return_embed:
embed, = unpack(embed, packed_shape, 'b * d')
return embed
# pooling to logits
pooled = reduce(embed, 'b n d -> b d', 'mean')
pooled = self.to_latent(pooled)
return self.mlp_head(pooled)
if __name__ == '__main__':
model = ViTND(
ndim = 5,
input_shape = (4, 8, 16, 32, 64),
patch_size = (2, 2, 4, 4, 8),
num_classes = 1000,
dim = 512,
depth = 6,
heads = 8,
mlp_dim = 2048,
channels = 3,
dropout = 0.1,
emb_dropout = 0.1
)
data = torch.randn(3, 3, 4, 8, 16, 32, 64)
logits = model(data)
embed = model(data, return_embed = True) # (2, 2, 4, 4, 8, 8, 512)

View File

@@ -141,7 +141,7 @@ class ViT(nn.Module):
self.global_average_pool = pool == 'mean'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (f pf) (h p1) (w p2) -> b f (h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
Rearrange('b c (f pf) (h p1) (w p2) -> b f (h w) (pf p1 p2 c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim)