From c07a55cc8354a69593b7a7580c6d16de228e8e9c Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sun, 26 Oct 2025 18:09:57 -0700 Subject: [PATCH] add a vit with decorrelation auxiliary losses for mha and feedforwards, right after prenorm - this is in line with a paper from the netherlands, but without extra parameters or their manual sgd update scheme --- README.md | 12 ++ pyproject.toml | 2 +- train_vit_decorr.py | 107 +++++++++++++++++ vit_pytorch/vit_with_decorr.py | 212 +++++++++++++++++++++++++++++++++ 4 files changed, 332 insertions(+), 1 deletion(-) create mode 100644 train_vit_decorr.py create mode 100644 vit_pytorch/vit_with_decorr.py diff --git a/README.md b/README.md index 2116596..4f0bb71 100644 --- a/README.md +++ b/README.md @@ -2201,4 +2201,16 @@ Coming from computer vision and new to transformers? Here are some resources tha } ``` +```bibtex +@misc{carrigg2025decorrelationspeedsvisiontransformers, + title = {Decorrelation Speeds Up Vision Transformers}, + author = {Kieran Carrigg and Rob van Gastel and Melda Yeghaian and Sander Dalm and Faysal Boughorbel and Marcel van Gerven}, + year = {2025}, + eprint = {2510.14657}, + archivePrefix = {arXiv}, + primaryClass = {cs.CV}, + url = {https://arxiv.org/abs/2510.14657}, +} +``` + *I visualise a time when we will be to robots what dogs are to humans, and I’m rooting for the machines.* — Claude Shannon diff --git a/pyproject.toml b/pyproject.toml index f0fe609..d725f44 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "vit-pytorch" -version = "1.14.5" +version = "1.15.1" description = "Vision Transformer (ViT) - Pytorch" readme = { file = "README.md", content-type = "text/markdown" } license = { file = "LICENSE" } diff --git a/train_vit_decorr.py b/train_vit_decorr.py new file mode 100644 index 0000000..ca6a5dd --- /dev/null +++ b/train_vit_decorr.py @@ -0,0 +1,107 @@ +# /// script +# dependencies = [ +# "accelerate", +# "vit-pytorch", +# "wandb" +# ] +# /// + +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader + +import torchvision.transforms as T +from torchvision.datasets import CIFAR100 + +# constants + +BATCH_SIZE = 32 +LEARNING_RATE = 3e-4 +EPOCHS = 10 +DECORR_LOSS_WEIGHT = 1e-1 + +TRACK_EXPERIMENT_ONLINE = False + +# helpers + +def exists(v): + return v is not None + +# data + +transform = T.Compose([ + T.ToTensor(), + T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) +]) + +dataset = CIFAR100( + root = 'data', + download = True, + train = True, + transform = transform +) + +dataloader = DataLoader(dataset, batch_size = BATCH_SIZE, shuffle = True) + +# model + +from vit_pytorch.vit_with_decorr import ViT + +vit = ViT( + dim = 128, + num_classes = 100, + image_size = 32, + patch_size = 4, + depth = 6, + heads = 8, + dim_head = 64, + mlp_dim = 128 * 4, + decorr_sample_frac = 1. # use all tokens +) + +# optim + +from torch.optim import Adam + +optim = Adam(vit.parameters(), lr = LEARNING_RATE) + +# prepare + +from accelerate import Accelerator + +accelerator = Accelerator() + +vit, optim, dataloader = accelerator.prepare(vit, optim, dataloader) + +# experiment + +import wandb + +wandb.init( + project = 'vit-decorr', + mode = 'disabled' if not TRACK_EXPERIMENT_ONLINE else 'online' +) + +wandb.run.name = 'baseline' + +# loop + +for _ in range(EPOCHS): + for images, labels in dataloader: + + logits, decorr_aux_loss = vit(images) + loss = F.cross_entropy(logits, labels) + + + total_loss = ( + loss + + decorr_aux_loss * DECORR_LOSS_WEIGHT + ) + + wandb.log(dict(loss = loss, decorr_loss = decorr_aux_loss)) + + accelerator.print(f'loss: {loss.item():.3f} | decorr aux loss: {decorr_aux_loss.item():.3f}') + + accelerator.backward(total_loss) + optim.step() + optim.zero_grad() diff --git a/vit_pytorch/vit_with_decorr.py b/vit_pytorch/vit_with_decorr.py new file mode 100644 index 0000000..d298f8e --- /dev/null +++ b/vit_pytorch/vit_with_decorr.py @@ -0,0 +1,212 @@ +# https://arxiv.org/abs/2510.14657 +# but instead of their decorr module updated with SGD, remove all projections and just return a decorrelation auxiliary loss + +import torch +from torch import nn, stack +import torch.nn.functional as F +from torch.nn import Module, ModuleList + +from einops import rearrange, repeat, reduce, einsum, pack, unpack +from einops.layers.torch import Rearrange + +# helpers + +def exists(v): + return v is not None + +def default(v, d): + return v if exists(v) else d + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + +# decorr loss + +class DecorrelationLoss(Module): + def __init__( + self, + sample_frac = 1. + ): + super().__init__() + assert 0. <= sample_frac <= 1. + self.need_sample = sample_frac < 1. + self.sample_frac = sample_frac + + def forward( + self, + tokens + ): + batch, seq_len, dim, device = *tokens.shape[-3:], tokens.device + + if self.need_sample: + num_sampled = int(seq_len * self.sample_frac) + assert num_sampled >= 2. + + tokens, packed_shape = pack([tokens], '* n d e') + + indices = torch.randn(tokens.shape[:2]).argsort(dim = -1)[..., :num_sampled, :] + + batch_arange = torch.arange(tokens.shape[0], device = tokens.device) + batch_arange = rearrange(batch_arange, 'b -> b 1') + + tokens = tokens[batch_arange, indices] + tokens, = unpack(tokens, packed_shape, '* n d e') + + dist = einsum(tokens, tokens, '... n d, ... n e -> ... d e') / tokens.shape[-2] + eye = torch.eye(dim, device = device) + + loss = dist.pow(2) * (1. - eye) / ((dim - 1) * dim) + + loss = reduce(loss, 'l b d e -> b', 'sum') + return loss.sum() + +# classes + +class FeedForward(Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.norm = nn.LayerNorm(dim) + + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + + def forward(self, x): + normed = self.norm(x) + return self.net(x), normed + +class Attention(Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.norm = nn.LayerNorm(dim) + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + normed = self.norm(x) + + qkv = self.to_qkv(normed).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + + return self.to_out(out), normed + +class Transformer(Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.layers = ModuleList([]) + + for _ in range(depth): + self.layers.append(ModuleList([ + Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), + FeedForward(dim, mlp_dim, dropout = dropout) + ])) + + def forward(self, x): + + normed_inputs = [] + + for attn, ff in self.layers: + attn_out, attn_normed_inp = attn(x) + x = attn_out + x + + ff_out, ff_normed_inp = ff(x) + x = ff_out + x + + normed_inputs.append(attn_normed_inp) + normed_inputs.append(ff_normed_inp) + + return self.norm(x), stack(normed_inputs) + +class ViT(Module): + def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., decorr_sample_frac = 1.): + super().__init__() + image_height, image_width = pair(image_size) + patch_height, patch_width = pair(patch_size) + + assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' + + num_patches = (image_height // patch_height) * (image_width // patch_width) + patch_dim = channels * patch_height * patch_width + assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' + + self.to_patch_embedding = nn.Sequential( + Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), + nn.LayerNorm(patch_dim), + nn.Linear(patch_dim, dim), + nn.LayerNorm(dim), + ) + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) + self.dropout = nn.Dropout(emb_dropout) + + self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) + + self.pool = pool + self.to_latent = nn.Identity() + + self.mlp_head = nn.Linear(dim, num_classes) + + # decorrelation loss related + + self.has_decorr_loss = decorr_sample_frac > 0. + + if self.has_decorr_loss: + self.decorr_loss = DecorrelationLoss(decorr_sample_frac) + + self.register_buffer('zero', torch.tensor(0.), persistent = False) + + def forward( + self, + img, + return_decorr_aux_loss = None + ): + return_decorr_aux_loss = default(return_decorr_aux_loss, self.training) and self.has_decorr_loss + + x = self.to_patch_embedding(img) + b, n, _ = x.shape + + cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) + x = torch.cat((cls_tokens, x), dim=1) + x += self.pos_embedding[:, :(n + 1)] + x = self.dropout(x) + + x, normed_layer_inputs = self.transformer(x) + + # maybe return decor loss + + decorr_aux_loss = self.zero + + if return_decorr_aux_loss: + decorr_aux_loss = self.decorr_loss(normed_layer_inputs) + + x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] + + x = self.to_latent(x) + return self.mlp_head(x), decorr_aux_loss