Compare commits

...

18 Commits
1.14.2 ... main

Author SHA1 Message Date
lucidrains
fb5014f0ee get a version of n-dimensional vit with golden gate polar coordinate embeddings into the repo for future use 2025-12-25 09:11:13 -08:00
Phil Wang
0b7518ef45 educate 2025-12-21 07:06:20 -08:00
lucidrains
077d8c188f fix distill 2025-12-10 15:52:10 -08:00
lucidrains
5888f05300 1.16.4 2025-12-07 04:32:52 -08:00
Amit Moryossef
d518e89573 cache position grids in NaViT forward pass (#354)
Use lru_cache to cache unique (ph, pw, device) position grids, avoiding
redundant computation when multiple images share the same patch
dimensions. Cache persists across forward passes.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-07 04:32:30 -08:00
lucidrains
dd6462d19b release small navit perf 2025-12-06 04:57:12 -08:00
Amit Moryossef
a1ee1daa1a optimize NaViT with SDPA and vectorized forward pass (#353)
- Replace manual attention with F.scaled_dot_product_attention
- Use repeat_interleave instead of meshgrid for position computation
- Build image_ids efficiently with repeat_interleave instead of F.pad
- Remove unused Rearrange import

~56% speedup (91ms -> 58ms on 512 variable-sized images)
Numerically equivalent (max diff ~5e-4, within flash attention tolerance)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude <noreply@anthropic.com>
2025-12-06 04:56:40 -08:00
lucidrains
3cff5e547a address https://github.com/lucidrains/vit-pytorch/issues/352 2025-12-02 05:21:52 -08:00
lucidrains
fdaf7f92b9 fix positional embed for mean pool case and cleanup 2025-11-27 17:01:47 -08:00
lucidrains
0ebd4edab9 address https://github.com/lucidrains/vit-pytorch/issues/351 2025-11-27 06:07:43 -08:00
lucidrains
aa49c2783a VAAT should have two ears 2025-11-22 08:32:23 -08:00
lucidrains
6aa0374313 register tokens for the AST in VAAT 2025-11-22 08:12:01 -08:00
lucidrains
b35a97de05 improvise a variant of VAT with audio cortex before fully generalizing it 2025-11-22 07:51:19 -08:00
lucidrains
1374b93145 the paper claims finetuning everything was better, but just allow for freezing the visual cortex, what PI proposes 2025-11-09 10:59:55 -08:00
lucidrains
4386742cd1 an option to return zero for decorr aux loss if insufficient samples 2025-11-09 10:08:06 -08:00
lucidrains
5cf8384c56 add a vit with decorrelation auxiliary losses for mha and feedforwards, right after prenorm - this is in line with a paper from the netherlands, but without extra parameters or their manual sgd update scheme 2025-10-28 12:17:32 -07:00
lucidrains
f7d59cecb5 some register tokens cannot hurt for VAT 2025-10-24 14:00:38 -07:00
lucidrains
a583cb5988 last tweak to vat 2025-10-23 12:21:09 -07:00
15 changed files with 1628 additions and 80 deletions

View File

@@ -49,7 +49,7 @@
## Vision Transformer - Pytorch
Implementation of <a href="https://openreview.net/pdf?id=YicbFdNTTy">Vision Transformer</a>, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. Significance is further explained in <a href="https://www.youtube.com/watch?v=TrdevFK_am4">Yannic Kilcher's</a> video. There's really not much to code here, but may as well lay it out for everyone so we expedite the attention revolution.
Implementation of <a href="https://openreview.net/pdf?id=YicbFdNTTy">Vision Transformer</a>, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. Significance is further explained in <a href="https://www.youtube.com/watch?v=TrdevFK_am4">Yannic Kilcher's</a> video. There's really not much to code here, but may as well lay it out for everyone so we expedite the [attention](https://www.youtube.com/watch?v=eMlx5fFNoYc) revolution.
For a Pytorch implementation with pretrained models, please see Ross Wightman's repository <a href="https://github.com/rwightman/pytorch-image-models">here</a>.
@@ -2201,4 +2201,28 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```
```bibtex
@misc{carrigg2025decorrelationspeedsvisiontransformers,
title = {Decorrelation Speeds Up Vision Transformers},
author = {Kieran Carrigg and Rob van Gastel and Melda Yeghaian and Sander Dalm and Faysal Boughorbel and Marcel van Gerven},
year = {2025},
eprint = {2510.14657},
archivePrefix = {arXiv},
primaryClass = {cs.CV},
url = {https://arxiv.org/abs/2510.14657},
}
```
```bibtex
@misc{gopalakrishnan2025decouplingwhatwherepolar,
title = {Decoupling the "What" and "Where" With Polar Coordinate Positional Embeddings},
author = {Anand Gopalakrishnan and Robert Csordás and Jürgen Schmidhuber and Michael C. Mozer},
year = {2025},
eprint = {2509.10534},
archivePrefix = {arXiv},
primaryClass = {cs.LG},
url = {https://arxiv.org/abs/2509.10534},
}
```
*I visualise a time when we will be to robots what dogs are to humans, and Im rooting for the machines.* — Claude Shannon

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "vit-pytorch"
version = "1.14.2"
version = "1.17.1"
description = "Vision Transformer (ViT) - Pytorch"
readme = { file = "README.md", content-type = "text/markdown" }
license = { file = "LICENSE" }

107
train_vit_decorr.py Normal file
View File

@@ -0,0 +1,107 @@
# /// script
# dependencies = [
# "accelerate",
# "vit-pytorch",
# "wandb"
# ]
# ///
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as T
from torchvision.datasets import CIFAR100
# constants
BATCH_SIZE = 32
LEARNING_RATE = 3e-4
EPOCHS = 10
DECORR_LOSS_WEIGHT = 1e-1
TRACK_EXPERIMENT_ONLINE = False
# helpers
def exists(v):
return v is not None
# data
transform = T.Compose([
T.ToTensor(),
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = CIFAR100(
root = 'data',
download = True,
train = True,
transform = transform
)
dataloader = DataLoader(dataset, batch_size = BATCH_SIZE, shuffle = True)
# model
from vit_pytorch.vit_with_decorr import ViT
vit = ViT(
dim = 128,
num_classes = 100,
image_size = 32,
patch_size = 4,
depth = 6,
heads = 8,
dim_head = 64,
mlp_dim = 128 * 4,
decorr_sample_frac = 1. # use all tokens
)
# optim
from torch.optim import Adam
optim = Adam(vit.parameters(), lr = LEARNING_RATE)
# prepare
from accelerate import Accelerator
accelerator = Accelerator()
vit, optim, dataloader = accelerator.prepare(vit, optim, dataloader)
# experiment
import wandb
wandb.init(
project = 'vit-decorr',
mode = 'disabled' if not TRACK_EXPERIMENT_ONLINE else 'online'
)
wandb.run.name = 'baseline'
# loop
for _ in range(EPOCHS):
for images, labels in dataloader:
logits, decorr_aux_loss = vit(images)
loss = F.cross_entropy(logits, labels)
total_loss = (
loss +
decorr_aux_loss * DECORR_LOSS_WEIGHT
)
wandb.log(dict(loss = loss, decorr_loss = decorr_aux_loss))
accelerator.print(f'loss: {loss.item():.3f} | decorr aux loss: {decorr_aux_loss.item():.3f}')
accelerator.backward(total_loss)
optim.step()
optim.zero_grad()

View File

@@ -25,12 +25,12 @@ class DistillMixin:
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = b)
cls_tokens = repeat(self.cls_token, 'n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim = 1)
x += self.pos_embedding[:, :(n + 1)]
x += self.pos_embedding[:(n + 1)]
if distilling:
distill_tokens = repeat(distill_token, '1 n d -> b n d', b = b)
distill_tokens = repeat(distill_token, 'n d -> b n d', b = b)
x = torch.cat((x, distill_tokens), dim = 1)
x = self._attend(x)
@@ -125,7 +125,7 @@ class DistillWrapper(Module):
self.alpha = alpha
self.hard = hard
self.distillation_token = nn.Parameter(torch.randn(1, 1, dim))
self.distillation_token = nn.Parameter(torch.randn(1, dim))
self.distill_mlp = nn.Sequential(
nn.LayerNorm(dim) if mlp_layernorm else nn.Identity(),

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from functools import partial
from functools import partial, lru_cache
from typing import List
import torch
@@ -9,7 +9,6 @@ from torch import nn, Tensor
from torch.nn.utils.rnn import pad_sequence as orig_pad_sequence
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
@@ -28,6 +27,12 @@ def pair(t):
def divisible_by(numer, denom):
return (numer % denom) == 0
@lru_cache(maxsize=128)
def posemb_grid(ph, pw, device):
h_idx = torch.arange(ph, device=device).repeat_interleave(pw)
w_idx = torch.arange(pw, device=device).repeat(ph)
return torch.stack([h_idx, w_idx], dim=-1)
# auto grouping images
def group_images_by_max_seq_len(
@@ -117,8 +122,7 @@ class Attention(nn.Module):
self.q_norm = RMSNorm(heads, dim_head)
self.k_norm = RMSNorm(heads, dim_head)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.dropout_p = dropout
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
@@ -145,19 +149,22 @@ class Attention(nn.Module):
q = self.q_norm(q)
k = self.k_norm(k)
dots = torch.matmul(q, k.transpose(-1, -2))
# combine masks if both exist
if exists(mask) or exists(attn_mask):
if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
if exists(mask) and exists(attn_mask):
attn_mask = mask & attn_mask
elif exists(mask):
attn_mask = mask
if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
dots = dots.masked_fill(~mask, -torch.finfo(dots.dtype).max)
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask = attn_mask,
dropout_p = self.dropout_p if self.training else 0.,
scale = 1. # RMSNorm already includes sqrt(dim) scaling
)
if exists(attn_mask):
dots = dots.masked_fill(~attn_mask, -torch.finfo(dots.dtype).max)
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
@@ -281,42 +288,41 @@ class NaViT(nn.Module):
for images in batched_images:
num_images.append(len(images))
sequences = []
positions = []
image_ids = torch.empty((0,), device = device, dtype = torch.long)
for image_id, image in enumerate(images):
assert image.ndim ==3 and image.shape[0] == c
# compute patch dimensions for all images
patch_dims = []
for image in images:
assert image.ndim == 3 and image.shape[0] == c
image_dims = image.shape[-2:]
assert all([divisible_by(dim, p) for dim in image_dims]), f'height and width {image_dims} of images must be divisible by patch size {p}'
patch_dims.append((image_dims[0] // p, image_dims[1] // p))
ph, pw = map(lambda dim: dim // p, image_dims)
# extract patches for all images
sequences = [rearrange(img, 'c (h p1) (w p2) -> (h w) (c p1 p2)', p1=p, p2=p) for img in images]
pos = torch.stack(torch.meshgrid((
arange(ph),
arange(pw)
), indexing = 'ij'), dim = -1)
# compute positions - uses lru_cache to avoid redundant computation across forward passes
positions = [posemb_grid(ph, pw, device) for ph, pw in patch_dims]
pos = rearrange(pos, 'h w c -> (h w) c')
seq = rearrange(image, 'c (h p1) (w p2) -> (h w) (c p1 p2)', p1 = p, p2 = p)
seq_len = seq.shape[-2]
if has_token_dropout:
# handle token dropout
if has_token_dropout:
for i, (seq, pos) in enumerate(zip(sequences, positions)):
image_dims = images[i].shape[-2:]
token_dropout = self.calc_token_dropout(*image_dims)
seq_len = seq.shape[0]
num_keep = max(1, int(seq_len * (1 - token_dropout)))
keep_indices = torch.randn((seq_len,), device = device).topk(num_keep, dim = -1).indices
keep_indices = torch.randn((seq_len,), device=device).topk(num_keep, dim=-1).indices
sequences[i] = seq[keep_indices]
positions[i] = pos[keep_indices]
seq = seq[keep_indices]
pos = pos[keep_indices]
image_ids = F.pad(image_ids, (0, seq.shape[-2]), value = image_id)
sequences.append(seq)
positions.append(pos)
# build image_ids efficiently using repeat_interleave
patch_counts = [seq.shape[0] for seq in sequences]
image_ids = torch.repeat_interleave(
arange(len(images)),
torch.tensor(patch_counts, device=device)
)
batched_image_ids.append(image_ids)
batched_sequences.append(torch.cat(sequences, dim = 0))
batched_positions.append(torch.cat(positions, dim = 0))
batched_sequences.append(torch.cat(sequences, dim=0))
batched_positions.append(torch.cat(positions, dim=0))
# derive key padding mask

View File

@@ -176,7 +176,7 @@ class NaViT(Module):
self.channels = channels
self.patch_size = patch_size
self.to_patches = Rearrange('c (f pf) (h p1) (w p2) -> f h w (c p1 p2 pf)', p1 = patch_size, p2 = patch_size, pf = frame_patch_size)
self.to_patches = Rearrange('c (f pf) (h p1) (w p2) -> f h w (c pf p1 p2)', p1 = patch_size, p2 = patch_size, pf = frame_patch_size)
self.to_patch_embedding = nn.Sequential(
nn.LayerNorm(patch_dim),

View File

@@ -146,7 +146,7 @@ class SimpleViT(Module):
patch_dim = channels * patch_height * patch_width * frame_patch_size
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (pf p1 p2 c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),

View File

@@ -103,7 +103,7 @@ class SimpleViT(nn.Module):
patch_dim = channels * patch_height * patch_width * frame_patch_size
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (pf p1 p2 c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),

777
vit_pytorch/vaat.py Normal file
View File

@@ -0,0 +1,777 @@
# vision-audio-action transformer - vaat
from __future__ import annotations
from contextlib import nullcontext
import torch
import torch.nn.functional as F
from torch import nn, cat, stack, arange, tensor
from torch.nn import Module, ModuleList
from torchaudio.transforms import Spectrogram
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange
# helpers
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# 2d sinusoidal positional embedding
# simple vit paper shows it is good enough compared to learned
def posemb_sincos_2d(
patches,
temperature = 10000,
dtype = torch.float32
):
_, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype
y, x = torch.meshgrid(arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
omega = arange(dim // 4, device = device) / (dim // 4 - 1)
omega = temperature ** -omega
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
pe = pe.type(dtype)
return rearrange(pe, '(h w) d -> h w d', h = h, w = w)
# classes
class FiLM(Module):
def __init__(
self,
dim,
):
super().__init__()
proj = nn.Linear(dim, dim * 2)
self.to_gamma_beta = nn.Sequential(
proj,
Rearrange('b (two d) -> two b 1 d', two = 2)
)
nn.init.zeros_(proj.weight)
nn.init.zeros_(proj.bias)
def forward(self, tokens, cond):
gamma, beta = self.to_gamma_beta(cond)
return tokens * gamma + beta
class FeedForward(Module):
def __init__(
self,
dim,
hidden_dim,
dropout = 0.
):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(Module):
def __init__(
self,
dim,
heads = 8,
dim_head = 64,
dropout = 0.,
dim_context = None,
cross_attend = False
):
super().__init__()
dim_context = default(dim_context, dim)
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.cross_attend = cross_attend
self.context_norm = nn.LayerNorm(dim_context) if cross_attend else None
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim_context, inner_dim * 2, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x, context = None):
assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross attending, or vice versa'
x = self.norm(x)
# handle norming of context for cross attention
kv_input = x
if self.cross_attend:
context = self.context_norm(context)
kv_input = context
# project for queries, keys, values
qkv = (self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(Module):
def __init__(
self,
dim,
depth,
heads,
dim_head,
mlp_dim,
dropout = 0.
):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(
self,
x,
return_hiddens = False
):
hiddens = []
for attn, ff in self.layers:
hiddens.append(x)
x = attn(x) + x
x = ff(x) + x
x = self.norm(x)
if not return_hiddens:
return x
return x, hiddens
class AST(Module):
# audio spectrogram transformer https://arxiv.org/abs/2104.01778
def __init__(
self,
dim,
depth,
mlp_dim,
num_classes = None,
patch_size = 16,
dim_head = 64,
heads = 8,
dropout = 0.,
accept_spec = False,
accept_spec_time_first = True,
spec_n_fft = 128,
spec_power = 2,
spec_win_length = 24,
spec_hop_length = None,
spec_pad = 0,
spec_center = True,
spec_pad_mode = 'reflect',
num_register_tokens = 4
):
super().__init__()
self.dim = dim
self.depth = depth
patch_height, patch_width = pair(patch_size)
patch_input_dim = patch_height * patch_width
self.patch_size = (patch_height, patch_width)
self.to_patch_tokens = nn.Sequential(
Rearrange('b (h p1) (w p2) -> b h w (p1 p2)', p1 = self.patch_size[0], p2 = self.patch_size[1]),
nn.LayerNorm(patch_input_dim),
nn.Linear(patch_input_dim, dim),
nn.LayerNorm(dim)
)
self.accept_spec = accept_spec
self.accept_spec_time_first = accept_spec_time_first
self.spec = Spectrogram(
n_fft = spec_n_fft,
power = spec_power,
win_length = spec_win_length,
hop_length = spec_hop_length,
pad = spec_pad,
center = spec_center,
pad_mode = spec_pad_mode
)
self.transformer = Transformer(
dim = dim,
depth = depth,
dim_head = dim_head,
heads = heads,
mlp_dim = mlp_dim,
dropout = dropout,
)
self.final_norm = nn.LayerNorm(dim)
self.mlp_head = nn.Linear(dim, num_classes) if exists(num_classes) else nn.Identity()
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
def forward(
self,
raw_audio_or_spec, # (b t) | (b f t)
return_hiddens = False
):
batch, device = raw_audio_or_spec.shape[0], raw_audio_or_spec.device
assert (self.accept_spec and raw_audio_or_spec.ndim == 3) or (not self.accept_spec and raw_audio_or_spec.ndim == 2)
if self.accept_spec:
spec = rearrange(raw_audio_or_spec, 'b t f -> b f t')
else:
spec = self.spec(raw_audio_or_spec)
# automatically crop if audio does not yield a 2d spectrogram that is divisible by patch sizes
height, width = spec.shape[-2:]
patch_height, patch_width = self.patch_size
rounded_height = height // patch_height * patch_height
rounded_width = width // patch_width * patch_width
spec = spec[..., :rounded_height, :rounded_width]
# to patches
tokens = self.to_patch_tokens(spec)
# get number of patches along height and width
_, num_patch_height, num_patch_width, _ = tokens.shape
# 2d sinusoidal positional embedding
tokens = tokens + posemb_sincos_2d(tokens)
tokens = rearrange(tokens, 'b ... c -> b (...) c')
# register tokens
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
tokens, packed_shape = pack((register_tokens, tokens), 'b * d')
# attention
attended, hiddens = self.transformer(tokens, return_hiddens = True)
# final global average and norm (most recent papers show this is superior to CLS token)
normed = self.final_norm(attended)
if return_hiddens:
return normed, stack(hiddens)
register_tokens, normed = unpack(normed, packed_shape, 'b * d')
pooled = reduce(normed, 'b n d -> b d', 'mean')
maybe_logits = self.mlp_head(pooled)
return maybe_logits
class ViT(Module):
def __init__(
self,
*,
image_size,
patch_size,
num_classes,
dim,
depth,
heads,
mlp_dim,
pool = 'cls',
channels = 3,
dim_head = 64,
dropout = 0.,
emb_dropout = 0.,
num_register_tokens = 0
):
super().__init__()
self.dim = dim
self.depth = depth
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.pos_embedding = nn.Parameter(torch.randn(num_patches, dim))
self.cls_token = nn.Parameter(torch.randn(dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes)
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
def forward(self, img, return_hiddens = False):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
x += self.pos_embedding[:n]
cls_tokens = repeat(self.cls_token, 'd -> b d', b = b)
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = b)
x, packed_shape = pack((register_tokens, cls_tokens, x), 'b * d')
x = self.dropout(x)
x, hiddens = self.transformer(x, return_hiddens = True)
# return the representation trajectory
if return_hiddens:
return x, stack(hiddens)
register_tokens, cls_tokens, x = unpack(x, packed_shape, 'b * d')
x = x.mean(dim = 1) if self.pool == 'mean' else cls_tokens
x = self.to_latent(x)
return self.mlp_head(x)
# proposed VAT
# https://openreview.net/forum?id=TalHOvvLZu
# simple way to get SOTA on Libero dataset (beating fine-tuned pi-zero)
class VAAT(Module):
def __init__(
self,
vit: ViT | dict,
ast: AST | dict,
*,
dim,
depth,
heads,
dim_head,
dim_action,
mlp_dim,
num_image_views = None,
num_audio_views = None,
num_tasks = None,
dim_extra_token = None,
num_register_tokens = 4,
action_chunk_len = 7,
time_seq_len = 1,
dropout = 0.,
add_self_attn = True, # in the paper, they didn't have any ways for the action token to exchange information with the extra token, so we'll just add it as an option
self_attn_heads = 4,
self_attn_dim_head = 32,
ast_layer_indices: tuple[int, ...] | None = None,
vit_layer_indices: tuple[int, ...] | None = None
):
super().__init__()
# vit
if isinstance(vit, dict):
vit = ViT(**vit)
self.vit = vit
vit_dim = vit.dim
assert vit.depth == depth or exists(vit_layer_indices), f'if the VAAT depth is not equal to the ViT depth, you must pass in the indices from the ViT to be layered to the VAAT in order from bottom to top'
vit_layer_indices = default(vit_layer_indices, tuple(range(depth)))
assert len(vit_layer_indices) == depth, f'number of vit layer indices {len(vit_layer_indices)} does not much the VAT depth {depth}'
self.register_buffer('vit_layer_indices', tensor(vit_layer_indices), persistent = False)
# ast
if isinstance(ast, dict):
ast = AST(**ast)
self.ast = ast
ast_dim = ast.dim
self.ast_accept_spec = ast.accept_spec
assert ast.depth == depth or exists(ast_layer_indices), f'if the VAAT depth is not equal to the AST depth, you must pass in the indices from the AST to be layered to the VAAT in order from bottom to top'
ast_layer_indices = default(ast_layer_indices, tuple(range(depth)))
assert len(ast_layer_indices) == depth, f'number of ast layer indices {len(ast_layer_indices)} does not much the VAAT depth {depth}'
self.register_buffer('ast_layer_indices', tensor(vit_layer_indices), persistent = False)
# handle maybe multiple frames
is_video = time_seq_len > 1
self.is_video = is_video
self.time_seq_len = time_seq_len
self.time_pos_emb = nn.Parameter(torch.randn(time_seq_len, vit_dim) * 1e-2) if is_video else None
# maybe view embeddings
self.image_view_emb = nn.Parameter(torch.randn(num_image_views, vit_dim) * 1e-2) if exists(num_image_views) and num_image_views > 1 else None
self.audio_view_emb = nn.Parameter(torch.randn(num_audio_views, ast_dim) * 1e-2) if exists(num_audio_views) and num_audio_views > 1 else None
# handle maybe task conditioning
self.has_tasks = exists(num_tasks)
if self.has_tasks:
self.task_emb = nn.Parameter(torch.randn(num_tasks, dim) * 1e-2)
# register tokens from Darcet et al.
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
# to action tokens
self.action_pos_emb = nn.Parameter(torch.randn(action_chunk_len, dim) * 1e-2)
self.layers = ModuleList([])
for _ in range(depth):
maybe_film = FiLM(dim = dim) if self.has_tasks else None
maybe_self_attn = Attention(dim = dim, heads = self_attn_heads, dim_head = self_attn_dim_head, dropout = dropout) if add_self_attn else None
self.layers.append(ModuleList([
maybe_film,
maybe_self_attn,
Attention(dim = dim, dim_context = vit_dim, heads = heads, dim_head = dim_head, dropout = dropout, cross_attend = True),
Attention(dim = dim, dim_context = ast_dim, heads = heads, dim_head = dim_head, dropout = dropout, cross_attend = True),
FeedForward(dim = dim, hidden_dim = mlp_dim, dropout = dropout)
]))
self.final_norm = nn.LayerNorm(dim)
self.to_pred_action = nn.Linear(dim, dim_action, bias = False)
# handle the extra token
self.accept_extra_token = exists(dim_extra_token)
if exists(dim_extra_token):
self.to_extra_token = nn.Linear(dim_extra_token, dim)
def forward(
self,
video_or_image, # (b v? c t? h w) - batch, views [wrist + third person or more], channels, maybe time, height, width
audio_or_spec, # (b v? t) | (b v?f t) - batch, audio len | batch, spec freq, time
*,
extra = None, # (b d) - batch, dim extra
tasks = None, # (b)
actions = None, # (b k d) - batch, action chunk length, action dimension
return_hiddens = False,
freeze_vit = False,
freeze_ast = False
):
batch = video_or_image.shape[0]
return_loss = exists(actions)
# handle some various input dimensions
if video_or_image.ndim == 4:
video_or_image = rearrange(video_or_image, 'b 1 c h w')
assert (
(video_or_image.ndim == 5 and not self.is_video) or
(video_or_image.ndim == 6 and self.is_video)
)
if video_or_image.ndim == 5:
video_or_image = rearrange(video_or_image, 'b v c h w -> b v c 1 h w')
assert video_or_image.shape[3] == self.time_seq_len
# audio shapes - adding view if impliciy to be 1
if audio_or_spec.ndim == 2 and not self.ast_accept_spec:
audio_or_spec = rearrange(audio_or_spec, 'b t -> b 1 t')
elif audio_or_spec.ndim == 3 and self.ast_accept_spec:
audio_or_spec = rearrange(audio_or_spec, 'b f t -> b 1 f t')
# to images
images = rearrange(video_or_image, 'b v c t h w -> b v t c h w')
images, image_packed_shape = pack([images], '* c h w')
# to audio
if self.ast_accept_spec:
audio_or_spec, audio_packed_shape = pack([audio_or_spec], '* f t')
else:
audio_or_spec, audio_packed_shape = pack([audio_or_spec], '* t')
# get representation trajectory from vit
vit_forward_context = torch.no_grad if freeze_vit else nullcontext
with vit_forward_context():
embed, hiddens = self.vit(images, return_hiddens = True)
hiddens = cat((hiddens, embed[None, ...]))
# extract the hiddens needed for the action cross attention
hiddens = hiddens[self.vit_layer_indices]
# unpack temporarily for embedding
hiddens, = unpack(hiddens, image_packed_shape, 'l * n d') # l for layers
# maybe add time embeddings
if self.is_video:
time_pos_emb = rearrange(self.time_pos_emb, 't d -> t 1 d')
hiddens = hiddens + time_pos_emb
# maybe view embeddings
if exists(self.image_view_emb):
assert self.image_view_emb.shape[0] == hiddens.shape[2]
image_view_emb = rearrange(self.image_view_emb, 'v d -> v 1 1 d')
hiddens = hiddens + image_view_emb
# get representation trajectory from ast
ast_forward_context = torch.no_grad if freeze_ast else nullcontext
with ast_forward_context():
audio_embed, audio_hiddens = self.ast(audio_or_spec, return_hiddens = True)
audio_hiddens = cat((audio_hiddens, audio_embed[None, ...]))
# extract the hiddens needed for the action cross attention
audio_hiddens = audio_hiddens[self.ast_layer_indices]
# unpack audio temporarily for embedding
audio_hiddens, = unpack(audio_hiddens, audio_packed_shape, 'l * n d') # l for layers
# maybe audio view embeddings
if exists(self.audio_view_emb):
assert self.audio_view_emb.shape[0] == audio_hiddens.shape[2]
audio_view_emb = rearrange(self.audio_view_emb, 'v d -> v 1 1 d')
audio_hiddens = audio_hiddens + audio_view_emb
# maybe tasks
if exists(tasks):
assert self.has_tasks, f'`num_tasks` must be set on `VAT` for task conditioning'
task_emb = self.task_emb[tasks]
# cross from actions to representation trajectory
image_context = rearrange(hiddens, 'l b v t n d -> l b (v t n) d')
audio_context = rearrange(audio_hiddens, 'l b v n d -> l b (v n) d')
# get main action tokens and maybe append extra
action_tokens = repeat(self.action_pos_emb, 'k d -> b k d', b = batch)
has_extra = exists(extra)
if has_extra:
assert self.accept_extra_token
extra_token = self.to_extra_token(extra)
action_tokens, packed_extra = pack([action_tokens, extra_token], 'b * d')
# register tokens
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
action_tokens, registers_packed_shape = pack((register_tokens, action_tokens), 'b * d')
# cross attention
hiddens = [action_tokens]
for (maybe_film, maybe_self_attn, image_cross_attn, audio_cross_attn, ff), image_layer_context, audio_layer_context in zip(self.layers, image_context, audio_context):
if exists(tasks):
action_tokens = maybe_film(action_tokens, task_emb)
action_tokens = image_cross_attn(action_tokens, image_layer_context) + action_tokens
action_tokens = audio_cross_attn(action_tokens, audio_layer_context) + action_tokens
if exists(maybe_self_attn):
action_tokens = maybe_self_attn(action_tokens) + action_tokens
action_tokens = ff(action_tokens) + action_tokens
hiddens.append(action_tokens)
# unpack registers
_, action_tokens = unpack(action_tokens, registers_packed_shape, 'b * d')
# maybe unpack extra
if has_extra:
action_tokens, _ = unpack(action_tokens, packed_extra, 'b * d')
# norm and prediction
action_tokens = self.final_norm(action_tokens)
pred_action = self.to_pred_action(action_tokens)
if not return_loss:
if not return_hiddens:
return pred_action
return pred_action, stack(hiddens)
assert pred_action.shape[1] == actions.shape[1]
# they found l1 loss suffices
return F.l1_loss(pred_action, actions)
# quick test
if __name__ == '__main__':
vit = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 384,
heads = 8,
depth = 4,
mlp_dim = 384 * 4
)
ast = AST(
dim = 384,
depth = 4,
heads = 8,
num_classes = 1000,
patch_size = 16,
mlp_dim = 384 * 4
)
vaat = VAAT(
vit,
ast,
dim = 512,
depth = 9,
heads = 8,
dim_head = 64,
mlp_dim = 2048,
dim_action = 20,
action_chunk_len = 7,
time_seq_len = 4,
num_image_views = 2,
num_audio_views = 2,
num_tasks = 4,
add_self_attn = True,
dim_extra_token = 33, # extra token with some variable dimension
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
0, 0, 1, 1, 2, 2, 3, 3, 4
),
ast_layer_indices = (
1, 1, 1, 2, 2, 2, 3, 3, 3
)
)
images = torch.randn(2, 2, 3, 4, 256, 256) # (2 views with 4 frames)
audio = torch.randn(2, 2, 14_100 * 5)
tasks = torch.randint(0, 4, (2,))
extra = torch.randn(2, 33) # extra internal state
actions = torch.randn(2, 7, 20) # actions for learning
loss = vaat(images, audio, actions = actions, tasks = tasks, extra = extra, freeze_vit = True)
loss.backward()
# after much training
pred_actions, hiddens = vaat(images, audio, tasks = tasks, extra = extra, return_hiddens = True)
assert pred_actions.shape == (2, 7, 20)

View File

@@ -1,4 +1,5 @@
from __future__ import annotations
from contextlib import nullcontext
import torch
import torch.nn.functional as F
@@ -66,12 +67,14 @@ class Attention(Module):
def __init__(
self,
dim,
dim_context = None,
heads = 8,
dim_head = 64,
dropout = 0.,
cross_attend = False
):
super().__init__()
dim_context = default(dim_context, dim)
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
@@ -81,13 +84,13 @@ class Attention(Module):
self.norm = nn.LayerNorm(dim)
self.cross_attend = cross_attend
self.context_norm = nn.LayerNorm(dim) if cross_attend else None
self.context_norm = nn.LayerNorm(dim_context) if cross_attend else None
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_kv = nn.Linear(dim_context, inner_dim * 2, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
@@ -178,7 +181,8 @@ class ViT(Module):
channels = 3,
dim_head = 64,
dropout = 0.,
emb_dropout = 0.
emb_dropout = 0.,
num_register_tokens = 0
):
super().__init__()
self.dim = dim
@@ -200,8 +204,8 @@ class ViT(Module):
nn.LayerNorm(dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.pos_embedding = nn.Parameter(torch.randn(num_patches, dim))
self.cls_token = nn.Parameter(torch.randn(dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
@@ -211,13 +215,19 @@ class ViT(Module):
self.mlp_head = nn.Linear(dim, num_classes)
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
def forward(self, img, return_hiddens = False):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x += self.pos_embedding[:n]
cls_tokens = repeat(self.cls_token, 'd -> b d', b = b)
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = b)
x, packed_shape = pack((register_tokens, cls_tokens, x), 'b * d')
x = self.dropout(x)
x, hiddens = self.transformer(x, return_hiddens = True)
@@ -227,7 +237,9 @@ class ViT(Module):
if return_hiddens:
return x, stack(hiddens)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
register_tokens, cls_tokens, x = unpack(x, packed_shape, 'b * d')
x = x.mean(dim = 1) if self.pool == 'mean' else cls_tokens
x = self.to_latent(x)
return self.mlp_head(x)
@@ -251,6 +263,7 @@ class VAT(Module):
num_views = None,
num_tasks = None,
dim_extra_token = None,
num_register_tokens = 4,
action_chunk_len = 7,
time_seq_len = 1,
dropout = 0.,
@@ -295,6 +308,10 @@ class VAT(Module):
if self.has_tasks:
self.task_emb = nn.Parameter(torch.randn(num_tasks, dim) * 1e-2)
# register tokens from Darcet et al.
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
# to action tokens
self.action_pos_emb = nn.Parameter(torch.randn(action_chunk_len, dim) * 1e-2)
@@ -308,7 +325,7 @@ class VAT(Module):
self.layers.append(ModuleList([
maybe_film,
maybe_self_attn,
Attention(dim = dim, heads = heads, dim_head = dim_head, dropout = dropout, cross_attend = True),
Attention(dim = dim, dim_context = vit_dim, heads = heads, dim_head = dim_head, dropout = dropout, cross_attend = True),
FeedForward(dim = dim, hidden_dim = mlp_dim, dropout = dropout)
]))
@@ -329,6 +346,8 @@ class VAT(Module):
extra = None, # (b d) - batch, dim extra
tasks = None, # (b)
actions = None, # (b k d) - batch, action chunk length, action dimension
return_hiddens = False,
freeze_vit = False
):
batch = video_or_image.shape[0]
return_loss = exists(actions)
@@ -356,7 +375,10 @@ class VAT(Module):
# get representation trajectory from vit
embed, hiddens = self.vit(images, return_hiddens = True)
vit_forward_context = torch.no_grad if freeze_vit else nullcontext
with vit_forward_context():
embed, hiddens = self.vit(images, return_hiddens = True)
hiddens = cat((hiddens, embed[None, ...]))
@@ -406,8 +428,16 @@ class VAT(Module):
action_tokens, packed_extra = pack([action_tokens, extra_token], 'b * d')
# register tokens
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
action_tokens, registers_packed_shape = pack((register_tokens, action_tokens), 'b * d')
# cross attention
hiddens = [action_tokens]
for (maybe_film, maybe_self_attn, cross_attn, ff), layer_context in zip(self.layers, context):
if exists(tasks):
@@ -420,6 +450,12 @@ class VAT(Module):
action_tokens = ff(action_tokens) + action_tokens
hiddens.append(action_tokens)
# unpack registers
_, action_tokens = unpack(action_tokens, registers_packed_shape, 'b * d')
# maybe unpack extra
if has_extra:
@@ -432,7 +468,10 @@ class VAT(Module):
pred_action = self.to_pred_action(action_tokens)
if not return_loss:
return pred_action
if not return_hiddens:
return pred_action
return pred_action, stack(hiddens)
assert pred_action.shape[1] == actions.shape[1]
@@ -448,10 +487,10 @@ if __name__ == '__main__':
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 512,
dim = 256,
heads = 8,
depth = 4,
mlp_dim = 2048
mlp_dim = 1024
)
vat = VAT(
@@ -479,11 +518,11 @@ if __name__ == '__main__':
actions = torch.randn(2, 7, 20) # actions for learning
loss = vat(images, actions = actions, tasks = tasks, extra = extra)
loss = vat(images, actions = actions, tasks = tasks, extra = extra, freeze_vit = True)
loss.backward()
# after much training
pred_actions = vat(images, tasks = tasks, extra = extra)
pred_actions, hiddens = vat(images, tasks = tasks, extra = extra, return_hiddens = True)
assert pred_actions.shape == (2, 7, 20)

View File

@@ -1,5 +1,6 @@
import torch
from torch import nn
from torch.nn import Module, ModuleList
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
@@ -11,7 +12,7 @@ def pair(t):
# classes
class FeedForward(nn.Module):
class FeedForward(Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
@@ -26,7 +27,7 @@ class FeedForward(nn.Module):
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
@@ -62,13 +63,14 @@ class Attention(nn.Module):
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
@@ -80,7 +82,7 @@ class Transformer(nn.Module):
return self.norm(x)
class ViT(nn.Module):
class ViT(Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
@@ -90,7 +92,9 @@ class ViT(nn.Module):
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
num_cls_tokens = 1 if pool == 'cls' else 0
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
@@ -99,8 +103,9 @@ class ViT(nn.Module):
nn.LayerNorm(dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.cls_token = nn.Parameter(torch.randn(num_cls_tokens, dim))
self.pos_embedding = nn.Parameter(torch.randn(num_patches + num_cls_tokens, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
@@ -111,12 +116,15 @@ class ViT(nn.Module):
self.mlp_head = nn.Linear(dim, num_classes)
def forward(self, img):
batch = img.shape[0]
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
cls_tokens = repeat(self.cls_token, '... d -> b ... d', b = batch)
x = torch.cat((cls_tokens, x), dim = 1)
seq = x.shape[1]
x = x + self.pos_embedding[:seq]
x = self.dropout(x)
x = self.transformer(x)

View File

@@ -89,7 +89,7 @@ class ViT(nn.Module):
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (f pf) (h p1) (w p2) -> b (f h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
Rearrange('b c (f pf) (h p1) (w p2) -> b (f h w) (pf p1 p2 c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),

353
vit_pytorch/vit_nd_pope.py Normal file
View File

@@ -0,0 +1,353 @@
from __future__ import annotations
import torch
import torch.nn.functional as F
from torch import pi, nn, arange, cat, stack, Tensor
from torch.nn import Module, ModuleList
from torch.amp import autocast
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange
# helpers
def exists(val):
return val is not None
def l2norm(t):
return F.normalize(t, dim = -1, p = 2)
def join(arr, delimiter = ' '):
return delimiter.join(arr)
def ensure_tuple(t, length):
if isinstance(t, (tuple, list)):
assert len(t) == length, f'Expected tuple of length {length}, got {len(t)}'
return tuple(t)
return (t,) * length
# golden gate rotary - Jerry Xiong, PhD student at UIUC
# https://jerryxio.ng/posts/nd-rope/
# but using polar version instead
# Gopalakrishnan et al. https://arxiv.org/abs/2509.10534
def _phi(m: int) -> float:
x = 2.0
for _ in range(10):
x = (1 + x) ** (1.0 / (m + 1.0))
return x
def make_directions(n: int, d: int) -> Tensor:
g = _phi(d)
alpha = (1.0 / g) ** arange(1, d + 1, dtype = torch.float64)
i = arange(1, n + 1, dtype = torch.float64).unsqueeze(1)
z = torch.fmod(i * alpha, 1.0)
directions = torch.erfinv(2.0 * z - 1.0)
directions = l2norm(directions)
return directions.float()
class GoldenGatePoPENd(Module):
def __init__(
self,
dim_pos: int,
heads: int,
dim_head: int,
min_freq: float = 1.0,
max_freq: float = 10000.0,
p_zero_freqs: float = 0.0, # proportion of frequencies set to 0
init_learned_bias_uniform = False
):
super().__init__()
n_freqs = dim_head
n_zero_freqs = round(p_zero_freqs * n_freqs)
omega = cat((
torch.zeros(n_zero_freqs),
min_freq * (max_freq / min_freq) ** torch.linspace(0, 1, n_freqs - n_zero_freqs),
))
directions = rearrange(
make_directions(heads * n_freqs, dim_pos),
'(h f) p -> h f p',
h = heads
)
omega_expanded = rearrange(omega, 'f -> f 1')
self.register_buffer('freqs', directions * omega_expanded) # shape: (h, f, p)
self.learned_bias = nn.Parameter(torch.zeros(heads, dim_head))
if init_learned_bias_uniform:
self.learned_bias.uniform_(-2. * pi, 0.)
@autocast('cuda', enabled = False)
def forward(self, pos):
freqs = rearrange(self.freqs, 'h f p -> 1 h 1 f p')
positions = rearrange(pos.float(), 'b n p -> b 1 n 1 p')
# compute theta for each (batch, head, seq, freq)
theta = reduce(freqs * positions, 'b h n f p -> b h n f', 'sum')
bias = self.learned_bias.clamp(-2. * pi, 0.)
bias = rearrange(bias, 'h d -> h 1 d')
return theta, bias
@autocast('cuda', enabled = False)
def apply_polar_pos_emb(t, freqs):
orig_dtype = t.dtype
t = t.float()
t = F.softplus(t)
out = cat((t * freqs.cos(), t * freqs.sin()), dim = -1)
return out.type(orig_dtype)
# classes
class FeedForward(Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qk = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_v = nn.Linear(dim, inner_dim, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x, polar_pos_emb = None):
x = self.norm(x)
qkv = (*self.to_qk(x).chunk(2, dim = -1), self.to_v(x))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
if exists(polar_pos_emb):
freqs, bias = polar_pos_emb
q = apply_polar_pos_emb(q, freqs)
k = apply_polar_pos_emb(k, freqs + bias)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., polar_emb = None):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.polar_emb = polar_emb
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x, pos = None):
# pope embedding
polar_pos_emb = None
if exists(pos) and exists(self.polar_emb):
polar_pos_emb = self.polar_emb(pos)
# transformer layers
for attn, ff in self.layers:
x = attn(x, polar_pos_emb) + x
x = ff(x) + x
return self.norm(x)
class ViTND(Module):
def __init__(
self,
*,
ndim: int,
input_shape: int | tuple[int, ...],
patch_size: int | tuple[int, ...],
num_classes: int,
dim: int,
depth: int,
heads: int,
mlp_dim: int,
channels: int = 3,
dim_head: int = 64,
dropout: float = 0.,
emb_dropout: float = 0.,
pope_min_freq: float = 1.0,
pope_max_freq: float = 10000.0,
pope_p_zero_freqs: float = 0.0,
pope_init_learned_bias_uniform = False
):
super().__init__()
assert 1 <= ndim <= 7, 'ndim must be between 1 and 7'
self.ndim = ndim
input_shape = ensure_tuple(input_shape, ndim)
patch_size = ensure_tuple(patch_size, ndim)
for i, (inp_dim, patch_dim) in enumerate(zip(input_shape, patch_size)):
assert inp_dim % patch_dim == 0, f'Input dimension {i} ({inp_dim}) must be divisible by patch size ({patch_dim})'
num_patches_per_dim = [inp_dim // patch_dim for inp_dim, patch_dim in zip(input_shape, patch_size)]
num_patches = 1
for n in num_patches_per_dim:
num_patches *= n
patch_dim = channels
for p in patch_size:
patch_dim *= p
dim_names = 'fghijkl'[:ndim]
input_dims = [f'({d} p{i})' for i, d in enumerate(dim_names)]
patch_dims = [f'p{i}' for i in range(ndim)]
input_pattern = f'b c {join(input_dims)}'
output_pattern = f'b {join(dim_names)} ({join(patch_dims)} c)'
rearrange_str = f'{input_pattern} -> {output_pattern}'
rearrange_kwargs = {f'p{i}': p for i, p in enumerate(patch_size)}
self.to_patch_embedding = nn.Sequential(
Rearrange(rearrange_str, **rearrange_kwargs),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.dropout = nn.Dropout(emb_dropout)
# golden gate pope
self.polar_emb = GoldenGatePoPENd(
dim_pos = ndim,
heads = heads,
dim_head = dim_head,
min_freq = pope_min_freq,
max_freq = pope_max_freq,
p_zero_freqs = pope_p_zero_freqs,
init_learned_bias_uniform = pope_init_learned_bias_uniform
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, polar_emb = self.polar_emb)
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes)
def muon_parameters(self):
params = []
for m in self.modules():
if isinstance(m, Attention):
params.extend([
m.to_v.weight,
m.to_out[0].weight
])
elif isinstance(m, FeedForward):
params.extend([
m.net[1].weight,
m.net[-2].weight
])
return params
def forward(
self,
x,
return_embed = False
):
x = self.to_patch_embedding(x) # (b, *spatial_dims, patch_dim)
batch, *spatial_dims, _, device = *x.shape, x.device
# Generate position coordinates
grids = [arange(d, device = device, dtype = torch.float32) for d in spatial_dims]
grid = torch.meshgrid(*grids, indexing = 'ij')
pos = stack(grid, dim = -1) # (*spatial_dims, ndim)
# flatten spatial dimensions for attention with nd rotary
pos = repeat(pos, '... p -> b (...) p', b = batch)
x, packed_shape = pack([x], 'b * d')
x = self.dropout(x)
embed = self.transformer(x, pos)
# return the embed with reconstituted patch shape
if return_embed:
embed, = unpack(embed, packed_shape, 'b * d')
return embed
# pooling to logits
pooled = reduce(embed, 'b n d -> b d', 'mean')
pooled = self.to_latent(pooled)
return self.mlp_head(pooled)
if __name__ == '__main__':
model = ViTND(
ndim = 5,
input_shape = (4, 8, 16, 32, 64),
patch_size = (2, 2, 4, 4, 8),
num_classes = 1000,
dim = 512,
depth = 6,
heads = 8,
mlp_dim = 2048,
channels = 3,
dropout = 0.1,
emb_dropout = 0.1
)
data = torch.randn(3, 3, 4, 8, 16, 32, 64)
logits = model(data)
embed = model(data, return_embed = True)
assert embed.shape == (3, 2, 4, 4, 8, 8, 512)

View File

@@ -0,0 +1,234 @@
# https://arxiv.org/abs/2510.14657
# but instead of their decorr module updated with SGD, remove all projections and just return a decorrelation auxiliary loss
import torch
from torch import nn, stack, tensor
import torch.nn.functional as F
from torch.nn import Module, ModuleList
from einops import rearrange, repeat, reduce, einsum, pack, unpack
from einops.layers.torch import Rearrange
# helpers
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# decorr loss
class DecorrelationLoss(Module):
def __init__(
self,
sample_frac = 1.,
soft_validate_num_sampled = False
):
super().__init__()
assert 0. <= sample_frac <= 1.
self.need_sample = sample_frac < 1.
self.sample_frac = sample_frac
self.soft_validate_num_sampled = soft_validate_num_sampled
self.register_buffer('zero', tensor(0.), persistent = False)
def forward(
self,
tokens
):
batch, seq_len, dim, device = *tokens.shape[-3:], tokens.device
if self.need_sample:
num_sampled = int(seq_len * self.sample_frac)
assert self.soft_validate_num_sampled or num_sampled >= 2.
if num_sampled <= 1:
return self.zero
tokens, packed_shape = pack([tokens], '* n d e')
indices = torch.randn(tokens.shape[:2]).argsort(dim = -1)[..., :num_sampled, :]
batch_arange = torch.arange(tokens.shape[0], device = tokens.device)
batch_arange = rearrange(batch_arange, 'b -> b 1')
tokens = tokens[batch_arange, indices]
tokens, = unpack(tokens, packed_shape, '* n d e')
dist = einsum(tokens, tokens, '... n d, ... n e -> ... d e') / tokens.shape[-2]
eye = torch.eye(dim, device = device)
loss = dist.pow(2) * (1. - eye) / ((dim - 1) * dim)
loss = reduce(loss, '... b d e -> b', 'sum')
return loss.mean()
# classes
class FeedForward(Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
normed = self.norm(x)
return self.net(x), normed
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.norm = nn.LayerNorm(dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
normed = self.norm(x)
qkv = self.to_qkv(normed).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out), normed
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
normed_inputs = []
for attn, ff in self.layers:
attn_out, attn_normed_inp = attn(x)
x = attn_out + x
ff_out, ff_normed_inp = ff(x)
x = ff_out + x
normed_inputs.append(attn_normed_inp)
normed_inputs.append(ff_normed_inp)
return self.norm(x), stack(normed_inputs)
class ViT(Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., decorr_sample_frac = 1.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes)
# decorrelation loss related
self.has_decorr_loss = decorr_sample_frac > 0.
if self.has_decorr_loss:
self.decorr_loss = DecorrelationLoss(decorr_sample_frac)
self.register_buffer('zero', torch.tensor(0.), persistent = False)
def forward(
self,
img,
return_decorr_aux_loss = None
):
return_decorr_aux_loss = default(return_decorr_aux_loss, self.training) and self.has_decorr_loss
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x, normed_layer_inputs = self.transformer(x)
# maybe return decor loss
decorr_aux_loss = self.zero
if return_decorr_aux_loss:
decorr_aux_loss = self.decorr_loss(normed_layer_inputs)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x), decorr_aux_loss
# quick test
if __name__ == '__main__':
decorr_loss = DecorrelationLoss(0.1)
hiddens = torch.randn(6, 2, 512, 256)
decorr_loss(hiddens)
decorr_loss(hiddens[0])
decorr_loss = DecorrelationLoss(0.0001, soft_validate_num_sampled = True)
out = decorr_loss(hiddens)
assert out.item() == 0

View File

@@ -141,7 +141,7 @@ class ViT(nn.Module):
self.global_average_pool = pool == 'mean'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (f pf) (h p1) (w p2) -> b f (h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
Rearrange('b c (f pf) (h p1) (w p2) -> b f (h w) (pf p1 p2 c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim)