Compare commits

..

6 Commits

5 changed files with 448 additions and 13 deletions

View File

@@ -2201,4 +2201,16 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```
```bibtex
@misc{carrigg2025decorrelationspeedsvisiontransformers,
title = {Decorrelation Speeds Up Vision Transformers},
author = {Kieran Carrigg and Rob van Gastel and Melda Yeghaian and Sander Dalm and Faysal Boughorbel and Marcel van Gerven},
year = {2025},
eprint = {2510.14657},
archivePrefix = {arXiv},
primaryClass = {cs.CV},
url = {https://arxiv.org/abs/2510.14657},
}
```
*I visualise a time when we will be to robots what dogs are to humans, and Im rooting for the machines.* — Claude Shannon

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "vit-pytorch"
version = "1.14.1"
version = "1.15.4"
description = "Vision Transformer (ViT) - Pytorch"
readme = { file = "README.md", content-type = "text/markdown" }
license = { file = "LICENSE" }

107
train_vit_decorr.py Normal file
View File

@@ -0,0 +1,107 @@
# /// script
# dependencies = [
# "accelerate",
# "vit-pytorch",
# "wandb"
# ]
# ///
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as T
from torchvision.datasets import CIFAR100
# constants
BATCH_SIZE = 32
LEARNING_RATE = 3e-4
EPOCHS = 10
DECORR_LOSS_WEIGHT = 1e-1
TRACK_EXPERIMENT_ONLINE = False
# helpers
def exists(v):
return v is not None
# data
transform = T.Compose([
T.ToTensor(),
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = CIFAR100(
root = 'data',
download = True,
train = True,
transform = transform
)
dataloader = DataLoader(dataset, batch_size = BATCH_SIZE, shuffle = True)
# model
from vit_pytorch.vit_with_decorr import ViT
vit = ViT(
dim = 128,
num_classes = 100,
image_size = 32,
patch_size = 4,
depth = 6,
heads = 8,
dim_head = 64,
mlp_dim = 128 * 4,
decorr_sample_frac = 1. # use all tokens
)
# optim
from torch.optim import Adam
optim = Adam(vit.parameters(), lr = LEARNING_RATE)
# prepare
from accelerate import Accelerator
accelerator = Accelerator()
vit, optim, dataloader = accelerator.prepare(vit, optim, dataloader)
# experiment
import wandb
wandb.init(
project = 'vit-decorr',
mode = 'disabled' if not TRACK_EXPERIMENT_ONLINE else 'online'
)
wandb.run.name = 'baseline'
# loop
for _ in range(EPOCHS):
for images, labels in dataloader:
logits, decorr_aux_loss = vit(images)
loss = F.cross_entropy(logits, labels)
total_loss = (
loss +
decorr_aux_loss * DECORR_LOSS_WEIGHT
)
wandb.log(dict(loss = loss, decorr_loss = decorr_aux_loss))
accelerator.print(f'loss: {loss.item():.3f} | decorr aux loss: {decorr_aux_loss.item():.3f}')
accelerator.backward(total_loss)
optim.step()
optim.zero_grad()

View File

@@ -1,4 +1,5 @@
from __future__ import annotations
from contextlib import nullcontext
import torch
import torch.nn.functional as F
@@ -21,6 +22,27 @@ def pair(t):
# classes
class FiLM(Module):
def __init__(
self,
dim,
):
super().__init__()
proj = nn.Linear(dim, dim * 2)
self.to_gamma_beta = nn.Sequential(
proj,
Rearrange('b (two d) -> two b 1 d', two = 2)
)
nn.init.zeros_(proj.weight)
nn.init.zeros_(proj.bias)
def forward(self, tokens, cond):
gamma, beta = self.to_gamma_beta(cond)
return tokens * gamma + beta
class FeedForward(Module):
def __init__(
self,
@@ -157,7 +179,8 @@ class ViT(Module):
channels = 3,
dim_head = 64,
dropout = 0.,
emb_dropout = 0.
emb_dropout = 0.,
num_register_tokens = 0
):
super().__init__()
self.dim = dim
@@ -179,8 +202,8 @@ class ViT(Module):
nn.LayerNorm(dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.pos_embedding = nn.Parameter(torch.randn(num_patches, dim))
self.cls_token = nn.Parameter(torch.randn(dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
@@ -190,13 +213,19 @@ class ViT(Module):
self.mlp_head = nn.Linear(dim, num_classes)
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
def forward(self, img, return_hiddens = False):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x += self.pos_embedding[:n]
cls_tokens = repeat(self.cls_token, 'd -> b d', b = b)
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = b)
x, packed_shape = pack((register_tokens, cls_tokens, x), 'b * d')
x = self.dropout(x)
x, hiddens = self.transformer(x, return_hiddens = True)
@@ -206,7 +235,9 @@ class ViT(Module):
if return_hiddens:
return x, stack(hiddens)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
cls_tokens, x, register_tokens = unpack(x, packed_shape, 'b * d')
x = x.mean(dim = 1) if self.pool == 'mean' else cls_tokens
x = self.to_latent(x)
return self.mlp_head(x)
@@ -228,7 +259,9 @@ class VAT(Module):
dim_action,
mlp_dim,
num_views = None,
num_tasks = None,
dim_extra_token = None,
num_register_tokens = 4,
action_chunk_len = 7,
time_seq_len = 1,
dropout = 0.,
@@ -266,6 +299,17 @@ class VAT(Module):
self.view_emb = nn.Parameter(torch.randn(num_views, vit_dim) * 1e-2) if exists(num_views) and num_views > 1 else None
# handle maybe task conditioning
self.has_tasks = exists(num_tasks)
if self.has_tasks:
self.task_emb = nn.Parameter(torch.randn(num_tasks, dim) * 1e-2)
# register tokens from Darcet et al.
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
# to action tokens
self.action_pos_emb = nn.Parameter(torch.randn(action_chunk_len, dim) * 1e-2)
@@ -273,9 +317,11 @@ class VAT(Module):
self.layers = ModuleList([])
for _ in range(depth):
maybe_film = FiLM(dim = dim) if self.has_tasks else None
maybe_self_attn = Attention(dim = dim, heads = self_attn_heads, dim_head = self_attn_dim_head, dropout = dropout) if add_self_attn else None
self.layers.append(ModuleList([
maybe_film,
maybe_self_attn,
Attention(dim = dim, heads = heads, dim_head = dim_head, dropout = dropout, cross_attend = True),
FeedForward(dim = dim, hidden_dim = mlp_dim, dropout = dropout)
@@ -294,8 +340,12 @@ class VAT(Module):
def forward(
self,
video_or_image, # (b v? c t? h w) - batch, views [wrist + third person or more], channels, maybe time, height, width
*,
extra = None, # (b d) - batch, dim extra
tasks = None, # (b)
actions = None, # (b k d) - batch, action chunk length, action dimension
return_hiddens = False,
freeze_vit = False
):
batch = video_or_image.shape[0]
return_loss = exists(actions)
@@ -323,7 +373,10 @@ class VAT(Module):
# get representation trajectory from vit
embed, hiddens = self.vit(images, return_hiddens = True)
vit_forward_context = torch.no_grad if freeze_vit else nullcontext
with vit_forward_context():
embed, hiddens = self.vit(images, return_hiddens = True)
hiddens = cat((hiddens, embed[None, ...]))
@@ -349,6 +402,13 @@ class VAT(Module):
view_emb = rearrange(self.view_emb, 'v d -> v 1 1 d')
hiddens = hiddens + view_emb
# maybe tasks
if exists(tasks):
assert self.has_tasks, f'`num_tasks` must be set on `VAT` for task conditioning'
task_emb = self.task_emb[tasks]
# cross from actions to representation trajectory
context = rearrange(hiddens, 'l b v t n d -> l b (v t n) d')
@@ -366,9 +426,20 @@ class VAT(Module):
action_tokens, packed_extra = pack([action_tokens, extra_token], 'b * d')
# register tokens
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
action_tokens, registers_packed_shape = pack((register_tokens, action_tokens), 'b * d')
# cross attention
for (maybe_self_attn, cross_attn, ff), layer_context in zip(self.layers, context):
hiddens = [action_tokens]
for (maybe_film, maybe_self_attn, cross_attn, ff), layer_context in zip(self.layers, context):
if exists(tasks):
action_tokens = maybe_film(action_tokens, task_emb)
action_tokens = cross_attn(action_tokens, layer_context) + action_tokens
@@ -377,6 +448,12 @@ class VAT(Module):
action_tokens = ff(action_tokens) + action_tokens
hiddens.append(action_tokens)
# unpack registers
_, action_tokens = unpack(action_tokens, registers_packed_shape, 'b * d')
# maybe unpack extra
if has_extra:
@@ -389,7 +466,10 @@ class VAT(Module):
pred_action = self.to_pred_action(action_tokens)
if not return_loss:
return pred_action
if not return_hiddens:
return pred_action
return pred_action, stack(hiddens)
assert pred_action.shape[1] == actions.shape[1]
@@ -422,6 +502,7 @@ if __name__ == '__main__':
action_chunk_len = 7,
time_seq_len = 4,
num_views = 2,
num_tasks = 4,
add_self_attn = True,
dim_extra_token = 33, # extra token with some variable dimension
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
@@ -430,15 +511,16 @@ if __name__ == '__main__':
)
images = torch.randn(2, 2, 3, 4, 256, 256) # (2 views with 4 frames)
tasks = torch.randint(0, 4, (2,))
extra = torch.randn(2, 33) # extra internal state
actions = torch.randn(2, 7, 20) # actions for learning
loss = vat(images, actions = actions, extra = extra)
loss = vat(images, actions = actions, tasks = tasks, extra = extra, freeze_vit = True)
loss.backward()
# after much training
pred_actions = vat(images)
pred_actions, hiddens = vat(images, tasks = tasks, extra = extra, return_hiddens = True)
assert pred_actions.shape == (2, 7, 20)

View File

@@ -0,0 +1,234 @@
# https://arxiv.org/abs/2510.14657
# but instead of their decorr module updated with SGD, remove all projections and just return a decorrelation auxiliary loss
import torch
from torch import nn, stack, tensor
import torch.nn.functional as F
from torch.nn import Module, ModuleList
from einops import rearrange, repeat, reduce, einsum, pack, unpack
from einops.layers.torch import Rearrange
# helpers
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# decorr loss
class DecorrelationLoss(Module):
def __init__(
self,
sample_frac = 1.,
soft_validate_num_sampled = False
):
super().__init__()
assert 0. <= sample_frac <= 1.
self.need_sample = sample_frac < 1.
self.sample_frac = sample_frac
self.soft_validate_num_sampled = soft_validate_num_sampled
self.register_buffer('zero', tensor(0.), persistent = False)
def forward(
self,
tokens
):
batch, seq_len, dim, device = *tokens.shape[-3:], tokens.device
if self.need_sample:
num_sampled = int(seq_len * self.sample_frac)
assert self.soft_validate_num_sampled or num_sampled >= 2.
if num_sampled <= 1:
return self.zero
tokens, packed_shape = pack([tokens], '* n d e')
indices = torch.randn(tokens.shape[:2]).argsort(dim = -1)[..., :num_sampled, :]
batch_arange = torch.arange(tokens.shape[0], device = tokens.device)
batch_arange = rearrange(batch_arange, 'b -> b 1')
tokens = tokens[batch_arange, indices]
tokens, = unpack(tokens, packed_shape, '* n d e')
dist = einsum(tokens, tokens, '... n d, ... n e -> ... d e') / tokens.shape[-2]
eye = torch.eye(dim, device = device)
loss = dist.pow(2) * (1. - eye) / ((dim - 1) * dim)
loss = reduce(loss, '... b d e -> b', 'sum')
return loss.mean()
# classes
class FeedForward(Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
normed = self.norm(x)
return self.net(x), normed
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.norm = nn.LayerNorm(dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
normed = self.norm(x)
qkv = self.to_qkv(normed).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out), normed
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
normed_inputs = []
for attn, ff in self.layers:
attn_out, attn_normed_inp = attn(x)
x = attn_out + x
ff_out, ff_normed_inp = ff(x)
x = ff_out + x
normed_inputs.append(attn_normed_inp)
normed_inputs.append(ff_normed_inp)
return self.norm(x), stack(normed_inputs)
class ViT(Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., decorr_sample_frac = 1.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes)
# decorrelation loss related
self.has_decorr_loss = decorr_sample_frac > 0.
if self.has_decorr_loss:
self.decorr_loss = DecorrelationLoss(decorr_sample_frac)
self.register_buffer('zero', torch.tensor(0.), persistent = False)
def forward(
self,
img,
return_decorr_aux_loss = None
):
return_decorr_aux_loss = default(return_decorr_aux_loss, self.training) and self.has_decorr_loss
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x, normed_layer_inputs = self.transformer(x)
# maybe return decor loss
decorr_aux_loss = self.zero
if return_decorr_aux_loss:
decorr_aux_loss = self.decorr_loss(normed_layer_inputs)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x), decorr_aux_loss
# quick test
if __name__ == '__main__':
decorr_loss = DecorrelationLoss(0.1)
hiddens = torch.randn(6, 2, 512, 256)
decorr_loss(hiddens)
decorr_loss(hiddens[0])
decorr_loss = DecorrelationLoss(0.0001, soft_validate_num_sampled = True)
out = decorr_loss(hiddens)
assert out.item() == 0