Compare commits

...

26 Commits

Author SHA1 Message Date
lucidrains
6aa0374313 register tokens for the AST in VAAT 2025-11-22 08:12:01 -08:00
lucidrains
b35a97de05 improvise a variant of VAT with audio cortex before fully generalizing it 2025-11-22 07:51:19 -08:00
lucidrains
1374b93145 the paper claims finetuning everything was better, but just allow for freezing the visual cortex, what PI proposes 2025-11-09 10:59:55 -08:00
lucidrains
4386742cd1 an option to return zero for decorr aux loss if insufficient samples 2025-11-09 10:08:06 -08:00
lucidrains
5cf8384c56 add a vit with decorrelation auxiliary losses for mha and feedforwards, right after prenorm - this is in line with a paper from the netherlands, but without extra parameters or their manual sgd update scheme 2025-10-28 12:17:32 -07:00
lucidrains
f7d59cecb5 some register tokens cannot hurt for VAT 2025-10-24 14:00:38 -07:00
lucidrains
a583cb5988 last tweak to vat 2025-10-23 12:21:09 -07:00
lucidrains
25871013f5 forgot task conditioning for vat 2025-10-23 10:55:16 -07:00
lucidrains
e66862bcd5 add VAT from iclr 2026, which claims SOTA on libero using a relatively simple scheme (#350) 2025-10-23 10:23:53 -07:00
lucidrains
39fd9ac8be for n-dimensional vit, have a method for fetching muon friendly parameters 2025-10-13 12:07:48 -07:00
lucidrains
3becf087bb have a language model address https://github.com/lucidrains/vit-pytorch/issues/348 2025-09-25 06:21:13 -07:00
lucidrains
f6bc14c81d able to return embed from vit-nd-rotary 2025-09-23 07:21:34 -07:00
lucidrains
845c844b3b add a vit nd with rotary nd, from Jerry Xiong at UIUC 2025-09-21 10:45:42 -07:00
lucidrains
5f2bc0c796 with assistance from claude (yes it did the einops equation building here), generalize to n-dimensions 2025-09-21 06:22:43 -07:00
lucidrains
35bf273037 1.11.7 2025-08-17 18:07:42 -07:00
Baraa sameeh
1123063a5e Make all CCT regularization parameters user-configurable. (#346) 2025-08-17 18:07:25 -07:00
lucidrains
f8bec5ede2 able to project the image embedding before applying time positional embedding for accept video wrapper 2025-08-13 10:15:18 -07:00
lucidrains
297e7d00a2 handle channel first for accept video wrapper 2025-08-03 08:29:40 -07:00
lucidrains
29ac8e143c fix when video time seq len less than max time seq len for video acceptor 2025-07-27 09:00:56 -07:00
lucidrains
e05cd6d8b8 some models only return embeddings with some kwarg on forward 2025-07-27 08:46:43 -07:00
lucidrains
b46233c3d6 need to be able to invoke with eval no grad 2025-07-27 08:25:58 -07:00
lucidrains
68e13a3c7d bit more flexible 2025-07-27 08:14:48 -07:00
lucidrains
b22dc0ecd2 add a wrapper for accepting video and processing the images individually, optionally able to add time positional embeddings - for use in two robotics work 2025-07-27 08:05:48 -07:00
lucidrains
db05a141a6 add the proposed jumbo vit from Fuller et al. of Carleton University 2025-03-05 10:50:34 -08:00
lucidrains
9f49a31977 1.9.2 2025-01-19 05:53:11 -08:00
JacobLinCool
ab63fc9cc8 remove duplicated qkv computation in na_vit_nested_tensor_3d.py (#341) 2025-01-19 05:52:46 -08:00
17 changed files with 2612 additions and 65 deletions

View File

@@ -18,9 +18,9 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v2
uses: actions/setup-python@v5
with:
python-version: '3.x'
- name: Install dependencies

View File

@@ -18,18 +18,17 @@ jobs:
python-version: [3.8, 3.9]
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install pytest
python -m pip install wheel
python -m pip install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cpu
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
python -m pip install -e .
python -m pip install pytest
- name: Test with pytest
run: |
python setup.py test
pytest -q

View File

@@ -2172,4 +2172,45 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```
```bibtex
@inproceedings{Fuller2025SimplerFV,
title = {Simpler Fast Vision Transformers with a Jumbo CLS Token},
author = {Anthony Fuller and Yousef Yassin and Daniel G. Kyrollos and Evan Shelhamer and James R. Green},
year = {2025},
url = {https://api.semanticscholar.org/CorpusID:276557720}
}
```
```bibtex
@misc{xiong2025ndrope,
author = {Jerry Xiong},
title = {On n-dimensional rotary positional embeddings},
year = {2025},
url = {https://jerryxio.ng/posts/nd-rope/}
}
```
```bibtex
@inproceedings{anonymous2025vat,
title = {{VAT}: Vision Action Transformer by Unlocking Full Representation of ViT},
author = {Anonymous},
booktitle = {Submitted to The Fourteenth International Conference on Learning Representations},
year = {2025},
url = {https://openreview.net/forum?id=TalHOvvLZu},
note = {under review}
}
```
```bibtex
@misc{carrigg2025decorrelationspeedsvisiontransformers,
title = {Decorrelation Speeds Up Vision Transformers},
author = {Kieran Carrigg and Rob van Gastel and Melda Yeghaian and Sander Dalm and Faysal Boughorbel and Marcel van Gerven},
year = {2025},
eprint = {2510.14657},
archivePrefix = {arXiv},
primaryClass = {cs.CV},
url = {https://arxiv.org/abs/2510.14657},
}
```
*I visualise a time when we will be to robots what dogs are to humans, and Im rooting for the machines.* — Claude Shannon

63
pyproject.toml Normal file
View File

@@ -0,0 +1,63 @@
[build-system]
requires = ["setuptools>=61", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "vit-pytorch"
version = "1.15.6"
description = "Vision Transformer (ViT) - Pytorch"
readme = { file = "README.md", content-type = "text/markdown" }
license = { file = "LICENSE" }
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" },
]
requires-python = ">=3.8"
keywords = [
"artificial intelligence",
"attention mechanism",
"image recognition",
]
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
]
dependencies = [
"einops>=0.7.0",
"torch>=1.10",
"torchvision",
]
[project.optional-dependencies]
test = [
"pytest",
"torch==2.4.0",
"torchvision==0.19.0",
]
[project.urls]
Homepage = "https://github.com/lucidrains/vit-pytorch"
Repository = "https://github.com/lucidrains/vit-pytorch"
[tool.setuptools]
include-package-data = true
[tool.setuptools.packages.find]
include = ["vit_pytorch*"]
exclude = ["examples*", "tests*", "test*"]
[tool.pytest.ini_options]
testpaths = ["tests", "."]
python_files = ["test_*.py", "*_test.py"]
addopts = "-q"
filterwarnings = [
"ignore::FutureWarning",
]

View File

@@ -1,42 +0,0 @@
from setuptools import setup, find_packages
with open('README.md') as f:
long_description = f.read()
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.9.1',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description = long_description,
long_description_content_type = 'text/markdown',
author = 'Phil Wang',
author_email = 'lucidrains@gmail.com',
url = 'https://github.com/lucidrains/vit-pytorch',
keywords = [
'artificial intelligence',
'attention mechanism',
'image recognition'
],
install_requires=[
'einops>=0.7.0',
'torch>=1.10',
'torchvision'
],
setup_requires=[
'pytest-runner',
],
tests_require=[
'pytest',
'torch==2.4.0',
'torchvision==0.19.0'
],
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)

BIN
tests/.DS_Store vendored Normal file

Binary file not shown.

View File

@@ -1,7 +1,7 @@
import torch
from vit_pytorch import ViT
def test():
def test_vit():
v = ViT(
image_size = 256,
patch_size = 32,

107
train_vit_decorr.py Normal file
View File

@@ -0,0 +1,107 @@
# /// script
# dependencies = [
# "accelerate",
# "vit-pytorch",
# "wandb"
# ]
# ///
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as T
from torchvision.datasets import CIFAR100
# constants
BATCH_SIZE = 32
LEARNING_RATE = 3e-4
EPOCHS = 10
DECORR_LOSS_WEIGHT = 1e-1
TRACK_EXPERIMENT_ONLINE = False
# helpers
def exists(v):
return v is not None
# data
transform = T.Compose([
T.ToTensor(),
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = CIFAR100(
root = 'data',
download = True,
train = True,
transform = transform
)
dataloader = DataLoader(dataset, batch_size = BATCH_SIZE, shuffle = True)
# model
from vit_pytorch.vit_with_decorr import ViT
vit = ViT(
dim = 128,
num_classes = 100,
image_size = 32,
patch_size = 4,
depth = 6,
heads = 8,
dim_head = 64,
mlp_dim = 128 * 4,
decorr_sample_frac = 1. # use all tokens
)
# optim
from torch.optim import Adam
optim = Adam(vit.parameters(), lr = LEARNING_RATE)
# prepare
from accelerate import Accelerator
accelerator = Accelerator()
vit, optim, dataloader = accelerator.prepare(vit, optim, dataloader)
# experiment
import wandb
wandb.init(
project = 'vit-decorr',
mode = 'disabled' if not TRACK_EXPERIMENT_ONLINE else 'online'
)
wandb.run.name = 'baseline'
# loop
for _ in range(EPOCHS):
for images, labels in dataloader:
logits, decorr_aux_loss = vit(images)
loss = F.cross_entropy(logits, labels)
total_loss = (
loss +
decorr_aux_loss * DECORR_LOSS_WEIGHT
)
wandb.log(dict(loss = loss, decorr_loss = decorr_aux_loss))
accelerator.print(f'loss: {loss.item():.3f} | decorr aux loss: {decorr_aux_loss.item():.3f}')
accelerator.backward(total_loss)
optim.step()
optim.zero_grad()

View File

@@ -0,0 +1,161 @@
from contextlib import nullcontext
import torch
from torch import is_tensor, randn
from torch.nn import Module, Linear, Parameter
from torch.utils._pytree import tree_flatten, tree_unflatten
from einops import rearrange, repeat
# helper functions
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
# classes
class AcceptVideoWrapper(Module):
def __init__(
self,
image_net: Module,
forward_function = 'forward',
add_time_pos_emb = False,
dim_emb = None,
time_seq_len = None,
embed_is_channel_first = False,
output_pos_add_pos_emb = 0, # defaults to first output position to add embedding
proj_embed_to_dim = None
):
super().__init__()
self.image_net = image_net
self.forward_function = forward_function # for openclip, used in TRI-LBM
self.add_time_pos_emb = add_time_pos_emb
self.output_pos_add_pos_emb = output_pos_add_pos_emb
# maybe project the image embedding
self.embed_proj = None
if exists(proj_embed_to_dim):
assert exists(dim_emb), '`dim_emb` must be passed in'
self.embed_proj = Linear(dim_emb, proj_embed_to_dim)
# time positional embedding
if add_time_pos_emb:
assert exists(dim_emb) and exists(time_seq_len), '`dim_emb` and `time_seq_len` must be set if adding positional embeddings to the output'
self.time_seq_len = time_seq_len
dim_pos_emb = default(proj_embed_to_dim, dim_emb)
self.pos_emb = Parameter(randn(time_seq_len, dim_pos_emb) * 1e-2)
self.embed_is_channel_first = embed_is_channel_first
def forward(
self,
video, # (b c t h w)
eval_with_no_grad = False,
forward_kwargs = dict()
):
add_time_pos_emb = self.add_time_pos_emb
time = video.shape[2]
# maybe validate time positional embedding
if add_time_pos_emb:
assert time <= self.time_seq_len, f'received video with {time} frames but `time_seq_len` ({self.time_seq_len}) is too low'
video = rearrange(video, 'b c t h w -> b t c h w')
video = rearrange(video, 'b t ... -> (b t) ...')
# forward through image net for outputs
func = getattr(self.image_net, self.forward_function)
if eval_with_no_grad:
self.image_net.eval()
context = torch.no_grad if eval_with_no_grad else nullcontext
with context():
outputs = func(video, **forward_kwargs)
# handle multiple outputs, say logits and embeddings returned from extractor - also handle some reduce aux loss being returned
outputs, tree_spec = tree_flatten(outputs)
outputs = tuple(rearrange(t, '(b t) ... -> b t ...', t = time) if is_tensor(t) and t.numel() > 1 else t for t in outputs)
# maybe project embedding
if exists(self.embed_proj):
outputs = list(outputs)
embed = outputs[self.output_pos_add_pos_emb]
outputs[self.output_pos_add_pos_emb] = self.embed_proj(embed)
# maybe add time positional embedding
if add_time_pos_emb:
outputs = list(outputs)
embed = outputs[self.output_pos_add_pos_emb]
pos_emb = rearrange(self.pos_emb, 't d -> 1 t d')
# handle the network outputting embeddings with spatial dimensions intact - assume embedded dimension is last
dims_to_unsqueeze = embed.ndim - pos_emb.ndim
one_dims = ((1,) * dims_to_unsqueeze)
if self.embed_is_channel_first:
pos_emb = pos_emb.reshape(*pos_emb.shape, *one_dims)
else:
pos_emb = pos_emb.reshape(*pos_emb.shape[:2], *one_dims, pos_emb.shape[-1])
pos_emb = pos_emb[:, :embed.shape[1]]
embed = embed + pos_emb
outputs[self.output_pos_add_pos_emb] = embed
return tree_unflatten(outputs, tree_spec)
# main
if __name__ == '__main__':
from vit_pytorch import ViT
v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
videos = torch.randn(1, 3, 7, 256, 256)
# step up the difficulty and return embeddings for robotics
from vit_pytorch.extractor import Extractor
v = Extractor(v)
video_acceptor = AcceptVideoWrapper(v, add_time_pos_emb = True, output_pos_add_pos_emb = 1, time_seq_len = 12, dim_emb = 1024, proj_embed_to_dim = 512)
logits, embeddings = video_acceptor(videos, eval_with_no_grad = True) # always (batch, channels, time, height, width) - time is always dimension 2
assert logits.shape == (1, 7, 1000)
assert embeddings.shape == (1, 7, 65, 512)

View File

@@ -316,6 +316,9 @@ class CCT(nn.Module):
pooling_kernel_size=3,
pooling_stride=2,
pooling_padding=1,
dropout_rate=0.,
attention_dropout=0.1,
stochastic_depth_rate=0.1,
*args, **kwargs
):
super().__init__()
@@ -340,9 +343,9 @@ class CCT(nn.Module):
width=img_width),
embedding_dim=embedding_dim,
seq_pool=True,
dropout_rate=0.,
attention_dropout=0.1,
stochastic_depth=0.1,
dropout_rate=dropout_rate,
attention_dropout=attention_dropout,
stochastic_depth_rate=stochastic_depth_rate,
*args, **kwargs)
def forward(self, x):

204
vit_pytorch/jumbo_vit.py Normal file
View File

@@ -0,0 +1,204 @@
# Simpler Fast Vision Transformers with a Jumbo CLS Token
# https://arxiv.org/abs/2502.15021
import torch
from torch import nn
from torch.nn import Module, ModuleList
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange
# helpers
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def divisible_by(num, den):
return (num % den) == 0
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
assert divisible_by(dim, 4), "feature dimension must be multiple of 4 for sincos emb"
omega = torch.arange(dim // 4) / (dim // 4 - 1)
omega = temperature ** -omega
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pos_emb = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
return pos_emb.type(dtype)
# classes
def FeedForward(dim, mult = 4.):
hidden_dim = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class JumboViT(Module):
def __init__(
self,
*,
image_size,
patch_size,
num_classes,
dim,
depth,
heads,
mlp_dim,
num_jumbo_cls = 1, # differing from paper, allow for multiple jumbo cls, so one could break it up into 2 jumbo cls tokens with 3x the dim, as an example
jumbo_cls_k = 6, # they use a CLS token with this factor times the dimension - 6 was the value they settled on
jumbo_ff_mult = 2, # expansion factor of the jumbo cls token feedforward
channels = 3,
dim_head = 64
):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert divisible_by(image_height, patch_height) and divisible_by(image_width, patch_width), 'Image dimensions must be divisible by the patch size.'
patch_dim = channels * patch_height * patch_width
self.to_patch_embedding = nn.Sequential(
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.pos_embedding = posemb_sincos_2d(
h = image_height // patch_height,
w = image_width // patch_width,
dim = dim,
)
jumbo_cls_dim = dim * jumbo_cls_k
self.jumbo_cls_token = nn.Parameter(torch.zeros(num_jumbo_cls, jumbo_cls_dim))
jumbo_cls_to_tokens = Rearrange('b n (k d) -> b (n k) d', k = jumbo_cls_k)
self.jumbo_cls_to_tokens = jumbo_cls_to_tokens
self.norm = nn.LayerNorm(dim)
self.layers = ModuleList([])
# attention and feedforwards
self.jumbo_ff = nn.Sequential(
Rearrange('b (n k) d -> b n (k d)', k = jumbo_cls_k),
FeedForward(jumbo_cls_dim, int(jumbo_cls_dim * jumbo_ff_mult)), # they use separate parameters for the jumbo feedforward, weight tied for parameter efficient
jumbo_cls_to_tokens
)
for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head),
FeedForward(dim, mlp_dim),
]))
self.to_latent = nn.Identity()
self.linear_head = nn.Linear(dim, num_classes)
def forward(self, img):
batch, device = img.shape[0], img.device
x = self.to_patch_embedding(img)
# pos embedding
pos_emb = self.pos_embedding.to(device, dtype = x.dtype)
x = x + pos_emb
# add cls tokens
cls_tokens = repeat(self.jumbo_cls_token, 'nj d -> b nj d', b = batch)
jumbo_tokens = self.jumbo_cls_to_tokens(cls_tokens)
x, cls_packed_shape = pack([jumbo_tokens, x], 'b * d')
# attention and feedforwards
for layer, (attn, ff) in enumerate(self.layers, start = 1):
is_last = layer == len(self.layers)
x = attn(x) + x
# jumbo feedforward
jumbo_cls_tokens, x = unpack(x, cls_packed_shape, 'b * d')
x = ff(x) + x
jumbo_cls_tokens = self.jumbo_ff(jumbo_cls_tokens) + jumbo_cls_tokens
if is_last:
continue
x, _ = pack([jumbo_cls_tokens, x], 'b * d')
pooled = reduce(jumbo_cls_tokens, 'b n d -> b d', 'mean')
# normalization and project to logits
embed = self.norm(pooled)
embed = self.to_latent(embed)
logits = self.linear_head(embed)
return logits
# copy pasteable file
if __name__ == '__main__':
v = JumboViT(
num_classes = 1000,
image_size = 64,
patch_size = 8,
dim = 16,
depth = 2,
heads = 2,
mlp_dim = 32,
jumbo_cls_k = 3,
jumbo_ff_mult = 2,
)
images = torch.randn(1, 3, 64, 64)
logits = v(images)
assert logits.shape == (1, 1000)

View File

@@ -83,17 +83,6 @@ class Attention(Module):
# split heads
def split_heads(t):
return t.unflatten(-1, (self.heads, self.dim_head)).transpose(1, 2).contiguous()
# queries, keys, values
query = self.to_queries(x)
key = self.to_keys(context)
value = self.to_values(context)
# split heads
def split_heads(t):
return t.unflatten(-1, (self.heads, self.dim_head))

744
vit_pytorch/vaat.py Normal file
View File

@@ -0,0 +1,744 @@
# vision-audio-action transformer - vaat
from __future__ import annotations
from contextlib import nullcontext
import torch
import torch.nn.functional as F
from torch import nn, cat, stack, arange, tensor
from torch.nn import Module, ModuleList
from torchaudio.transforms import Spectrogram
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange
# helpers
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# 2d sinusoidal positional embedding
# simple vit paper shows it is good enough compared to learned
def posemb_sincos_2d(
patches,
temperature = 10000,
dtype = torch.float32
):
_, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype
y, x = torch.meshgrid(arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
omega = arange(dim // 4, device = device) / (dim // 4 - 1)
omega = temperature ** -omega
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
pe = pe.type(dtype)
return rearrange(pe, '(h w) d -> h w d', h = h, w = w)
# classes
class FiLM(Module):
def __init__(
self,
dim,
):
super().__init__()
proj = nn.Linear(dim, dim * 2)
self.to_gamma_beta = nn.Sequential(
proj,
Rearrange('b (two d) -> two b 1 d', two = 2)
)
nn.init.zeros_(proj.weight)
nn.init.zeros_(proj.bias)
def forward(self, tokens, cond):
gamma, beta = self.to_gamma_beta(cond)
return tokens * gamma + beta
class FeedForward(Module):
def __init__(
self,
dim,
hidden_dim,
dropout = 0.
):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(Module):
def __init__(
self,
dim,
heads = 8,
dim_head = 64,
dropout = 0.,
dim_context = None,
cross_attend = False
):
super().__init__()
dim_context = default(dim_context, dim)
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.cross_attend = cross_attend
self.context_norm = nn.LayerNorm(dim_context) if cross_attend else None
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim_context, inner_dim * 2, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x, context = None):
assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross attending, or vice versa'
x = self.norm(x)
# handle norming of context for cross attention
kv_input = x
if self.cross_attend:
context = self.context_norm(context)
kv_input = context
# project for queries, keys, values
qkv = (self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(Module):
def __init__(
self,
dim,
depth,
heads,
dim_head,
mlp_dim,
dropout = 0.
):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(
self,
x,
return_hiddens = False
):
hiddens = []
for attn, ff in self.layers:
hiddens.append(x)
x = attn(x) + x
x = ff(x) + x
x = self.norm(x)
if not return_hiddens:
return x
return x, hiddens
class AST(Module):
# audio spectrogram transformer https://arxiv.org/abs/2104.01778
def __init__(
self,
dim,
depth,
mlp_dim,
num_classes = None,
patch_size = 16,
dim_head = 64,
heads = 8,
dropout = 0.,
accept_spec = False,
accept_spec_time_first = True,
spec_n_fft = 128,
spec_power = 2,
spec_win_length = 24,
spec_hop_length = None,
spec_pad = 0,
spec_center = True,
spec_pad_mode = 'reflect',
num_register_tokens = 4
):
super().__init__()
self.dim = dim
self.depth = depth
patch_height, patch_width = pair(patch_size)
patch_input_dim = patch_height * patch_width
self.patch_size = (patch_height, patch_width)
self.to_patch_tokens = nn.Sequential(
Rearrange('b (h p1) (w p2) -> b h w (p1 p2)', p1 = self.patch_size[0], p2 = self.patch_size[1]),
nn.LayerNorm(patch_input_dim),
nn.Linear(patch_input_dim, dim),
nn.LayerNorm(dim)
)
self.accept_spec = accept_spec
self.accept_spec_time_first = accept_spec_time_first
self.spec = Spectrogram(
n_fft = spec_n_fft,
power = spec_power,
win_length = spec_win_length,
hop_length = spec_hop_length,
pad = spec_pad,
center = spec_center,
pad_mode = spec_pad_mode
)
self.transformer = Transformer(
dim = dim,
depth = depth,
dim_head = dim_head,
heads = heads,
mlp_dim = mlp_dim,
dropout = dropout,
)
self.final_norm = nn.LayerNorm(dim)
self.mlp_head = nn.Linear(dim, num_classes) if exists(num_classes) else nn.Identity()
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
def forward(
self,
raw_audio_or_spec, # (b t) | (b f t)
return_hiddens = False
):
batch, device = raw_audio_or_spec.shape[0], raw_audio_or_spec.device
assert (self.accept_spec and raw_audio_or_spec.ndim == 3) or (not self.accept_spec and raw_audio_or_spec.ndim == 2)
if self.accept_spec:
spec = rearrange(raw_audio_or_spec, 'b t f -> b f t')
else:
spec = self.spec(raw_audio_or_spec)
# automatically crop if audio does not yield a 2d spectrogram that is divisible by patch sizes
height, width = spec.shape[-2:]
patch_height, patch_width = self.patch_size
rounded_height = height // patch_height * patch_height
rounded_width = width // patch_width * patch_width
spec = spec[..., :rounded_height, :rounded_width]
# to patches
tokens = self.to_patch_tokens(spec)
# get number of patches along height and width
_, num_patch_height, num_patch_width, _ = tokens.shape
# 2d sinusoidal positional embedding
tokens = tokens + posemb_sincos_2d(tokens)
tokens = rearrange(tokens, 'b ... c -> b (...) c')
# register tokens
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
tokens, packed_shape = pack((register_tokens, tokens), 'b * d')
# attention
attended, hiddens = self.transformer(tokens, return_hiddens = True)
# final global average and norm (most recent papers show this is superior to CLS token)
normed = self.final_norm(attended)
if return_hiddens:
return normed, stack(hiddens)
register_tokens, normed = unpack(normed, packed_shape, 'b * d')
pooled = reduce(normed, 'b n d -> b d', 'mean')
maybe_logits = self.mlp_head(pooled)
return maybe_logits
class ViT(Module):
def __init__(
self,
*,
image_size,
patch_size,
num_classes,
dim,
depth,
heads,
mlp_dim,
pool = 'cls',
channels = 3,
dim_head = 64,
dropout = 0.,
emb_dropout = 0.,
num_register_tokens = 0
):
super().__init__()
self.dim = dim
self.depth = depth
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.pos_embedding = nn.Parameter(torch.randn(num_patches, dim))
self.cls_token = nn.Parameter(torch.randn(dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes)
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
def forward(self, img, return_hiddens = False):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
x += self.pos_embedding[:n]
cls_tokens = repeat(self.cls_token, 'd -> b d', b = b)
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = b)
x, packed_shape = pack((register_tokens, cls_tokens, x), 'b * d')
x = self.dropout(x)
x, hiddens = self.transformer(x, return_hiddens = True)
# return the representation trajectory
if return_hiddens:
return x, stack(hiddens)
register_tokens, cls_tokens, x = unpack(x, packed_shape, 'b * d')
x = x.mean(dim = 1) if self.pool == 'mean' else cls_tokens
x = self.to_latent(x)
return self.mlp_head(x)
# proposed VAT
# https://openreview.net/forum?id=TalHOvvLZu
# simple way to get SOTA on Libero dataset (beating fine-tuned pi-zero)
class VAAT(Module):
def __init__(
self,
vit: ViT | dict,
ast: AST | dict,
*,
dim,
depth,
heads,
dim_head,
dim_action,
mlp_dim,
num_views = None,
num_tasks = None,
dim_extra_token = None,
num_register_tokens = 4,
action_chunk_len = 7,
time_seq_len = 1,
dropout = 0.,
add_self_attn = True, # in the paper, they didn't have any ways for the action token to exchange information with the extra token, so we'll just add it as an option
self_attn_heads = 4,
self_attn_dim_head = 32,
ast_layer_indices: tuple[int, ...] | None = None,
vit_layer_indices: tuple[int, ...] | None = None
):
super().__init__()
# vit
if isinstance(vit, dict):
vit = ViT(**vit)
self.vit = vit
vit_dim = vit.dim
assert vit.depth == depth or exists(vit_layer_indices), f'if the VAAT depth is not equal to the ViT depth, you must pass in the indices from the ViT to be layered to the VAAT in order from bottom to top'
vit_layer_indices = default(vit_layer_indices, tuple(range(depth)))
assert len(vit_layer_indices) == depth, f'number of vit layer indices {len(vit_layer_indices)} does not much the VAT depth {depth}'
self.register_buffer('vit_layer_indices', tensor(vit_layer_indices), persistent = False)
# ast
if isinstance(ast, dict):
ast = AST(**ast)
self.ast = ast
ast_dim = ast.dim
assert ast.depth == depth or exists(ast_layer_indices), f'if the VAAT depth is not equal to the AST depth, you must pass in the indices from the AST to be layered to the VAAT in order from bottom to top'
ast_layer_indices = default(ast_layer_indices, tuple(range(depth)))
assert len(ast_layer_indices) == depth, f'number of ast layer indices {len(ast_layer_indices)} does not much the VAAT depth {depth}'
self.register_buffer('ast_layer_indices', tensor(vit_layer_indices), persistent = False)
# handle maybe multiple frames
is_video = time_seq_len > 1
self.is_video = is_video
self.time_seq_len = time_seq_len
self.time_pos_emb = nn.Parameter(torch.randn(time_seq_len, vit_dim) * 1e-2) if is_video else None
# maybe view embeddings
self.view_emb = nn.Parameter(torch.randn(num_views, vit_dim) * 1e-2) if exists(num_views) and num_views > 1 else None
# handle maybe task conditioning
self.has_tasks = exists(num_tasks)
if self.has_tasks:
self.task_emb = nn.Parameter(torch.randn(num_tasks, dim) * 1e-2)
# register tokens from Darcet et al.
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
# to action tokens
self.action_pos_emb = nn.Parameter(torch.randn(action_chunk_len, dim) * 1e-2)
self.layers = ModuleList([])
for _ in range(depth):
maybe_film = FiLM(dim = dim) if self.has_tasks else None
maybe_self_attn = Attention(dim = dim, heads = self_attn_heads, dim_head = self_attn_dim_head, dropout = dropout) if add_self_attn else None
self.layers.append(ModuleList([
maybe_film,
maybe_self_attn,
Attention(dim = dim, dim_context = vit_dim, heads = heads, dim_head = dim_head, dropout = dropout, cross_attend = True),
Attention(dim = dim, dim_context = ast_dim, heads = heads, dim_head = dim_head, dropout = dropout, cross_attend = True),
FeedForward(dim = dim, hidden_dim = mlp_dim, dropout = dropout)
]))
self.final_norm = nn.LayerNorm(dim)
self.to_pred_action = nn.Linear(dim, dim_action, bias = False)
# handle the extra token
self.accept_extra_token = exists(dim_extra_token)
if exists(dim_extra_token):
self.to_extra_token = nn.Linear(dim_extra_token, dim)
def forward(
self,
video_or_image, # (b v? c t? h w) - batch, views [wrist + third person or more], channels, maybe time, height, width
audio_or_spec, # (b t) | (b f t) - batch, audio len | batch, spec freq, time
*,
extra = None, # (b d) - batch, dim extra
tasks = None, # (b)
actions = None, # (b k d) - batch, action chunk length, action dimension
return_hiddens = False,
freeze_vit = False,
freeze_ast = False
):
batch = video_or_image.shape[0]
return_loss = exists(actions)
# handle some various input dimensions
if video_or_image.ndim == 4:
video_or_image = rearrange(video_or_image, 'b 1 c h w')
assert (
(video_or_image.ndim == 5 and not self.is_video) or
(video_or_image.ndim == 6 and self.is_video)
)
if video_or_image.ndim == 5:
video_or_image = rearrange(video_or_image, 'b v c h w -> b v c 1 h w')
assert video_or_image.shape[3] == self.time_seq_len
# to images
images = rearrange(video_or_image, 'b v c t h w -> b v t c h w')
images, packed_shape = pack([images], '* c h w')
# get representation trajectory from vit
vit_forward_context = torch.no_grad if freeze_vit else nullcontext
with vit_forward_context():
embed, hiddens = self.vit(images, return_hiddens = True)
hiddens = cat((hiddens, embed[None, ...]))
# extract the hiddens needed for the action cross attention
hiddens = hiddens[self.vit_layer_indices]
# pack temporarily for embedding
hiddens, = unpack(hiddens, packed_shape, 'l * n d') # l for layers
# maybe add time embeddings
if self.is_video:
time_pos_emb = rearrange(self.time_pos_emb, 't d -> t 1 d')
hiddens = hiddens + time_pos_emb
# maybe view embeddings
if exists(self.view_emb):
assert self.view_emb.shape[0] == hiddens.shape[2]
view_emb = rearrange(self.view_emb, 'v d -> v 1 1 d')
hiddens = hiddens + view_emb
# get representation trajectory from ast
ast_forward_context = torch.no_grad if freeze_ast else nullcontext
with ast_forward_context():
audio_embed, audio_hiddens = self.ast(audio_or_spec, return_hiddens = True)
audio_hiddens = cat((audio_hiddens, audio_embed[None, ...]))
# extract the hiddens needed for the action cross attention
audio_hiddens = audio_hiddens[self.ast_layer_indices]
# maybe tasks
if exists(tasks):
assert self.has_tasks, f'`num_tasks` must be set on `VAT` for task conditioning'
task_emb = self.task_emb[tasks]
# cross from actions to representation trajectory
image_context = rearrange(hiddens, 'l b v t n d -> l b (v t n) d')
audio_context = audio_hiddens # eventually handle views (stereo and beyond)
# get main action tokens and maybe append extra
action_tokens = repeat(self.action_pos_emb, 'k d -> b k d', b = batch)
has_extra = exists(extra)
if has_extra:
assert self.accept_extra_token
extra_token = self.to_extra_token(extra)
action_tokens, packed_extra = pack([action_tokens, extra_token], 'b * d')
# register tokens
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
action_tokens, registers_packed_shape = pack((register_tokens, action_tokens), 'b * d')
# cross attention
hiddens = [action_tokens]
for (maybe_film, maybe_self_attn, image_cross_attn, audio_cross_attn, ff), image_layer_context, audio_layer_context in zip(self.layers, image_context, audio_context):
if exists(tasks):
action_tokens = maybe_film(action_tokens, task_emb)
action_tokens = image_cross_attn(action_tokens, image_layer_context) + action_tokens
action_tokens = audio_cross_attn(action_tokens, audio_layer_context) + action_tokens
if exists(maybe_self_attn):
action_tokens = maybe_self_attn(action_tokens) + action_tokens
action_tokens = ff(action_tokens) + action_tokens
hiddens.append(action_tokens)
# unpack registers
_, action_tokens = unpack(action_tokens, registers_packed_shape, 'b * d')
# maybe unpack extra
if has_extra:
action_tokens, _ = unpack(action_tokens, packed_extra, 'b * d')
# norm and prediction
action_tokens = self.final_norm(action_tokens)
pred_action = self.to_pred_action(action_tokens)
if not return_loss:
if not return_hiddens:
return pred_action
return pred_action, stack(hiddens)
assert pred_action.shape[1] == actions.shape[1]
# they found l1 loss suffices
return F.l1_loss(pred_action, actions)
# quick test
if __name__ == '__main__':
vit = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 384,
heads = 8,
depth = 4,
mlp_dim = 384 * 4
)
ast = AST(
dim = 384,
depth = 4,
heads = 8,
num_classes = 1000,
patch_size = 16,
mlp_dim = 384 * 4
)
vat = VAAT(
vit,
ast,
dim = 512,
depth = 9,
heads = 8,
dim_head = 64,
mlp_dim = 2048,
dim_action = 20,
action_chunk_len = 7,
time_seq_len = 4,
num_views = 2,
num_tasks = 4,
add_self_attn = True,
dim_extra_token = 33, # extra token with some variable dimension
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
0, 0, 1, 1, 2, 2, 3, 3, 4
),
ast_layer_indices = (
1, 1, 1, 2, 2, 2, 3, 3, 3
)
)
images = torch.randn(2, 2, 3, 4, 256, 256) # (2 views with 4 frames)
audio = torch.randn(2, 14_100 * 5)
tasks = torch.randint(0, 4, (2,))
extra = torch.randn(2, 33) # extra internal state
actions = torch.randn(2, 7, 20) # actions for learning
loss = vat(images, audio, actions = actions, tasks = tasks, extra = extra, freeze_vit = True)
loss.backward()
# after much training
pred_actions, hiddens = vat(images, audio, tasks = tasks, extra = extra, return_hiddens = True)
assert pred_actions.shape == (2, 7, 20)

528
vit_pytorch/vat.py Normal file
View File

@@ -0,0 +1,528 @@
from __future__ import annotations
from contextlib import nullcontext
import torch
import torch.nn.functional as F
from torch import nn, cat, stack, tensor
from torch.nn import Module, ModuleList
from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange
# helpers
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# classes
class FiLM(Module):
def __init__(
self,
dim,
):
super().__init__()
proj = nn.Linear(dim, dim * 2)
self.to_gamma_beta = nn.Sequential(
proj,
Rearrange('b (two d) -> two b 1 d', two = 2)
)
nn.init.zeros_(proj.weight)
nn.init.zeros_(proj.bias)
def forward(self, tokens, cond):
gamma, beta = self.to_gamma_beta(cond)
return tokens * gamma + beta
class FeedForward(Module):
def __init__(
self,
dim,
hidden_dim,
dropout = 0.
):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(Module):
def __init__(
self,
dim,
dim_context = None,
heads = 8,
dim_head = 64,
dropout = 0.,
cross_attend = False
):
super().__init__()
dim_context = default(dim_context, dim)
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.cross_attend = cross_attend
self.context_norm = nn.LayerNorm(dim_context) if cross_attend else None
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim_context, inner_dim * 2, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x, context = None):
assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross attending, or vice versa'
x = self.norm(x)
# handle norming of context for cross attention
kv_input = x
if self.cross_attend:
context = self.context_norm(context)
kv_input = context
# project for queries, keys, values
qkv = (self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(Module):
def __init__(
self,
dim,
depth,
heads,
dim_head,
mlp_dim,
dropout = 0.
):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(
self,
x,
return_hiddens = False
):
hiddens = []
for attn, ff in self.layers:
hiddens.append(x)
x = attn(x) + x
x = ff(x) + x
x = self.norm(x)
if not return_hiddens:
return x
return x, hiddens
class ViT(Module):
def __init__(
self,
*,
image_size,
patch_size,
num_classes,
dim,
depth,
heads,
mlp_dim,
pool = 'cls',
channels = 3,
dim_head = 64,
dropout = 0.,
emb_dropout = 0.,
num_register_tokens = 0
):
super().__init__()
self.dim = dim
self.depth = depth
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.pos_embedding = nn.Parameter(torch.randn(num_patches, dim))
self.cls_token = nn.Parameter(torch.randn(dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes)
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
def forward(self, img, return_hiddens = False):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
x += self.pos_embedding[:n]
cls_tokens = repeat(self.cls_token, 'd -> b d', b = b)
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = b)
x, packed_shape = pack((register_tokens, cls_tokens, x), 'b * d')
x = self.dropout(x)
x, hiddens = self.transformer(x, return_hiddens = True)
# return the representation trajectory
if return_hiddens:
return x, stack(hiddens)
register_tokens, cls_tokens, x = unpack(x, packed_shape, 'b * d')
x = x.mean(dim = 1) if self.pool == 'mean' else cls_tokens
x = self.to_latent(x)
return self.mlp_head(x)
# proposed VAT
# https://openreview.net/forum?id=TalHOvvLZu
# simple way to get SOTA on Libero dataset (beating fine-tuned pi-zero)
class VAT(Module):
def __init__(
self,
vit: ViT | dict,
*,
dim,
depth,
heads,
dim_head,
dim_action,
mlp_dim,
num_views = None,
num_tasks = None,
dim_extra_token = None,
num_register_tokens = 4,
action_chunk_len = 7,
time_seq_len = 1,
dropout = 0.,
add_self_attn = True, # in the paper, they didn't have any ways for the action token to exchange information with the extra token, so we'll just add it as an option
self_attn_heads = 4,
self_attn_dim_head = 32,
vit_layer_indices: tuple[int, ...] | None = None
):
super().__init__()
if isinstance(vit, dict):
vit = ViT(**vit)
self.vit = vit
vit_dim = vit.dim
assert vit.depth == depth or exists(vit_layer_indices), f'if the VAT depth is not equal to the ViT depth, you must pass in the indices from the ViT to be layered to the VAT in order from bottom to top'
vit_layer_indices = default(vit_layer_indices, tuple(range(depth)))
assert len(vit_layer_indices) == depth, f'number of vit layer indices {len(vit_layer_indices)} does not much the VAT depth {depth}'
self.register_buffer('layer_indices', tensor(vit_layer_indices), persistent = False)
# handle maybe multiple frames
is_video = time_seq_len > 1
self.is_video = is_video
self.time_seq_len = time_seq_len
self.time_pos_emb = nn.Parameter(torch.randn(time_seq_len, vit_dim) * 1e-2) if is_video else None
# maybe view embeddings
self.view_emb = nn.Parameter(torch.randn(num_views, vit_dim) * 1e-2) if exists(num_views) and num_views > 1 else None
# handle maybe task conditioning
self.has_tasks = exists(num_tasks)
if self.has_tasks:
self.task_emb = nn.Parameter(torch.randn(num_tasks, dim) * 1e-2)
# register tokens from Darcet et al.
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
# to action tokens
self.action_pos_emb = nn.Parameter(torch.randn(action_chunk_len, dim) * 1e-2)
self.layers = ModuleList([])
for _ in range(depth):
maybe_film = FiLM(dim = dim) if self.has_tasks else None
maybe_self_attn = Attention(dim = dim, heads = self_attn_heads, dim_head = self_attn_dim_head, dropout = dropout) if add_self_attn else None
self.layers.append(ModuleList([
maybe_film,
maybe_self_attn,
Attention(dim = dim, dim_context = vit_dim, heads = heads, dim_head = dim_head, dropout = dropout, cross_attend = True),
FeedForward(dim = dim, hidden_dim = mlp_dim, dropout = dropout)
]))
self.final_norm = nn.LayerNorm(dim)
self.to_pred_action = nn.Linear(dim, dim_action, bias = False)
# handle the extra token
self.accept_extra_token = exists(dim_extra_token)
if exists(dim_extra_token):
self.to_extra_token = nn.Linear(dim_extra_token, dim)
def forward(
self,
video_or_image, # (b v? c t? h w) - batch, views [wrist + third person or more], channels, maybe time, height, width
*,
extra = None, # (b d) - batch, dim extra
tasks = None, # (b)
actions = None, # (b k d) - batch, action chunk length, action dimension
return_hiddens = False,
freeze_vit = False
):
batch = video_or_image.shape[0]
return_loss = exists(actions)
# handle some various input dimensions
if video_or_image.ndim == 4:
video_or_image = rearrange(video_or_image, 'b 1 c h w')
assert (
(video_or_image.ndim == 5 and not self.is_video) or
(video_or_image.ndim == 6 and self.is_video)
)
if video_or_image.ndim == 5:
video_or_image = rearrange(video_or_image, 'b v c h w -> b v c 1 h w')
assert video_or_image.shape[3] == self.time_seq_len
# to images
images = rearrange(video_or_image, 'b v c t h w -> b v t c h w')
images, packed_shape = pack([images], '* c h w')
# get representation trajectory from vit
vit_forward_context = torch.no_grad if freeze_vit else nullcontext
with vit_forward_context():
embed, hiddens = self.vit(images, return_hiddens = True)
hiddens = cat((hiddens, embed[None, ...]))
# extract the hiddens needed for the action cross attention
hiddens = hiddens[self.layer_indices]
# pack temporarily for embedding
hiddens, = unpack(hiddens, packed_shape, 'l * n d') # l for layers
# maybe add time embeddings
if self.is_video:
time_pos_emb = rearrange(self.time_pos_emb, 't d -> t 1 d')
hiddens = hiddens + time_pos_emb
# maybe view embeddings
if exists(self.view_emb):
assert self.view_emb.shape[0] == hiddens.shape[2]
view_emb = rearrange(self.view_emb, 'v d -> v 1 1 d')
hiddens = hiddens + view_emb
# maybe tasks
if exists(tasks):
assert self.has_tasks, f'`num_tasks` must be set on `VAT` for task conditioning'
task_emb = self.task_emb[tasks]
# cross from actions to representation trajectory
context = rearrange(hiddens, 'l b v t n d -> l b (v t n) d')
# get main action tokens and maybe append extra
action_tokens = repeat(self.action_pos_emb, 'k d -> b k d', b = batch)
has_extra = exists(extra)
if has_extra:
assert self.accept_extra_token
extra_token = self.to_extra_token(extra)
action_tokens, packed_extra = pack([action_tokens, extra_token], 'b * d')
# register tokens
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
action_tokens, registers_packed_shape = pack((register_tokens, action_tokens), 'b * d')
# cross attention
hiddens = [action_tokens]
for (maybe_film, maybe_self_attn, cross_attn, ff), layer_context in zip(self.layers, context):
if exists(tasks):
action_tokens = maybe_film(action_tokens, task_emb)
action_tokens = cross_attn(action_tokens, layer_context) + action_tokens
if exists(maybe_self_attn):
action_tokens = maybe_self_attn(action_tokens) + action_tokens
action_tokens = ff(action_tokens) + action_tokens
hiddens.append(action_tokens)
# unpack registers
_, action_tokens = unpack(action_tokens, registers_packed_shape, 'b * d')
# maybe unpack extra
if has_extra:
action_tokens, _ = unpack(action_tokens, packed_extra, 'b * d')
# norm and prediction
action_tokens = self.final_norm(action_tokens)
pred_action = self.to_pred_action(action_tokens)
if not return_loss:
if not return_hiddens:
return pred_action
return pred_action, stack(hiddens)
assert pred_action.shape[1] == actions.shape[1]
# they found l1 loss suffices
return F.l1_loss(pred_action, actions)
# quick test
if __name__ == '__main__':
vit = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 256,
heads = 8,
depth = 4,
mlp_dim = 1024
)
vat = VAT(
vit,
dim = 512,
depth = 9,
heads = 8,
dim_head = 64,
mlp_dim = 2048,
dim_action = 20,
action_chunk_len = 7,
time_seq_len = 4,
num_views = 2,
num_tasks = 4,
add_self_attn = True,
dim_extra_token = 33, # extra token with some variable dimension
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
0, 0, 1, 1, 2, 2, 3, 3, 4
)
)
images = torch.randn(2, 2, 3, 4, 256, 256) # (2 views with 4 frames)
tasks = torch.randint(0, 4, (2,))
extra = torch.randn(2, 33) # extra internal state
actions = torch.randn(2, 7, 20) # actions for learning
loss = vat(images, actions = actions, tasks = tasks, extra = extra, freeze_vit = True)
loss.backward()
# after much training
pred_actions, hiddens = vat(images, tasks = tasks, extra = extra, return_hiddens = True)
assert pred_actions.shape == (2, 7, 20)

191
vit_pytorch/vit_nd.py Normal file
View File

@@ -0,0 +1,191 @@
from __future__ import annotations
import torch
from torch import nn
from torch.nn import Module
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
def join(arr, delimiter = ' '):
return delimiter.join(arr)
def ensure_tuple(t, length):
if isinstance(t, (tuple, list)):
assert len(t) == length, f'Expected tuple of length {length}, got {len(t)}'
return tuple(t)
return (t,) * length
# classes
class FeedForward(Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)
class ViTND(Module):
def __init__(
self,
*,
ndim: int,
input_shape: int | tuple[int, ...],
patch_size: int | tuple[int, ...],
num_classes: int,
dim: int,
depth: int,
heads: int,
mlp_dim: int,
pool: str = 'cls',
channels: int = 3,
dim_head: int = 64,
dropout: float = 0.,
emb_dropout: float = 0.
):
super().__init__()
assert 1 <= ndim <= 7, 'ndim must be between 1 and 7'
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.ndim = ndim
self.pool = pool
input_shape = ensure_tuple(input_shape, ndim)
patch_size = ensure_tuple(patch_size, ndim)
for i, (inp_dim, patch_dim) in enumerate(zip(input_shape, patch_size)):
assert inp_dim % patch_dim == 0, f'Input dimension {i} ({inp_dim}) must be divisible by patch size ({patch_dim})'
num_patches_per_dim = [inp_dim // patch_dim for inp_dim, patch_dim in zip(input_shape, patch_size)]
num_patches = 1
for n in num_patches_per_dim:
num_patches *= n
patch_dim = channels
for p in patch_size:
patch_dim *= p
dim_names = 'fghijkl'[:ndim]
input_dims = [f'({d} p{i})' for i, d in enumerate(dim_names)]
patch_dims = [f'p{i}' for i in range(ndim)]
input_pattern = f'b c {join(input_dims)}'
output_pattern = f'b ({join(dim_names)}) ({join(patch_dims)} c)'
rearrange_str = f'{input_pattern} -> {output_pattern}'
rearrange_kwargs = {f'p{i}': p for i, p in enumerate(patch_size)}
self.to_patch_embedding = nn.Sequential(
Rearrange(rearrange_str, **rearrange_kwargs),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes)
def forward(self, x):
x = self.to_patch_embedding(x)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim = 1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x)
x = x[:, 1:].mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)
if __name__ == '__main__':
model = ViTND(
ndim = 4,
input_shape = (8, 16, 32, 64),
patch_size = (2, 4, 4, 8),
num_classes = 1000,
dim = 512,
depth = 6,
heads = 8,
mlp_dim = 2048,
channels = 3,
dropout = 0.1,
emb_dropout = 0.1
)
occupancy_time = torch.randn(2, 3, 8, 16, 32, 64)
logits = model(occupancy_time)

View File

@@ -0,0 +1,325 @@
from __future__ import annotations
import torch
from torch import nn, arange, cat, stack, Tensor
from torch.nn import Module, ModuleList
import torch.nn.functional as F
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange
# helpers
def exists(val):
return val is not None
def l2norm(t):
return F.normalize(t, dim = -1, p = 2)
def join(arr, delimiter = ' '):
return delimiter.join(arr)
def ensure_tuple(t, length):
if isinstance(t, (tuple, list)):
assert len(t) == length, f'Expected tuple of length {length}, got {len(t)}'
return tuple(t)
return (t,) * length
# golden gate rotary - Jerry Xiong, PhD student at UIUC
# https://jerryxio.ng/posts/nd-rope/
def _phi(m: int) -> float:
x = 2.0
for _ in range(10):
x = (1 + x) ** (1.0 / (m + 1.0))
return x
def make_directions(n: int, d: int) -> Tensor:
g = _phi(d)
alpha = (1.0 / g) ** arange(1, d + 1, dtype = torch.float64)
i = arange(1, n + 1, dtype = torch.float64).unsqueeze(1)
z = torch.fmod(i * alpha, 1.0)
directions = torch.erfinv(2.0 * z - 1.0)
directions = l2norm(directions)
return directions.float()
class GoldenGateRoPENd(Module):
def __init__(
self,
dim_pos: int,
heads: int,
dim_head: int,
rope_min_freq: float = 1.0,
rope_max_freq: float = 10000.0,
rope_p_zero_freqs: float = 0.0, # proportion of frequencies set to 0
):
super().__init__()
n_freqs = dim_head // 2
n_zero_freqs = round(rope_p_zero_freqs * n_freqs)
omega = cat((
torch.zeros(n_zero_freqs),
rope_min_freq * (rope_max_freq / rope_min_freq) ** torch.linspace(0, 1, n_freqs - n_zero_freqs),
))
directions = rearrange(
make_directions(heads * n_freqs, dim_pos),
'(h f) p -> h f p',
h = heads
)
omega_expanded = rearrange(omega, 'f -> f 1')
self.register_buffer('freqs', directions * omega_expanded) # shape: (h, f, p)
def forward(self, input: Tensor, pos: Tensor) -> Tensor:
# input shape: (b, h, n, d) where d = head_dim
# pos shape: (b, n, p) where p = pos_dim
# self.freqs shape: (h, f, p) where f = d // 2
x, y = input.float().chunk(2, dim = -1) # both (b, h, n, f)
# Expand dimensions for broadcasting
freqs = rearrange(self.freqs, 'h f p -> 1 h 1 f p')
positions = rearrange(pos.float(), 'b n p -> b 1 n 1 p')
# Compute theta for each (batch, head, seq, freq)
theta = reduce(freqs * positions, 'b h n f p -> b h n f', 'sum')
cos_theta = torch.cos(theta)
sin_theta = torch.sin(theta)
# Apply rotation
x_out = x * cos_theta - y * sin_theta
y_out = x * sin_theta + y * cos_theta
output = cat((x_out, y_out), dim=-1)
return output.type_as(input)
# classes
class FeedForward(Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., rotary_emb = None):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.rotary_emb = rotary_emb
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qk = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_v = nn.Linear(dim, inner_dim, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x, pos = None):
x = self.norm(x)
qkv = (*self.to_qk(x).chunk(2, dim = -1), self.to_v(x))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
# Apply rotary embeddings if available
if exists(self.rotary_emb):
assert exists(pos)
q = self.rotary_emb(q, pos)
k = self.rotary_emb(k, pos)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., rotary_emb = None):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, rotary_emb = rotary_emb),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x, pos = None):
for attn, ff in self.layers:
x = attn(x, pos) + x
x = ff(x) + x
return self.norm(x)
class ViTND(Module):
def __init__(
self,
*,
ndim: int,
input_shape: int | tuple[int, ...],
patch_size: int | tuple[int, ...],
num_classes: int,
dim: int,
depth: int,
heads: int,
mlp_dim: int,
channels: int = 3,
dim_head: int = 64,
dropout: float = 0.,
emb_dropout: float = 0.,
rope_min_freq: float = 1.0,
rope_max_freq: float = 10000.0,
rope_p_zero_freqs: float = 0.0
):
super().__init__()
assert 1 <= ndim <= 7, 'ndim must be between 1 and 7'
self.ndim = ndim
input_shape = ensure_tuple(input_shape, ndim)
patch_size = ensure_tuple(patch_size, ndim)
for i, (inp_dim, patch_dim) in enumerate(zip(input_shape, patch_size)):
assert inp_dim % patch_dim == 0, f'Input dimension {i} ({inp_dim}) must be divisible by patch size ({patch_dim})'
num_patches_per_dim = [inp_dim // patch_dim for inp_dim, patch_dim in zip(input_shape, patch_size)]
num_patches = 1
for n in num_patches_per_dim:
num_patches *= n
patch_dim = channels
for p in patch_size:
patch_dim *= p
dim_names = 'fghijkl'[:ndim]
input_dims = [f'({d} p{i})' for i, d in enumerate(dim_names)]
patch_dims = [f'p{i}' for i in range(ndim)]
input_pattern = f'b c {join(input_dims)}'
output_pattern = f'b {join(dim_names)} ({join(patch_dims)} c)'
rearrange_str = f'{input_pattern} -> {output_pattern}'
rearrange_kwargs = {f'p{i}': p for i, p in enumerate(patch_size)}
self.to_patch_embedding = nn.Sequential(
Rearrange(rearrange_str, **rearrange_kwargs),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.dropout = nn.Dropout(emb_dropout)
# Create rotary embeddings
self.rotary_emb = GoldenGateRoPENd(
dim_pos = ndim,
heads = heads,
dim_head = dim_head,
rope_min_freq = rope_min_freq,
rope_max_freq = rope_max_freq,
rope_p_zero_freqs = rope_p_zero_freqs
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, rotary_emb = self.rotary_emb)
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes)
def muon_parameters(self):
params = []
for m in self.modules():
if isinstance(m, Attention):
params.extend([
m.to_v.weight,
m.to_out[0].weight
])
elif isinstance(m, FeedForward):
params.extend([
m.net[1].weight,
m.net[-2].weight
])
return params
def forward(
self,
x,
return_embed = False
):
x = self.to_patch_embedding(x) # (b, *spatial_dims, patch_dim)
batch, *spatial_dims, _, device = *x.shape, x.device
# Generate position coordinates
grids = [arange(d, device = device, dtype = torch.float32) for d in spatial_dims]
grid = torch.meshgrid(*grids, indexing = 'ij')
pos = stack(grid, dim = -1) # (*spatial_dims, ndim)
# flatten spatial dimensions for attention with nd rotary
pos = repeat(pos, '... p -> b (...) p', b = batch)
x, packed_shape = pack([x], 'b * d')
x = self.dropout(x)
embed = self.transformer(x, pos)
# return the embed with reconstituted patch shape
if return_embed:
embed, = unpack(embed, packed_shape, 'b * d')
return embed
# pooling to logits
pooled = reduce(embed, 'b n d -> b d', 'mean')
pooled = self.to_latent(pooled)
return self.mlp_head(pooled)
if __name__ == '__main__':
model = ViTND(
ndim = 5,
input_shape = (4, 8, 16, 32, 64),
patch_size = (2, 2, 4, 4, 8),
num_classes = 1000,
dim = 512,
depth = 6,
heads = 8,
mlp_dim = 2048,
channels = 3,
dropout = 0.1,
emb_dropout = 0.1
)
data = torch.randn(2, 3, 4, 8, 16, 32, 64)
logits = model(data)
embed = model(data, return_embed = True) # (2, 2, 4, 4, 8, 8, 512)

View File

@@ -0,0 +1,234 @@
# https://arxiv.org/abs/2510.14657
# but instead of their decorr module updated with SGD, remove all projections and just return a decorrelation auxiliary loss
import torch
from torch import nn, stack, tensor
import torch.nn.functional as F
from torch.nn import Module, ModuleList
from einops import rearrange, repeat, reduce, einsum, pack, unpack
from einops.layers.torch import Rearrange
# helpers
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# decorr loss
class DecorrelationLoss(Module):
def __init__(
self,
sample_frac = 1.,
soft_validate_num_sampled = False
):
super().__init__()
assert 0. <= sample_frac <= 1.
self.need_sample = sample_frac < 1.
self.sample_frac = sample_frac
self.soft_validate_num_sampled = soft_validate_num_sampled
self.register_buffer('zero', tensor(0.), persistent = False)
def forward(
self,
tokens
):
batch, seq_len, dim, device = *tokens.shape[-3:], tokens.device
if self.need_sample:
num_sampled = int(seq_len * self.sample_frac)
assert self.soft_validate_num_sampled or num_sampled >= 2.
if num_sampled <= 1:
return self.zero
tokens, packed_shape = pack([tokens], '* n d e')
indices = torch.randn(tokens.shape[:2]).argsort(dim = -1)[..., :num_sampled, :]
batch_arange = torch.arange(tokens.shape[0], device = tokens.device)
batch_arange = rearrange(batch_arange, 'b -> b 1')
tokens = tokens[batch_arange, indices]
tokens, = unpack(tokens, packed_shape, '* n d e')
dist = einsum(tokens, tokens, '... n d, ... n e -> ... d e') / tokens.shape[-2]
eye = torch.eye(dim, device = device)
loss = dist.pow(2) * (1. - eye) / ((dim - 1) * dim)
loss = reduce(loss, '... b d e -> b', 'sum')
return loss.mean()
# classes
class FeedForward(Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
normed = self.norm(x)
return self.net(x), normed
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.norm = nn.LayerNorm(dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
normed = self.norm(x)
qkv = self.to_qkv(normed).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out), normed
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
normed_inputs = []
for attn, ff in self.layers:
attn_out, attn_normed_inp = attn(x)
x = attn_out + x
ff_out, ff_normed_inp = ff(x)
x = ff_out + x
normed_inputs.append(attn_normed_inp)
normed_inputs.append(ff_normed_inp)
return self.norm(x), stack(normed_inputs)
class ViT(Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., decorr_sample_frac = 1.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes)
# decorrelation loss related
self.has_decorr_loss = decorr_sample_frac > 0.
if self.has_decorr_loss:
self.decorr_loss = DecorrelationLoss(decorr_sample_frac)
self.register_buffer('zero', torch.tensor(0.), persistent = False)
def forward(
self,
img,
return_decorr_aux_loss = None
):
return_decorr_aux_loss = default(return_decorr_aux_loss, self.training) and self.has_decorr_loss
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x, normed_layer_inputs = self.transformer(x)
# maybe return decor loss
decorr_aux_loss = self.zero
if return_decorr_aux_loss:
decorr_aux_loss = self.decorr_loss(normed_layer_inputs)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x), decorr_aux_loss
# quick test
if __name__ == '__main__':
decorr_loss = DecorrelationLoss(0.1)
hiddens = torch.randn(6, 2, 512, 256)
decorr_loss(hiddens)
decorr_loss(hiddens[0])
decorr_loss = DecorrelationLoss(0.0001, soft_validate_num_sampled = True)
out = decorr_loss(hiddens)
assert out.item() == 0