Compare commits

...

8 Commits

6 changed files with 1199 additions and 18 deletions

View File

@@ -2201,4 +2201,16 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```
```bibtex
@misc{carrigg2025decorrelationspeedsvisiontransformers,
title = {Decorrelation Speeds Up Vision Transformers},
author = {Kieran Carrigg and Rob van Gastel and Melda Yeghaian and Sander Dalm and Faysal Boughorbel and Marcel van Gerven},
year = {2025},
eprint = {2510.14657},
archivePrefix = {arXiv},
primaryClass = {cs.CV},
url = {https://arxiv.org/abs/2510.14657},
}
```
*I visualise a time when we will be to robots what dogs are to humans, and Im rooting for the machines.* — Claude Shannon

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "vit-pytorch"
version = "1.14.1"
version = "1.15.6"
description = "Vision Transformer (ViT) - Pytorch"
readme = { file = "README.md", content-type = "text/markdown" }
license = { file = "LICENSE" }

107
train_vit_decorr.py Normal file
View File

@@ -0,0 +1,107 @@
# /// script
# dependencies = [
# "accelerate",
# "vit-pytorch",
# "wandb"
# ]
# ///
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as T
from torchvision.datasets import CIFAR100
# constants
BATCH_SIZE = 32
LEARNING_RATE = 3e-4
EPOCHS = 10
DECORR_LOSS_WEIGHT = 1e-1
TRACK_EXPERIMENT_ONLINE = False
# helpers
def exists(v):
return v is not None
# data
transform = T.Compose([
T.ToTensor(),
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = CIFAR100(
root = 'data',
download = True,
train = True,
transform = transform
)
dataloader = DataLoader(dataset, batch_size = BATCH_SIZE, shuffle = True)
# model
from vit_pytorch.vit_with_decorr import ViT
vit = ViT(
dim = 128,
num_classes = 100,
image_size = 32,
patch_size = 4,
depth = 6,
heads = 8,
dim_head = 64,
mlp_dim = 128 * 4,
decorr_sample_frac = 1. # use all tokens
)
# optim
from torch.optim import Adam
optim = Adam(vit.parameters(), lr = LEARNING_RATE)
# prepare
from accelerate import Accelerator
accelerator = Accelerator()
vit, optim, dataloader = accelerator.prepare(vit, optim, dataloader)
# experiment
import wandb
wandb.init(
project = 'vit-decorr',
mode = 'disabled' if not TRACK_EXPERIMENT_ONLINE else 'online'
)
wandb.run.name = 'baseline'
# loop
for _ in range(EPOCHS):
for images, labels in dataloader:
logits, decorr_aux_loss = vit(images)
loss = F.cross_entropy(logits, labels)
total_loss = (
loss +
decorr_aux_loss * DECORR_LOSS_WEIGHT
)
wandb.log(dict(loss = loss, decorr_loss = decorr_aux_loss))
accelerator.print(f'loss: {loss.item():.3f} | decorr aux loss: {decorr_aux_loss.item():.3f}')
accelerator.backward(total_loss)
optim.step()
optim.zero_grad()

744
vit_pytorch/vaat.py Normal file
View File

@@ -0,0 +1,744 @@
# vision-audio-action transformer - vaat
from __future__ import annotations
from contextlib import nullcontext
import torch
import torch.nn.functional as F
from torch import nn, cat, stack, arange, tensor
from torch.nn import Module, ModuleList
from torchaudio.transforms import Spectrogram
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange
# helpers
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# 2d sinusoidal positional embedding
# simple vit paper shows it is good enough compared to learned
def posemb_sincos_2d(
patches,
temperature = 10000,
dtype = torch.float32
):
_, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype
y, x = torch.meshgrid(arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
omega = arange(dim // 4, device = device) / (dim // 4 - 1)
omega = temperature ** -omega
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
pe = pe.type(dtype)
return rearrange(pe, '(h w) d -> h w d', h = h, w = w)
# classes
class FiLM(Module):
def __init__(
self,
dim,
):
super().__init__()
proj = nn.Linear(dim, dim * 2)
self.to_gamma_beta = nn.Sequential(
proj,
Rearrange('b (two d) -> two b 1 d', two = 2)
)
nn.init.zeros_(proj.weight)
nn.init.zeros_(proj.bias)
def forward(self, tokens, cond):
gamma, beta = self.to_gamma_beta(cond)
return tokens * gamma + beta
class FeedForward(Module):
def __init__(
self,
dim,
hidden_dim,
dropout = 0.
):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(Module):
def __init__(
self,
dim,
heads = 8,
dim_head = 64,
dropout = 0.,
dim_context = None,
cross_attend = False
):
super().__init__()
dim_context = default(dim_context, dim)
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.cross_attend = cross_attend
self.context_norm = nn.LayerNorm(dim_context) if cross_attend else None
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim_context, inner_dim * 2, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x, context = None):
assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross attending, or vice versa'
x = self.norm(x)
# handle norming of context for cross attention
kv_input = x
if self.cross_attend:
context = self.context_norm(context)
kv_input = context
# project for queries, keys, values
qkv = (self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(Module):
def __init__(
self,
dim,
depth,
heads,
dim_head,
mlp_dim,
dropout = 0.
):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(
self,
x,
return_hiddens = False
):
hiddens = []
for attn, ff in self.layers:
hiddens.append(x)
x = attn(x) + x
x = ff(x) + x
x = self.norm(x)
if not return_hiddens:
return x
return x, hiddens
class AST(Module):
# audio spectrogram transformer https://arxiv.org/abs/2104.01778
def __init__(
self,
dim,
depth,
mlp_dim,
num_classes = None,
patch_size = 16,
dim_head = 64,
heads = 8,
dropout = 0.,
accept_spec = False,
accept_spec_time_first = True,
spec_n_fft = 128,
spec_power = 2,
spec_win_length = 24,
spec_hop_length = None,
spec_pad = 0,
spec_center = True,
spec_pad_mode = 'reflect',
num_register_tokens = 4
):
super().__init__()
self.dim = dim
self.depth = depth
patch_height, patch_width = pair(patch_size)
patch_input_dim = patch_height * patch_width
self.patch_size = (patch_height, patch_width)
self.to_patch_tokens = nn.Sequential(
Rearrange('b (h p1) (w p2) -> b h w (p1 p2)', p1 = self.patch_size[0], p2 = self.patch_size[1]),
nn.LayerNorm(patch_input_dim),
nn.Linear(patch_input_dim, dim),
nn.LayerNorm(dim)
)
self.accept_spec = accept_spec
self.accept_spec_time_first = accept_spec_time_first
self.spec = Spectrogram(
n_fft = spec_n_fft,
power = spec_power,
win_length = spec_win_length,
hop_length = spec_hop_length,
pad = spec_pad,
center = spec_center,
pad_mode = spec_pad_mode
)
self.transformer = Transformer(
dim = dim,
depth = depth,
dim_head = dim_head,
heads = heads,
mlp_dim = mlp_dim,
dropout = dropout,
)
self.final_norm = nn.LayerNorm(dim)
self.mlp_head = nn.Linear(dim, num_classes) if exists(num_classes) else nn.Identity()
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
def forward(
self,
raw_audio_or_spec, # (b t) | (b f t)
return_hiddens = False
):
batch, device = raw_audio_or_spec.shape[0], raw_audio_or_spec.device
assert (self.accept_spec and raw_audio_or_spec.ndim == 3) or (not self.accept_spec and raw_audio_or_spec.ndim == 2)
if self.accept_spec:
spec = rearrange(raw_audio_or_spec, 'b t f -> b f t')
else:
spec = self.spec(raw_audio_or_spec)
# automatically crop if audio does not yield a 2d spectrogram that is divisible by patch sizes
height, width = spec.shape[-2:]
patch_height, patch_width = self.patch_size
rounded_height = height // patch_height * patch_height
rounded_width = width // patch_width * patch_width
spec = spec[..., :rounded_height, :rounded_width]
# to patches
tokens = self.to_patch_tokens(spec)
# get number of patches along height and width
_, num_patch_height, num_patch_width, _ = tokens.shape
# 2d sinusoidal positional embedding
tokens = tokens + posemb_sincos_2d(tokens)
tokens = rearrange(tokens, 'b ... c -> b (...) c')
# register tokens
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
tokens, packed_shape = pack((register_tokens, tokens), 'b * d')
# attention
attended, hiddens = self.transformer(tokens, return_hiddens = True)
# final global average and norm (most recent papers show this is superior to CLS token)
normed = self.final_norm(attended)
if return_hiddens:
return normed, stack(hiddens)
register_tokens, normed = unpack(normed, packed_shape, 'b * d')
pooled = reduce(normed, 'b n d -> b d', 'mean')
maybe_logits = self.mlp_head(pooled)
return maybe_logits
class ViT(Module):
def __init__(
self,
*,
image_size,
patch_size,
num_classes,
dim,
depth,
heads,
mlp_dim,
pool = 'cls',
channels = 3,
dim_head = 64,
dropout = 0.,
emb_dropout = 0.,
num_register_tokens = 0
):
super().__init__()
self.dim = dim
self.depth = depth
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.pos_embedding = nn.Parameter(torch.randn(num_patches, dim))
self.cls_token = nn.Parameter(torch.randn(dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes)
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
def forward(self, img, return_hiddens = False):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
x += self.pos_embedding[:n]
cls_tokens = repeat(self.cls_token, 'd -> b d', b = b)
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = b)
x, packed_shape = pack((register_tokens, cls_tokens, x), 'b * d')
x = self.dropout(x)
x, hiddens = self.transformer(x, return_hiddens = True)
# return the representation trajectory
if return_hiddens:
return x, stack(hiddens)
register_tokens, cls_tokens, x = unpack(x, packed_shape, 'b * d')
x = x.mean(dim = 1) if self.pool == 'mean' else cls_tokens
x = self.to_latent(x)
return self.mlp_head(x)
# proposed VAT
# https://openreview.net/forum?id=TalHOvvLZu
# simple way to get SOTA on Libero dataset (beating fine-tuned pi-zero)
class VAAT(Module):
def __init__(
self,
vit: ViT | dict,
ast: AST | dict,
*,
dim,
depth,
heads,
dim_head,
dim_action,
mlp_dim,
num_views = None,
num_tasks = None,
dim_extra_token = None,
num_register_tokens = 4,
action_chunk_len = 7,
time_seq_len = 1,
dropout = 0.,
add_self_attn = True, # in the paper, they didn't have any ways for the action token to exchange information with the extra token, so we'll just add it as an option
self_attn_heads = 4,
self_attn_dim_head = 32,
ast_layer_indices: tuple[int, ...] | None = None,
vit_layer_indices: tuple[int, ...] | None = None
):
super().__init__()
# vit
if isinstance(vit, dict):
vit = ViT(**vit)
self.vit = vit
vit_dim = vit.dim
assert vit.depth == depth or exists(vit_layer_indices), f'if the VAAT depth is not equal to the ViT depth, you must pass in the indices from the ViT to be layered to the VAAT in order from bottom to top'
vit_layer_indices = default(vit_layer_indices, tuple(range(depth)))
assert len(vit_layer_indices) == depth, f'number of vit layer indices {len(vit_layer_indices)} does not much the VAT depth {depth}'
self.register_buffer('vit_layer_indices', tensor(vit_layer_indices), persistent = False)
# ast
if isinstance(ast, dict):
ast = AST(**ast)
self.ast = ast
ast_dim = ast.dim
assert ast.depth == depth or exists(ast_layer_indices), f'if the VAAT depth is not equal to the AST depth, you must pass in the indices from the AST to be layered to the VAAT in order from bottom to top'
ast_layer_indices = default(ast_layer_indices, tuple(range(depth)))
assert len(ast_layer_indices) == depth, f'number of ast layer indices {len(ast_layer_indices)} does not much the VAAT depth {depth}'
self.register_buffer('ast_layer_indices', tensor(vit_layer_indices), persistent = False)
# handle maybe multiple frames
is_video = time_seq_len > 1
self.is_video = is_video
self.time_seq_len = time_seq_len
self.time_pos_emb = nn.Parameter(torch.randn(time_seq_len, vit_dim) * 1e-2) if is_video else None
# maybe view embeddings
self.view_emb = nn.Parameter(torch.randn(num_views, vit_dim) * 1e-2) if exists(num_views) and num_views > 1 else None
# handle maybe task conditioning
self.has_tasks = exists(num_tasks)
if self.has_tasks:
self.task_emb = nn.Parameter(torch.randn(num_tasks, dim) * 1e-2)
# register tokens from Darcet et al.
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
# to action tokens
self.action_pos_emb = nn.Parameter(torch.randn(action_chunk_len, dim) * 1e-2)
self.layers = ModuleList([])
for _ in range(depth):
maybe_film = FiLM(dim = dim) if self.has_tasks else None
maybe_self_attn = Attention(dim = dim, heads = self_attn_heads, dim_head = self_attn_dim_head, dropout = dropout) if add_self_attn else None
self.layers.append(ModuleList([
maybe_film,
maybe_self_attn,
Attention(dim = dim, dim_context = vit_dim, heads = heads, dim_head = dim_head, dropout = dropout, cross_attend = True),
Attention(dim = dim, dim_context = ast_dim, heads = heads, dim_head = dim_head, dropout = dropout, cross_attend = True),
FeedForward(dim = dim, hidden_dim = mlp_dim, dropout = dropout)
]))
self.final_norm = nn.LayerNorm(dim)
self.to_pred_action = nn.Linear(dim, dim_action, bias = False)
# handle the extra token
self.accept_extra_token = exists(dim_extra_token)
if exists(dim_extra_token):
self.to_extra_token = nn.Linear(dim_extra_token, dim)
def forward(
self,
video_or_image, # (b v? c t? h w) - batch, views [wrist + third person or more], channels, maybe time, height, width
audio_or_spec, # (b t) | (b f t) - batch, audio len | batch, spec freq, time
*,
extra = None, # (b d) - batch, dim extra
tasks = None, # (b)
actions = None, # (b k d) - batch, action chunk length, action dimension
return_hiddens = False,
freeze_vit = False,
freeze_ast = False
):
batch = video_or_image.shape[0]
return_loss = exists(actions)
# handle some various input dimensions
if video_or_image.ndim == 4:
video_or_image = rearrange(video_or_image, 'b 1 c h w')
assert (
(video_or_image.ndim == 5 and not self.is_video) or
(video_or_image.ndim == 6 and self.is_video)
)
if video_or_image.ndim == 5:
video_or_image = rearrange(video_or_image, 'b v c h w -> b v c 1 h w')
assert video_or_image.shape[3] == self.time_seq_len
# to images
images = rearrange(video_or_image, 'b v c t h w -> b v t c h w')
images, packed_shape = pack([images], '* c h w')
# get representation trajectory from vit
vit_forward_context = torch.no_grad if freeze_vit else nullcontext
with vit_forward_context():
embed, hiddens = self.vit(images, return_hiddens = True)
hiddens = cat((hiddens, embed[None, ...]))
# extract the hiddens needed for the action cross attention
hiddens = hiddens[self.vit_layer_indices]
# pack temporarily for embedding
hiddens, = unpack(hiddens, packed_shape, 'l * n d') # l for layers
# maybe add time embeddings
if self.is_video:
time_pos_emb = rearrange(self.time_pos_emb, 't d -> t 1 d')
hiddens = hiddens + time_pos_emb
# maybe view embeddings
if exists(self.view_emb):
assert self.view_emb.shape[0] == hiddens.shape[2]
view_emb = rearrange(self.view_emb, 'v d -> v 1 1 d')
hiddens = hiddens + view_emb
# get representation trajectory from ast
ast_forward_context = torch.no_grad if freeze_ast else nullcontext
with ast_forward_context():
audio_embed, audio_hiddens = self.ast(audio_or_spec, return_hiddens = True)
audio_hiddens = cat((audio_hiddens, audio_embed[None, ...]))
# extract the hiddens needed for the action cross attention
audio_hiddens = audio_hiddens[self.ast_layer_indices]
# maybe tasks
if exists(tasks):
assert self.has_tasks, f'`num_tasks` must be set on `VAT` for task conditioning'
task_emb = self.task_emb[tasks]
# cross from actions to representation trajectory
image_context = rearrange(hiddens, 'l b v t n d -> l b (v t n) d')
audio_context = audio_hiddens # eventually handle views (stereo and beyond)
# get main action tokens and maybe append extra
action_tokens = repeat(self.action_pos_emb, 'k d -> b k d', b = batch)
has_extra = exists(extra)
if has_extra:
assert self.accept_extra_token
extra_token = self.to_extra_token(extra)
action_tokens, packed_extra = pack([action_tokens, extra_token], 'b * d')
# register tokens
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
action_tokens, registers_packed_shape = pack((register_tokens, action_tokens), 'b * d')
# cross attention
hiddens = [action_tokens]
for (maybe_film, maybe_self_attn, image_cross_attn, audio_cross_attn, ff), image_layer_context, audio_layer_context in zip(self.layers, image_context, audio_context):
if exists(tasks):
action_tokens = maybe_film(action_tokens, task_emb)
action_tokens = image_cross_attn(action_tokens, image_layer_context) + action_tokens
action_tokens = audio_cross_attn(action_tokens, audio_layer_context) + action_tokens
if exists(maybe_self_attn):
action_tokens = maybe_self_attn(action_tokens) + action_tokens
action_tokens = ff(action_tokens) + action_tokens
hiddens.append(action_tokens)
# unpack registers
_, action_tokens = unpack(action_tokens, registers_packed_shape, 'b * d')
# maybe unpack extra
if has_extra:
action_tokens, _ = unpack(action_tokens, packed_extra, 'b * d')
# norm and prediction
action_tokens = self.final_norm(action_tokens)
pred_action = self.to_pred_action(action_tokens)
if not return_loss:
if not return_hiddens:
return pred_action
return pred_action, stack(hiddens)
assert pred_action.shape[1] == actions.shape[1]
# they found l1 loss suffices
return F.l1_loss(pred_action, actions)
# quick test
if __name__ == '__main__':
vit = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 384,
heads = 8,
depth = 4,
mlp_dim = 384 * 4
)
ast = AST(
dim = 384,
depth = 4,
heads = 8,
num_classes = 1000,
patch_size = 16,
mlp_dim = 384 * 4
)
vat = VAAT(
vit,
ast,
dim = 512,
depth = 9,
heads = 8,
dim_head = 64,
mlp_dim = 2048,
dim_action = 20,
action_chunk_len = 7,
time_seq_len = 4,
num_views = 2,
num_tasks = 4,
add_self_attn = True,
dim_extra_token = 33, # extra token with some variable dimension
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
0, 0, 1, 1, 2, 2, 3, 3, 4
),
ast_layer_indices = (
1, 1, 1, 2, 2, 2, 3, 3, 3
)
)
images = torch.randn(2, 2, 3, 4, 256, 256) # (2 views with 4 frames)
audio = torch.randn(2, 14_100 * 5)
tasks = torch.randint(0, 4, (2,))
extra = torch.randn(2, 33) # extra internal state
actions = torch.randn(2, 7, 20) # actions for learning
loss = vat(images, audio, actions = actions, tasks = tasks, extra = extra, freeze_vit = True)
loss.backward()
# after much training
pred_actions, hiddens = vat(images, audio, tasks = tasks, extra = extra, return_hiddens = True)
assert pred_actions.shape == (2, 7, 20)

View File

@@ -1,4 +1,5 @@
from __future__ import annotations
from contextlib import nullcontext
import torch
import torch.nn.functional as F
@@ -21,6 +22,27 @@ def pair(t):
# classes
class FiLM(Module):
def __init__(
self,
dim,
):
super().__init__()
proj = nn.Linear(dim, dim * 2)
self.to_gamma_beta = nn.Sequential(
proj,
Rearrange('b (two d) -> two b 1 d', two = 2)
)
nn.init.zeros_(proj.weight)
nn.init.zeros_(proj.bias)
def forward(self, tokens, cond):
gamma, beta = self.to_gamma_beta(cond)
return tokens * gamma + beta
class FeedForward(Module):
def __init__(
self,
@@ -45,12 +67,14 @@ class Attention(Module):
def __init__(
self,
dim,
dim_context = None,
heads = 8,
dim_head = 64,
dropout = 0.,
cross_attend = False
):
super().__init__()
dim_context = default(dim_context, dim)
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
@@ -60,13 +84,13 @@ class Attention(Module):
self.norm = nn.LayerNorm(dim)
self.cross_attend = cross_attend
self.context_norm = nn.LayerNorm(dim) if cross_attend else None
self.context_norm = nn.LayerNorm(dim_context) if cross_attend else None
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_kv = nn.Linear(dim_context, inner_dim * 2, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
@@ -157,7 +181,8 @@ class ViT(Module):
channels = 3,
dim_head = 64,
dropout = 0.,
emb_dropout = 0.
emb_dropout = 0.,
num_register_tokens = 0
):
super().__init__()
self.dim = dim
@@ -179,8 +204,8 @@ class ViT(Module):
nn.LayerNorm(dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.pos_embedding = nn.Parameter(torch.randn(num_patches, dim))
self.cls_token = nn.Parameter(torch.randn(dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
@@ -190,13 +215,19 @@ class ViT(Module):
self.mlp_head = nn.Linear(dim, num_classes)
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
def forward(self, img, return_hiddens = False):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x += self.pos_embedding[:n]
cls_tokens = repeat(self.cls_token, 'd -> b d', b = b)
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = b)
x, packed_shape = pack((register_tokens, cls_tokens, x), 'b * d')
x = self.dropout(x)
x, hiddens = self.transformer(x, return_hiddens = True)
@@ -206,7 +237,9 @@ class ViT(Module):
if return_hiddens:
return x, stack(hiddens)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
register_tokens, cls_tokens, x = unpack(x, packed_shape, 'b * d')
x = x.mean(dim = 1) if self.pool == 'mean' else cls_tokens
x = self.to_latent(x)
return self.mlp_head(x)
@@ -228,7 +261,9 @@ class VAT(Module):
dim_action,
mlp_dim,
num_views = None,
num_tasks = None,
dim_extra_token = None,
num_register_tokens = 4,
action_chunk_len = 7,
time_seq_len = 1,
dropout = 0.,
@@ -266,6 +301,17 @@ class VAT(Module):
self.view_emb = nn.Parameter(torch.randn(num_views, vit_dim) * 1e-2) if exists(num_views) and num_views > 1 else None
# handle maybe task conditioning
self.has_tasks = exists(num_tasks)
if self.has_tasks:
self.task_emb = nn.Parameter(torch.randn(num_tasks, dim) * 1e-2)
# register tokens from Darcet et al.
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
# to action tokens
self.action_pos_emb = nn.Parameter(torch.randn(action_chunk_len, dim) * 1e-2)
@@ -273,11 +319,13 @@ class VAT(Module):
self.layers = ModuleList([])
for _ in range(depth):
maybe_film = FiLM(dim = dim) if self.has_tasks else None
maybe_self_attn = Attention(dim = dim, heads = self_attn_heads, dim_head = self_attn_dim_head, dropout = dropout) if add_self_attn else None
self.layers.append(ModuleList([
maybe_film,
maybe_self_attn,
Attention(dim = dim, heads = heads, dim_head = dim_head, dropout = dropout, cross_attend = True),
Attention(dim = dim, dim_context = vit_dim, heads = heads, dim_head = dim_head, dropout = dropout, cross_attend = True),
FeedForward(dim = dim, hidden_dim = mlp_dim, dropout = dropout)
]))
@@ -294,8 +342,12 @@ class VAT(Module):
def forward(
self,
video_or_image, # (b v? c t? h w) - batch, views [wrist + third person or more], channels, maybe time, height, width
*,
extra = None, # (b d) - batch, dim extra
tasks = None, # (b)
actions = None, # (b k d) - batch, action chunk length, action dimension
return_hiddens = False,
freeze_vit = False
):
batch = video_or_image.shape[0]
return_loss = exists(actions)
@@ -323,7 +375,10 @@ class VAT(Module):
# get representation trajectory from vit
embed, hiddens = self.vit(images, return_hiddens = True)
vit_forward_context = torch.no_grad if freeze_vit else nullcontext
with vit_forward_context():
embed, hiddens = self.vit(images, return_hiddens = True)
hiddens = cat((hiddens, embed[None, ...]))
@@ -349,6 +404,13 @@ class VAT(Module):
view_emb = rearrange(self.view_emb, 'v d -> v 1 1 d')
hiddens = hiddens + view_emb
# maybe tasks
if exists(tasks):
assert self.has_tasks, f'`num_tasks` must be set on `VAT` for task conditioning'
task_emb = self.task_emb[tasks]
# cross from actions to representation trajectory
context = rearrange(hiddens, 'l b v t n d -> l b (v t n) d')
@@ -366,9 +428,20 @@ class VAT(Module):
action_tokens, packed_extra = pack([action_tokens, extra_token], 'b * d')
# register tokens
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
action_tokens, registers_packed_shape = pack((register_tokens, action_tokens), 'b * d')
# cross attention
for (maybe_self_attn, cross_attn, ff), layer_context in zip(self.layers, context):
hiddens = [action_tokens]
for (maybe_film, maybe_self_attn, cross_attn, ff), layer_context in zip(self.layers, context):
if exists(tasks):
action_tokens = maybe_film(action_tokens, task_emb)
action_tokens = cross_attn(action_tokens, layer_context) + action_tokens
@@ -377,6 +450,12 @@ class VAT(Module):
action_tokens = ff(action_tokens) + action_tokens
hiddens.append(action_tokens)
# unpack registers
_, action_tokens = unpack(action_tokens, registers_packed_shape, 'b * d')
# maybe unpack extra
if has_extra:
@@ -389,7 +468,10 @@ class VAT(Module):
pred_action = self.to_pred_action(action_tokens)
if not return_loss:
return pred_action
if not return_hiddens:
return pred_action
return pred_action, stack(hiddens)
assert pred_action.shape[1] == actions.shape[1]
@@ -405,10 +487,10 @@ if __name__ == '__main__':
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 512,
dim = 256,
heads = 8,
depth = 4,
mlp_dim = 2048
mlp_dim = 1024
)
vat = VAT(
@@ -422,6 +504,7 @@ if __name__ == '__main__':
action_chunk_len = 7,
time_seq_len = 4,
num_views = 2,
num_tasks = 4,
add_self_attn = True,
dim_extra_token = 33, # extra token with some variable dimension
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
@@ -430,15 +513,16 @@ if __name__ == '__main__':
)
images = torch.randn(2, 2, 3, 4, 256, 256) # (2 views with 4 frames)
tasks = torch.randint(0, 4, (2,))
extra = torch.randn(2, 33) # extra internal state
actions = torch.randn(2, 7, 20) # actions for learning
loss = vat(images, actions = actions, extra = extra)
loss = vat(images, actions = actions, tasks = tasks, extra = extra, freeze_vit = True)
loss.backward()
# after much training
pred_actions = vat(images)
pred_actions, hiddens = vat(images, tasks = tasks, extra = extra, return_hiddens = True)
assert pred_actions.shape == (2, 7, 20)

View File

@@ -0,0 +1,234 @@
# https://arxiv.org/abs/2510.14657
# but instead of their decorr module updated with SGD, remove all projections and just return a decorrelation auxiliary loss
import torch
from torch import nn, stack, tensor
import torch.nn.functional as F
from torch.nn import Module, ModuleList
from einops import rearrange, repeat, reduce, einsum, pack, unpack
from einops.layers.torch import Rearrange
# helpers
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# decorr loss
class DecorrelationLoss(Module):
def __init__(
self,
sample_frac = 1.,
soft_validate_num_sampled = False
):
super().__init__()
assert 0. <= sample_frac <= 1.
self.need_sample = sample_frac < 1.
self.sample_frac = sample_frac
self.soft_validate_num_sampled = soft_validate_num_sampled
self.register_buffer('zero', tensor(0.), persistent = False)
def forward(
self,
tokens
):
batch, seq_len, dim, device = *tokens.shape[-3:], tokens.device
if self.need_sample:
num_sampled = int(seq_len * self.sample_frac)
assert self.soft_validate_num_sampled or num_sampled >= 2.
if num_sampled <= 1:
return self.zero
tokens, packed_shape = pack([tokens], '* n d e')
indices = torch.randn(tokens.shape[:2]).argsort(dim = -1)[..., :num_sampled, :]
batch_arange = torch.arange(tokens.shape[0], device = tokens.device)
batch_arange = rearrange(batch_arange, 'b -> b 1')
tokens = tokens[batch_arange, indices]
tokens, = unpack(tokens, packed_shape, '* n d e')
dist = einsum(tokens, tokens, '... n d, ... n e -> ... d e') / tokens.shape[-2]
eye = torch.eye(dim, device = device)
loss = dist.pow(2) * (1. - eye) / ((dim - 1) * dim)
loss = reduce(loss, '... b d e -> b', 'sum')
return loss.mean()
# classes
class FeedForward(Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
normed = self.norm(x)
return self.net(x), normed
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.norm = nn.LayerNorm(dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
normed = self.norm(x)
qkv = self.to_qkv(normed).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out), normed
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
normed_inputs = []
for attn, ff in self.layers:
attn_out, attn_normed_inp = attn(x)
x = attn_out + x
ff_out, ff_normed_inp = ff(x)
x = ff_out + x
normed_inputs.append(attn_normed_inp)
normed_inputs.append(ff_normed_inp)
return self.norm(x), stack(normed_inputs)
class ViT(Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., decorr_sample_frac = 1.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes)
# decorrelation loss related
self.has_decorr_loss = decorr_sample_frac > 0.
if self.has_decorr_loss:
self.decorr_loss = DecorrelationLoss(decorr_sample_frac)
self.register_buffer('zero', torch.tensor(0.), persistent = False)
def forward(
self,
img,
return_decorr_aux_loss = None
):
return_decorr_aux_loss = default(return_decorr_aux_loss, self.training) and self.has_decorr_loss
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x, normed_layer_inputs = self.transformer(x)
# maybe return decor loss
decorr_aux_loss = self.zero
if return_decorr_aux_loss:
decorr_aux_loss = self.decorr_loss(normed_layer_inputs)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x), decorr_aux_loss
# quick test
if __name__ == '__main__':
decorr_loss = DecorrelationLoss(0.1)
hiddens = torch.randn(6, 2, 512, 256)
decorr_loss(hiddens)
decorr_loss(hiddens[0])
decorr_loss = DecorrelationLoss(0.0001, soft_validate_num_sampled = True)
out = decorr_loss(hiddens)
assert out.item() == 0