Compare commits

..

1 Commits

Author SHA1 Message Date
Phil Wang
771fb6daaf Nested navit (#325)
add a variant of NaViT using nested tensors
2024-08-20 15:07:20 -07:00
22 changed files with 88 additions and 2586 deletions

View File

@@ -18,9 +18,9 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v5
uses: actions/setup-python@v2
with:
python-version: '3.x'
- name: Install dependencies

View File

@@ -18,17 +18,17 @@ jobs:
python-version: [3.8, 3.9]
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cpu
python -m pip install -e .
python -m pip install pytest
python -m pip install wheel
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Test with pytest
run: |
pytest -q
python setup.py test

View File

@@ -198,7 +198,7 @@ preds = v(
) # (5, 1000)
```
Finally, if you would like to make use of a flavor of NaViT using <a href="https://pytorch.org/tutorials/prototype/nestedtensor.html">nested tensors</a> (which will omit a lot of the masking and padding altogether), make sure you are on version `2.5` and import as follows
Finally, if you would like to make use of a flavor of NaViT using <a href="https://pytorch.org/tutorials/prototype/nestedtensor.html">nested tensors</a> (which will omit a lot of the masking and padding altogether), make sure you are on version `2.4` and import as follows
```python
import torch
@@ -1218,8 +1218,7 @@ pred = cct(video)
<img src="./images/vivit.png" width="350px"></img>
This <a href="https://arxiv.org/abs/2103.15691">paper</a> offers 3 different types of architectures for efficient attention of videos, with the main theme being factorizing the attention across space and time. This repository includes the factorized encoder and the factorized self-attention variant.
The factorized encoder variant is a spatial transformer followed by a temporal one. The factorized self-attention variant is a spatio-temporal transformer with alternating spatial and temporal self-attention layers.
This <a href="https://arxiv.org/abs/2103.15691">paper</a> offers 3 different types of architectures for efficient attention of videos, with the main theme being factorizing the attention across space and time. This repository will offer the first variant, which is a spatial transformer followed by a temporal one.
```python
import torch
@@ -1235,8 +1234,7 @@ v = ViT(
spatial_depth = 6, # depth of the spatial transformer
temporal_depth = 6, # depth of the temporal transformer
heads = 8,
mlp_dim = 2048,
variant = 'factorized_encoder', # or 'factorized_self_attention'
mlp_dim = 2048
)
video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width)
@@ -2133,72 +2131,4 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```
```bibtex
@inproceedings{Loshchilov2024nGPTNT,
title = {nGPT: Normalized Transformer with Representation Learning on the Hypersphere},
author = {Ilya Loshchilov and Cheng-Ping Hsieh and Simeng Sun and Boris Ginsburg},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:273026160}
}
```
```bibtex
@inproceedings{Liu2017DeepHL,
title = {Deep Hyperspherical Learning},
author = {Weiyang Liu and Yanming Zhang and Xingguo Li and Zhen Liu and Bo Dai and Tuo Zhao and Le Song},
booktitle = {Neural Information Processing Systems},
year = {2017},
url = {https://api.semanticscholar.org/CorpusID:5104558}
}
```
```bibtex
@inproceedings{Zhou2024ValueRL,
title = {Value Residual Learning For Alleviating Attention Concentration In Transformers},
author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:273532030}
}
```
```bibtex
@article{Zhu2024HyperConnections,
title = {Hyper-Connections},
author = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou},
journal = {ArXiv},
year = {2024},
volume = {abs/2409.19606},
url = {https://api.semanticscholar.org/CorpusID:272987528}
}
```
```bibtex
@inproceedings{Fuller2025SimplerFV,
title = {Simpler Fast Vision Transformers with a Jumbo CLS Token},
author = {Anthony Fuller and Yousef Yassin and Daniel G. Kyrollos and Evan Shelhamer and James R. Green},
year = {2025},
url = {https://api.semanticscholar.org/CorpusID:276557720}
}
```
```bibtex
@misc{xiong2025ndrope,
author = {Jerry Xiong},
title = {On n-dimensional rotary positional embeddings},
year = {2025},
url = {https://jerryxio.ng/posts/nd-rope/}
}
```
```bibtex
@inproceedings{anonymous2025vat,
title = {{VAT}: Vision Action Transformer by Unlocking Full Representation of ViT},
author = {Anonymous},
booktitle = {Submitted to The Fourteenth International Conference on Learning Representations},
year = {2025},
url = {https://openreview.net/forum?id=TalHOvvLZu},
note = {under review}
}
```
*I visualise a time when we will be to robots what dogs are to humans, and Im rooting for the machines.* — Claude Shannon

View File

@@ -1,63 +0,0 @@
[build-system]
requires = ["setuptools>=61", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "vit-pytorch"
version = "1.14.0"
description = "Vision Transformer (ViT) - Pytorch"
readme = { file = "README.md", content-type = "text/markdown" }
license = { file = "LICENSE" }
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" },
]
requires-python = ">=3.8"
keywords = [
"artificial intelligence",
"attention mechanism",
"image recognition",
]
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
]
dependencies = [
"einops>=0.7.0",
"torch>=1.10",
"torchvision",
]
[project.optional-dependencies]
test = [
"pytest",
"torch==2.4.0",
"torchvision==0.19.0",
]
[project.urls]
Homepage = "https://github.com/lucidrains/vit-pytorch"
Repository = "https://github.com/lucidrains/vit-pytorch"
[tool.setuptools]
include-package-data = true
[tool.setuptools.packages.find]
include = ["vit_pytorch*"]
exclude = ["examples*", "tests*", "test*"]
[tool.pytest.ini_options]
testpaths = ["tests", "."]
python_files = ["test_*.py", "*_test.py"]
addopts = "-q"
filterwarnings = [
"ignore::FutureWarning",
]

42
setup.py Normal file
View File

@@ -0,0 +1,42 @@
from setuptools import setup, find_packages
with open('README.md') as f:
long_description = f.read()
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.7.6',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description=long_description,
long_description_content_type = 'text/markdown',
author = 'Phil Wang',
author_email = 'lucidrains@gmail.com',
url = 'https://github.com/lucidrains/vit-pytorch',
keywords = [
'artificial intelligence',
'attention mechanism',
'image recognition'
],
install_requires=[
'einops>=0.7.0',
'torch>=1.10',
'torchvision'
],
setup_requires=[
'pytest-runner',
],
tests_require=[
'pytest',
'torch==1.12.1',
'torchvision==0.13.1'
],
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)

BIN
tests/.DS_Store vendored

Binary file not shown.

View File

@@ -1,7 +1,7 @@
import torch
from vit_pytorch import ViT
def test_vit():
def test():
v = ViT(
image_size = 256,
patch_size = 32,

View File

@@ -1,161 +0,0 @@
from contextlib import nullcontext
import torch
from torch import is_tensor, randn
from torch.nn import Module, Linear, Parameter
from torch.utils._pytree import tree_flatten, tree_unflatten
from einops import rearrange, repeat
# helper functions
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
# classes
class AcceptVideoWrapper(Module):
def __init__(
self,
image_net: Module,
forward_function = 'forward',
add_time_pos_emb = False,
dim_emb = None,
time_seq_len = None,
embed_is_channel_first = False,
output_pos_add_pos_emb = 0, # defaults to first output position to add embedding
proj_embed_to_dim = None
):
super().__init__()
self.image_net = image_net
self.forward_function = forward_function # for openclip, used in TRI-LBM
self.add_time_pos_emb = add_time_pos_emb
self.output_pos_add_pos_emb = output_pos_add_pos_emb
# maybe project the image embedding
self.embed_proj = None
if exists(proj_embed_to_dim):
assert exists(dim_emb), '`dim_emb` must be passed in'
self.embed_proj = Linear(dim_emb, proj_embed_to_dim)
# time positional embedding
if add_time_pos_emb:
assert exists(dim_emb) and exists(time_seq_len), '`dim_emb` and `time_seq_len` must be set if adding positional embeddings to the output'
self.time_seq_len = time_seq_len
dim_pos_emb = default(proj_embed_to_dim, dim_emb)
self.pos_emb = Parameter(randn(time_seq_len, dim_pos_emb) * 1e-2)
self.embed_is_channel_first = embed_is_channel_first
def forward(
self,
video, # (b c t h w)
eval_with_no_grad = False,
forward_kwargs = dict()
):
add_time_pos_emb = self.add_time_pos_emb
time = video.shape[2]
# maybe validate time positional embedding
if add_time_pos_emb:
assert time <= self.time_seq_len, f'received video with {time} frames but `time_seq_len` ({self.time_seq_len}) is too low'
video = rearrange(video, 'b c t h w -> b t c h w')
video = rearrange(video, 'b t ... -> (b t) ...')
# forward through image net for outputs
func = getattr(self.image_net, self.forward_function)
if eval_with_no_grad:
self.image_net.eval()
context = torch.no_grad if eval_with_no_grad else nullcontext
with context():
outputs = func(video, **forward_kwargs)
# handle multiple outputs, say logits and embeddings returned from extractor - also handle some reduce aux loss being returned
outputs, tree_spec = tree_flatten(outputs)
outputs = tuple(rearrange(t, '(b t) ... -> b t ...', t = time) if is_tensor(t) and t.numel() > 1 else t for t in outputs)
# maybe project embedding
if exists(self.embed_proj):
outputs = list(outputs)
embed = outputs[self.output_pos_add_pos_emb]
outputs[self.output_pos_add_pos_emb] = self.embed_proj(embed)
# maybe add time positional embedding
if add_time_pos_emb:
outputs = list(outputs)
embed = outputs[self.output_pos_add_pos_emb]
pos_emb = rearrange(self.pos_emb, 't d -> 1 t d')
# handle the network outputting embeddings with spatial dimensions intact - assume embedded dimension is last
dims_to_unsqueeze = embed.ndim - pos_emb.ndim
one_dims = ((1,) * dims_to_unsqueeze)
if self.embed_is_channel_first:
pos_emb = pos_emb.reshape(*pos_emb.shape, *one_dims)
else:
pos_emb = pos_emb.reshape(*pos_emb.shape[:2], *one_dims, pos_emb.shape[-1])
pos_emb = pos_emb[:, :embed.shape[1]]
embed = embed + pos_emb
outputs[self.output_pos_add_pos_emb] = embed
return tree_unflatten(outputs, tree_spec)
# main
if __name__ == '__main__':
from vit_pytorch import ViT
v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
videos = torch.randn(1, 3, 7, 256, 256)
# step up the difficulty and return embeddings for robotics
from vit_pytorch.extractor import Extractor
v = Extractor(v)
video_acceptor = AcceptVideoWrapper(v, add_time_pos_emb = True, output_pos_add_pos_emb = 1, time_seq_len = 12, dim_emb = 1024, proj_embed_to_dim = 512)
logits, embeddings = video_acceptor(videos, eval_with_no_grad = True) # always (batch, channels, time, height, width) - time is always dimension 2
assert logits.shape == (1, 7, 1000)
assert embeddings.shape == (1, 7, 65, 512)

View File

@@ -316,9 +316,6 @@ class CCT(nn.Module):
pooling_kernel_size=3,
pooling_stride=2,
pooling_padding=1,
dropout_rate=0.,
attention_dropout=0.1,
stochastic_depth_rate=0.1,
*args, **kwargs
):
super().__init__()
@@ -343,9 +340,9 @@ class CCT(nn.Module):
width=img_width),
embedding_dim=embedding_dim,
seq_pool=True,
dropout_rate=dropout_rate,
attention_dropout=attention_dropout,
stochastic_depth_rate=stochastic_depth_rate,
dropout_rate=0.,
attention_dropout=0.1,
stochastic_depth=0.1,
*args, **kwargs)
def forward(self, x):

View File

@@ -167,10 +167,8 @@ class Tokenizer(nn.Module):
stride,
padding,
frame_stride=1,
frame_padding=None,
frame_pooling_stride=1,
frame_pooling_kernel_size=1,
frame_pooling_padding=None,
pooling_kernel_size=3,
pooling_stride=2,
pooling_padding=1,
@@ -190,22 +188,16 @@ class Tokenizer(nn.Module):
n_filter_list_pairs = zip(n_filter_list[:-1], n_filter_list[1:])
if frame_padding is None:
frame_padding = frame_kernel_size // 2
if frame_pooling_padding is None:
frame_pooling_padding = frame_pooling_kernel_size // 2
self.conv_layers = nn.Sequential(
*[nn.Sequential(
nn.Conv3d(chan_in, chan_out,
kernel_size=(frame_kernel_size, kernel_size, kernel_size),
stride=(frame_stride, stride, stride),
padding=(frame_padding, padding, padding), bias=conv_bias),
padding=(frame_kernel_size // 2, padding, padding), bias=conv_bias),
nn.Identity() if not exists(activation) else activation(),
nn.MaxPool3d(kernel_size=(frame_pooling_kernel_size, pooling_kernel_size, pooling_kernel_size),
stride=(frame_pooling_stride, pooling_stride, pooling_stride),
padding=(frame_pooling_padding, pooling_padding, pooling_padding)) if max_pool else nn.Identity()
padding=(frame_pooling_kernel_size // 2, pooling_padding, pooling_padding)) if max_pool else nn.Identity()
)
for chan_in, chan_out in n_filter_list_pairs
])
@@ -332,10 +324,8 @@ class CCT(nn.Module):
n_conv_layers=1,
frame_stride=1,
frame_kernel_size=3,
frame_padding=None,
frame_pooling_kernel_size=1,
frame_pooling_stride=1,
frame_pooling_padding=None,
kernel_size=7,
stride=2,
padding=3,
@@ -352,10 +342,8 @@ class CCT(nn.Module):
n_output_channels=embedding_dim,
frame_stride=frame_stride,
frame_kernel_size=frame_kernel_size,
frame_padding=frame_padding,
frame_pooling_stride=frame_pooling_stride,
frame_pooling_kernel_size=frame_pooling_kernel_size,
frame_pooling_padding=frame_pooling_padding,
kernel_size=kernel_size,
stride=stride,
padding=padding,

View File

@@ -1,204 +0,0 @@
# Simpler Fast Vision Transformers with a Jumbo CLS Token
# https://arxiv.org/abs/2502.15021
import torch
from torch import nn
from torch.nn import Module, ModuleList
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange
# helpers
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def divisible_by(num, den):
return (num % den) == 0
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
assert divisible_by(dim, 4), "feature dimension must be multiple of 4 for sincos emb"
omega = torch.arange(dim // 4) / (dim // 4 - 1)
omega = temperature ** -omega
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pos_emb = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
return pos_emb.type(dtype)
# classes
def FeedForward(dim, mult = 4.):
hidden_dim = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class JumboViT(Module):
def __init__(
self,
*,
image_size,
patch_size,
num_classes,
dim,
depth,
heads,
mlp_dim,
num_jumbo_cls = 1, # differing from paper, allow for multiple jumbo cls, so one could break it up into 2 jumbo cls tokens with 3x the dim, as an example
jumbo_cls_k = 6, # they use a CLS token with this factor times the dimension - 6 was the value they settled on
jumbo_ff_mult = 2, # expansion factor of the jumbo cls token feedforward
channels = 3,
dim_head = 64
):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert divisible_by(image_height, patch_height) and divisible_by(image_width, patch_width), 'Image dimensions must be divisible by the patch size.'
patch_dim = channels * patch_height * patch_width
self.to_patch_embedding = nn.Sequential(
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.pos_embedding = posemb_sincos_2d(
h = image_height // patch_height,
w = image_width // patch_width,
dim = dim,
)
jumbo_cls_dim = dim * jumbo_cls_k
self.jumbo_cls_token = nn.Parameter(torch.zeros(num_jumbo_cls, jumbo_cls_dim))
jumbo_cls_to_tokens = Rearrange('b n (k d) -> b (n k) d', k = jumbo_cls_k)
self.jumbo_cls_to_tokens = jumbo_cls_to_tokens
self.norm = nn.LayerNorm(dim)
self.layers = ModuleList([])
# attention and feedforwards
self.jumbo_ff = nn.Sequential(
Rearrange('b (n k) d -> b n (k d)', k = jumbo_cls_k),
FeedForward(jumbo_cls_dim, int(jumbo_cls_dim * jumbo_ff_mult)), # they use separate parameters for the jumbo feedforward, weight tied for parameter efficient
jumbo_cls_to_tokens
)
for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head),
FeedForward(dim, mlp_dim),
]))
self.to_latent = nn.Identity()
self.linear_head = nn.Linear(dim, num_classes)
def forward(self, img):
batch, device = img.shape[0], img.device
x = self.to_patch_embedding(img)
# pos embedding
pos_emb = self.pos_embedding.to(device, dtype = x.dtype)
x = x + pos_emb
# add cls tokens
cls_tokens = repeat(self.jumbo_cls_token, 'nj d -> b nj d', b = batch)
jumbo_tokens = self.jumbo_cls_to_tokens(cls_tokens)
x, cls_packed_shape = pack([jumbo_tokens, x], 'b * d')
# attention and feedforwards
for layer, (attn, ff) in enumerate(self.layers, start = 1):
is_last = layer == len(self.layers)
x = attn(x) + x
# jumbo feedforward
jumbo_cls_tokens, x = unpack(x, cls_packed_shape, 'b * d')
x = ff(x) + x
jumbo_cls_tokens = self.jumbo_ff(jumbo_cls_tokens) + jumbo_cls_tokens
if is_last:
continue
x, _ = pack([jumbo_cls_tokens, x], 'b * d')
pooled = reduce(jumbo_cls_tokens, 'b n d -> b d', 'mean')
# normalization and project to logits
embed = self.norm(pooled)
embed = self.to_latent(embed)
logits = self.linear_head(embed)
return logits
# copy pasteable file
if __name__ == '__main__':
v = JumboViT(
num_classes = 1000,
image_size = 64,
patch_size = 8,
dim = 16,
depth = 2,
heads = 2,
mlp_dim = 32,
jumbo_cls_k = 3,
jumbo_ff_mult = 2,
)
images = torch.randn(1, 3, 64, 64)
logits = v(images)
assert logits.shape == (1, 1000)

View File

@@ -5,6 +5,7 @@ from functools import partial
import torch
import packaging.version as pkg_version
assert pkg_version.parse(torch.__version__) >= pkg_version.parse('2.4'), 'install pytorch 2.4 or greater to use this flavor of NaViT'
from torch import nn, Tensor
import torch.nn.functional as F
@@ -41,7 +42,7 @@ def FeedForward(dim, hidden_dim, dropout = 0.):
)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., qk_norm = True):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
self.norm = nn.LayerNorm(dim, bias = False)
@@ -56,8 +57,8 @@ class Attention(Module):
# in the paper, they employ qk rmsnorm, a way to stabilize attention
# will use layernorm in place of rmsnorm, which has been shown to work in certain papers. requires l2norm on non-ragged dimension to be supported in nested tensors
self.query_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity()
self.key_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity()
self.query_norm = nn.LayerNorm(dim_head, bias = False)
self.key_norm = nn.LayerNorm(dim_head, bias = False)
self.dropout = dropout
@@ -111,13 +112,13 @@ class Attention(Module):
return self.to_out(out)
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., qk_norm = True):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, qk_norm = qk_norm),
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
@@ -146,15 +147,9 @@ class NaViT(Module):
dim_head = 64,
dropout = 0.,
emb_dropout = 0.,
qk_rmsnorm = True,
token_dropout_prob: float | None = None
):
super().__init__()
if pkg_version.parse(torch.__version__) < pkg_version.parse('2.5'):
print('nested tensor NaViT was tested on pytorch 2.5')
image_height, image_width = pair(image_size)
# what percent of tokens to dropout
@@ -185,7 +180,7 @@ class NaViT(Module):
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, qk_rmsnorm)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
# final attention pooling queries
@@ -326,5 +321,3 @@ if __name__ == '__main__':
]
assert v(images).shape == (5, 1000)
v(images).sum().backward()

View File

@@ -1,356 +0,0 @@
from __future__ import annotations
from typing import List
from functools import partial
import torch
import packaging.version as pkg_version
from torch import nn, Tensor
import torch.nn.functional as F
from torch.nn import Module, ModuleList
from torch.nested import nested_tensor
from einops import rearrange
from einops.layers.torch import Rearrange
# helpers
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def divisible_by(numer, denom):
return (numer % denom) == 0
# feedforward
def FeedForward(dim, hidden_dim, dropout = 0.):
return nn.Sequential(
nn.LayerNorm(dim, bias = False),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., qk_norm = True):
super().__init__()
self.norm = nn.LayerNorm(dim, bias = False)
dim_inner = heads * dim_head
self.heads = heads
self.dim_head = dim_head
self.to_queries = nn.Linear(dim, dim_inner, bias = False)
self.to_keys = nn.Linear(dim, dim_inner, bias = False)
self.to_values = nn.Linear(dim, dim_inner, bias = False)
# in the paper, they employ qk rmsnorm, a way to stabilize attention
# will use layernorm in place of rmsnorm, which has been shown to work in certain papers. requires l2norm on non-ragged dimension to be supported in nested tensors
self.query_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity()
self.key_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity()
self.dropout = dropout
self.to_out = nn.Linear(dim_inner, dim, bias = False)
def forward(
self,
x,
context: Tensor | None = None
):
x = self.norm(x)
# for attention pooling, one query pooling to entire sequence
context = default(context, x)
# queries, keys, values
query = self.to_queries(x)
key = self.to_keys(context)
value = self.to_values(context)
# split heads
def split_heads(t):
return t.unflatten(-1, (self.heads, self.dim_head))
def transpose_head_seq(t):
return t.transpose(1, 2)
query, key, value = map(split_heads, (query, key, value))
# qk norm for attention stability
query = self.query_norm(query)
key = self.key_norm(key)
query, key, value = map(transpose_head_seq, (query, key, value))
# attention
out = F.scaled_dot_product_attention(
query, key, value,
dropout_p = self.dropout if self.training else 0.
)
# merge heads
out = out.transpose(1, 2).flatten(-2)
return self.to_out(out)
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., qk_norm = True):
super().__init__()
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, qk_norm = qk_norm),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
self.norm = nn.LayerNorm(dim, bias = False)
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)
class NaViT(Module):
def __init__(
self,
*,
image_size,
max_frames,
patch_size,
frame_patch_size,
num_classes,
dim,
depth,
heads,
mlp_dim,
channels = 3,
dim_head = 64,
dropout = 0.,
emb_dropout = 0.,
num_registers = 4,
qk_rmsnorm = True,
token_dropout_prob: float | None = None
):
super().__init__()
image_height, image_width = pair(image_size)
if pkg_version.parse(torch.__version__) < pkg_version.parse('2.5'):
print('nested tensor NaViT was tested on pytorch 2.5')
# what percent of tokens to dropout
# if int or float given, then assume constant dropout prob
# otherwise accept a callback that in turn calculates dropout prob from height and width
self.token_dropout_prob = token_dropout_prob
# calculate patching related stuff
assert divisible_by(image_height, patch_size) and divisible_by(image_width, patch_size), 'Image dimensions must be divisible by the patch size.'
assert divisible_by(max_frames, frame_patch_size)
patch_frame_dim, patch_height_dim, patch_width_dim = (max_frames // frame_patch_size), (image_height // patch_size), (image_width // patch_size)
patch_dim = channels * (patch_size ** 2) * frame_patch_size
self.channels = channels
self.patch_size = patch_size
self.to_patches = Rearrange('c (f pf) (h p1) (w p2) -> f h w (c p1 p2 pf)', p1 = patch_size, p2 = patch_size, pf = frame_patch_size)
self.to_patch_embedding = nn.Sequential(
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.pos_embed_frame = nn.Parameter(torch.zeros(patch_frame_dim, dim))
self.pos_embed_height = nn.Parameter(torch.zeros(patch_height_dim, dim))
self.pos_embed_width = nn.Parameter(torch.zeros(patch_width_dim, dim))
# register tokens
self.register_tokens = nn.Parameter(torch.zeros(num_registers, dim))
nn.init.normal_(self.pos_embed_frame, std = 0.02)
nn.init.normal_(self.pos_embed_height, std = 0.02)
nn.init.normal_(self.pos_embed_width, std = 0.02)
nn.init.normal_(self.register_tokens, std = 0.02)
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, qk_rmsnorm)
# final attention pooling queries
self.attn_pool_queries = nn.Parameter(torch.randn(dim))
self.attn_pool = Attention(dim = dim, dim_head = dim_head, heads = heads)
# output to logits
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim, bias = False),
nn.Linear(dim, num_classes, bias = False)
)
@property
def device(self):
return next(self.parameters()).device
def forward(
self,
volumes: List[Tensor], # different resolution images / CT scans
):
batch, device = len(volumes), self.device
arange = partial(torch.arange, device = device)
assert all([volume.ndim == 4 and volume.shape[0] == self.channels for volume in volumes]), f'all volumes must have {self.channels} channels and number of dimensions of {self.channels} (channels, frame, height, width)'
all_patches = [self.to_patches(volume) for volume in volumes]
# prepare factorized positional embedding height width indices
positions = []
for patches in all_patches:
patch_frame, patch_height, patch_width = patches.shape[:3]
fhw_indices = torch.stack(torch.meshgrid((arange(patch_frame), arange(patch_height), arange(patch_width)), indexing = 'ij'), dim = -1)
fhw_indices = rearrange(fhw_indices, 'f h w c -> (f h w) c')
positions.append(fhw_indices)
# need the sizes to compute token dropout + positional embedding
tokens = [rearrange(patches, 'f h w d -> (f h w) d') for patches in all_patches]
# handle token dropout
seq_lens = torch.tensor([i.shape[0] for i in tokens], device = device)
if self.training and self.token_dropout_prob > 0:
keep_seq_lens = ((1. - self.token_dropout_prob) * seq_lens).int().clamp(min = 1)
kept_tokens = []
kept_positions = []
for one_image_tokens, one_image_positions, seq_len, num_keep in zip(tokens, positions, seq_lens, keep_seq_lens):
keep_indices = torch.randn((seq_len,), device = device).topk(num_keep, dim = -1).indices
one_image_kept_tokens = one_image_tokens[keep_indices]
one_image_kept_positions = one_image_positions[keep_indices]
kept_tokens.append(one_image_kept_tokens)
kept_positions.append(one_image_kept_positions)
tokens, positions, seq_lens = kept_tokens, kept_positions, keep_seq_lens
# add all height and width factorized positions
frame_indices, height_indices, width_indices = torch.cat(positions).unbind(dim = -1)
frame_embed, height_embed, width_embed = self.pos_embed_frame[frame_indices], self.pos_embed_height[height_indices], self.pos_embed_width[width_indices]
pos_embed = frame_embed + height_embed + width_embed
tokens = torch.cat(tokens)
# linear projection to patch embeddings
tokens = self.to_patch_embedding(tokens)
# absolute positions
tokens = tokens + pos_embed
# add register tokens
tokens = tokens.split(seq_lens.tolist())
tokens = [torch.cat((self.register_tokens, one_tokens)) for one_tokens in tokens]
# use nested tensor for transformers and save on padding computation
tokens = nested_tensor(tokens, layout = torch.jagged, device = device)
# embedding dropout
tokens = self.dropout(tokens)
# transformer
tokens = self.transformer(tokens)
# attention pooling
# will use a jagged tensor for queries, as SDPA requires all inputs to be jagged, or not
attn_pool_queries = [rearrange(self.attn_pool_queries, '... -> 1 ...')] * batch
attn_pool_queries = nested_tensor(attn_pool_queries, layout = torch.jagged)
pooled = self.attn_pool(attn_pool_queries, tokens)
# back to unjagged
logits = torch.stack(pooled.unbind())
logits = rearrange(logits, 'b 1 d -> b d')
logits = self.to_latent(logits)
return self.mlp_head(logits)
# quick test
if __name__ == '__main__':
# works for torch 2.5
v = NaViT(
image_size = 256,
max_frames = 8,
patch_size = 32,
frame_patch_size = 2,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.,
emb_dropout = 0.,
token_dropout_prob = 0.1
)
# 5 volumetric data (videos or CT scans) of different resolutions - List[Tensor]
volumes = [
torch.randn(3, 2, 256, 256), torch.randn(3, 8, 128, 128),
torch.randn(3, 4, 128, 256), torch.randn(3, 2, 256, 128),
torch.randn(3, 4, 64, 256)
]
assert v(volumes).shape == (5, 1000)
v(volumes).sum().backward()

View File

@@ -1,264 +0,0 @@
import torch
from torch import nn
from torch.nn import Module, ModuleList
import torch.nn.functional as F
import torch.nn.utils.parametrize as parametrize
from einops import rearrange, reduce
from einops.layers.torch import Rearrange
# functions
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def divisible_by(numer, denom):
return (numer % denom) == 0
def l2norm(t, dim = -1):
return F.normalize(t, dim = dim, p = 2)
# for use with parametrize
class L2Norm(Module):
def __init__(self, dim = -1):
super().__init__()
self.dim = dim
def forward(self, t):
return l2norm(t, dim = self.dim)
class NormLinear(Module):
def __init__(
self,
dim,
dim_out,
norm_dim_in = True
):
super().__init__()
self.linear = nn.Linear(dim, dim_out, bias = False)
parametrize.register_parametrization(
self.linear,
'weight',
L2Norm(dim = -1 if norm_dim_in else 0)
)
@property
def weight(self):
return self.linear.weight
def forward(self, x):
return self.linear(x)
# attention and feedforward
class Attention(Module):
def __init__(
self,
dim,
*,
dim_head = 64,
heads = 8,
dropout = 0.
):
super().__init__()
dim_inner = dim_head * heads
self.to_q = NormLinear(dim, dim_inner)
self.to_k = NormLinear(dim, dim_inner)
self.to_v = NormLinear(dim, dim_inner)
self.dropout = dropout
self.q_scale = nn.Parameter(torch.ones(heads, 1, dim_head) * (dim_head ** 0.25))
self.k_scale = nn.Parameter(torch.ones(heads, 1, dim_head) * (dim_head ** 0.25))
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
self.merge_heads = Rearrange('b h n d -> b n (h d)')
self.to_out = NormLinear(dim_inner, dim, norm_dim_in = False)
def forward(
self,
x
):
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
q, k, v = map(self.split_heads, (q, k, v))
# query key rmsnorm
q, k = map(l2norm, (q, k))
q = q * self.q_scale
k = k * self.k_scale
# scale is 1., as scaling factor is moved to s_qk (dk ^ 0.25) - eq. 16
out = F.scaled_dot_product_attention(
q, k, v,
dropout_p = self.dropout if self.training else 0.,
scale = 1.
)
out = self.merge_heads(out)
return self.to_out(out)
class FeedForward(Module):
def __init__(
self,
dim,
*,
dim_inner,
dropout = 0.
):
super().__init__()
dim_inner = int(dim_inner * 2 / 3)
self.dim = dim
self.dropout = nn.Dropout(dropout)
self.to_hidden = NormLinear(dim, dim_inner)
self.to_gate = NormLinear(dim, dim_inner)
self.hidden_scale = nn.Parameter(torch.ones(dim_inner))
self.gate_scale = nn.Parameter(torch.ones(dim_inner))
self.to_out = NormLinear(dim_inner, dim, norm_dim_in = False)
def forward(self, x):
hidden, gate = self.to_hidden(x), self.to_gate(x)
hidden = hidden * self.hidden_scale
gate = gate * self.gate_scale * (self.dim ** 0.5)
hidden = F.silu(gate) * hidden
hidden = self.dropout(hidden)
return self.to_out(hidden)
# classes
class nViT(Module):
""" https://arxiv.org/abs/2410.01131 """
def __init__(
self,
*,
image_size,
patch_size,
num_classes,
dim,
depth,
heads,
mlp_dim,
dropout = 0.,
channels = 3,
dim_head = 64,
residual_lerp_scale_init = None
):
super().__init__()
image_height, image_width = pair(image_size)
# calculate patching related stuff
assert divisible_by(image_height, patch_size) and divisible_by(image_width, patch_size), 'Image dimensions must be divisible by the patch size.'
patch_height_dim, patch_width_dim = (image_height // patch_size), (image_width // patch_size)
patch_dim = channels * (patch_size ** 2)
num_patches = patch_height_dim * patch_width_dim
self.channels = channels
self.patch_size = patch_size
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1 = patch_size, p2 = patch_size),
NormLinear(patch_dim, dim, norm_dim_in = False),
)
self.abs_pos_emb = NormLinear(dim, num_patches)
residual_lerp_scale_init = default(residual_lerp_scale_init, 1. / depth)
# layers
self.dim = dim
self.scale = dim ** 0.5
self.layers = ModuleList([])
self.residual_lerp_scales = nn.ParameterList([])
for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, dim_head = dim_head, heads = heads, dropout = dropout),
FeedForward(dim, dim_inner = mlp_dim, dropout = dropout),
]))
self.residual_lerp_scales.append(nn.ParameterList([
nn.Parameter(torch.ones(dim) * residual_lerp_scale_init / self.scale),
nn.Parameter(torch.ones(dim) * residual_lerp_scale_init / self.scale),
]))
self.logit_scale = nn.Parameter(torch.ones(num_classes))
self.to_pred = NormLinear(dim, num_classes)
@torch.no_grad()
def norm_weights_(self):
for module in self.modules():
if not isinstance(module, NormLinear):
continue
normed = module.weight
original = module.linear.parametrizations.weight.original
original.copy_(normed)
def forward(self, images):
device = images.device
tokens = self.to_patch_embedding(images)
seq_len = tokens.shape[-2]
pos_emb = self.abs_pos_emb.weight[torch.arange(seq_len, device = device)]
tokens = l2norm(tokens + pos_emb)
for (attn, ff), (attn_alpha, ff_alpha) in zip(self.layers, self.residual_lerp_scales):
attn_out = l2norm(attn(tokens))
tokens = l2norm(tokens.lerp(attn_out, attn_alpha * self.scale))
ff_out = l2norm(ff(tokens))
tokens = l2norm(tokens.lerp(ff_out, ff_alpha * self.scale))
pooled = reduce(tokens, 'b n d -> b d', 'mean')
logits = self.to_pred(pooled)
logits = logits * self.logit_scale * self.scale
return logits
# quick test
if __name__ == '__main__':
v = nViT(
image_size = 256,
patch_size = 16,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048,
)
img = torch.randn(4, 3, 256, 256)
logits = v(img) # (4, 1000)
assert logits.shape == (4, 1000)

View File

@@ -20,18 +20,6 @@ def divisible_by(val, d):
# helper classes
class ChanLayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class Downsample(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
@@ -224,10 +212,10 @@ class RegionViT(nn.Module):
if tokenize_local_3_conv:
self.local_encoder = nn.Sequential(
nn.Conv2d(3, init_dim, 3, 2, 1),
ChanLayerNorm(init_dim),
nn.LayerNorm(init_dim),
nn.GELU(),
nn.Conv2d(init_dim, init_dim, 3, 2, 1),
ChanLayerNorm(init_dim),
nn.LayerNorm(init_dim),
nn.GELU(),
nn.Conv2d(init_dim, init_dim, 3, 1, 1)
)

View File

@@ -3,14 +3,14 @@ from math import sqrt, pi, log
import torch
from torch import nn, einsum
import torch.nn.functional as F
from torch.amp import autocast
from torch.cuda.amp import autocast
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# rotary embeddings
@autocast('cuda', enabled = False)
@autocast(enabled = False)
def rotate_every_two(x):
x = rearrange(x, '... (d j) -> ... d j', j = 2)
x1, x2 = x.unbind(dim = -1)
@@ -24,7 +24,7 @@ class AxialRotaryEmbedding(nn.Module):
scales = torch.linspace(1., max_freq / 2, self.dim // 4)
self.register_buffer('scales', scales)
@autocast('cuda', enabled = False)
@autocast(enabled = False)
def forward(self, x):
device, dtype, n = x.device, x.dtype, int(sqrt(x.shape[-2]))

View File

@@ -1,233 +0,0 @@
"""
ViT + Hyper-Connections + Register Tokens
https://arxiv.org/abs/2409.19606
"""
import torch
from torch import nn, tensor
from torch.nn import Module, ModuleList
from einops import rearrange, repeat, reduce, einsum, pack, unpack
from einops.layers.torch import Rearrange
# b - batch, h - heads, n - sequence, e - expansion rate / residual streams, d - feature dimension
# helpers
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
omega = torch.arange(dim // 4) / (dim // 4 - 1)
omega = 1.0 / (temperature ** omega)
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
return pe.type(dtype)
# hyper connections
class HyperConnection(Module):
def __init__(
self,
dim,
num_residual_streams,
layer_index
):
""" Appendix J - Algorithm 2, Dynamic only """
super().__init__()
self.norm = nn.LayerNorm(dim, bias = False)
self.num_residual_streams = num_residual_streams
self.layer_index = layer_index
self.static_beta = nn.Parameter(torch.ones(num_residual_streams))
init_alpha0 = torch.zeros((num_residual_streams, 1))
init_alpha0[layer_index % num_residual_streams, 0] = 1.
self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + 1))
self.dynamic_alpha_scale = nn.Parameter(tensor(1e-2))
self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim))
self.dynamic_beta_scale = nn.Parameter(tensor(1e-2))
def width_connection(self, residuals):
normed = self.norm(residuals)
wc_weight = (normed @ self.dynamic_alpha_fn).tanh()
dynamic_alpha = wc_weight * self.dynamic_alpha_scale
alpha = dynamic_alpha + self.static_alpha
dc_weight = (normed @ self.dynamic_beta_fn).tanh()
dynamic_beta = dc_weight * self.dynamic_beta_scale
beta = dynamic_beta + self.static_beta
# width connection
mix_h = einsum(alpha, residuals, '... e1 e2, ... e1 d -> ... e2 d')
branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
return branch_input, residuals, beta
def depth_connection(
self,
branch_output,
residuals,
beta
):
return einsum(branch_output, beta, "b n d, b n e -> b n e d") + residuals
# classes
class FeedForward(Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
def forward(self, x):
return self.net(x)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, num_residual_streams):
super().__init__()
self.num_residual_streams = num_residual_streams
self.norm = nn.LayerNorm(dim)
self.layers = ModuleList([])
for layer_index in range(depth):
self.layers.append(nn.ModuleList([
HyperConnection(dim, num_residual_streams, layer_index),
Attention(dim, heads = heads, dim_head = dim_head),
HyperConnection(dim, num_residual_streams, layer_index),
FeedForward(dim, mlp_dim)
]))
def forward(self, x):
x = repeat(x, 'b n d -> b n e d', e = self.num_residual_streams)
for attn_hyper_conn, attn, ff_hyper_conn, ff in self.layers:
x, attn_res, beta = attn_hyper_conn.width_connection(x)
x = attn(x)
x = attn_hyper_conn.depth_connection(x, attn_res, beta)
x, ff_res, beta = ff_hyper_conn.width_connection(x)
x = ff(x)
x = ff_hyper_conn.depth_connection(x, ff_res, beta)
x = reduce(x, 'b n e d -> b n d', 'sum')
return self.norm(x)
class SimpleViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, num_residual_streams, num_register_tokens = 4, channels = 3, dim_head = 64):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
patch_dim = channels * patch_height * patch_width
self.to_patch_embedding = nn.Sequential(
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim))
self.pos_embedding = posemb_sincos_2d(
h = image_height // patch_height,
w = image_width // patch_width,
dim = dim,
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, num_residual_streams)
self.pool = "mean"
self.to_latent = nn.Identity()
self.linear_head = nn.Linear(dim, num_classes)
def forward(self, img):
batch, device = img.shape[0], img.device
x = self.to_patch_embedding(img)
x += self.pos_embedding.to(x)
r = repeat(self.register_tokens, 'n d -> b n d', b = batch)
x, ps = pack([x, r], 'b * d')
x = self.transformer(x)
x, _ = unpack(x, ps, 'b * d')
x = x.mean(dim = 1)
x = self.to_latent(x)
return self.linear_head(x)
# main
if __name__ == '__main__':
vit = SimpleViT(
num_classes = 1000,
image_size = 256,
patch_size = 8,
dim = 1024,
depth = 12,
heads = 8,
mlp_dim = 2048,
num_residual_streams = 8
)
images = torch.randn(3, 3, 256, 256)
logits = vit(images)

View File

@@ -1,159 +0,0 @@
import torch
from torch import nn
from torch.nn import Module, ModuleList
from einops import rearrange
from einops.layers.torch import Rearrange
# helpers
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
omega = torch.arange(dim // 4) / (dim // 4 - 1)
omega = 1.0 / (temperature ** omega)
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
return pe.type(dtype)
# classes
def FeedForward(dim, hidden_dim):
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, learned_value_residual_mix = False):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
self.to_residual_mix = nn.Sequential(
nn.Linear(dim, heads),
nn.Sigmoid(),
Rearrange('b n h -> b h n 1')
) if learned_value_residual_mix else (lambda _: 0.5)
def forward(self, x, value_residual = None):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
if exists(value_residual):
mix = self.to_residual_mix(x)
v = v * mix + value_residual * (1. - mix)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out), v
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = ModuleList([])
for i in range(depth):
is_first = i == 0
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, learned_value_residual_mix = not is_first),
FeedForward(dim, mlp_dim)
]))
def forward(self, x):
value_residual = None
for attn, ff in self.layers:
attn_out, values = attn(x, value_residual = value_residual)
value_residual = default(value_residual, values)
x = attn_out + x
x = ff(x) + x
return self.norm(x)
class SimpleViT(Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
patch_dim = channels * patch_height * patch_width
self.to_patch_embedding = nn.Sequential(
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.pos_embedding = posemb_sincos_2d(
h = image_height // patch_height,
w = image_width // patch_width,
dim = dim,
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
self.pool = "mean"
self.to_latent = nn.Identity()
self.linear_head = nn.Linear(dim, num_classes)
def forward(self, img):
device = img.device
x = self.to_patch_embedding(img)
x += self.pos_embedding.to(device, dtype=x.dtype)
x = self.transformer(x)
x = x.mean(dim = 1)
x = self.to_latent(x)
return self.linear_head(x)
# quick test
if __name__ == '__main__':
v = SimpleViT(
num_classes = 1000,
image_size = 256,
patch_size = 8,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048,
)
images = torch.randn(2, 3, 256, 256)
logits = v(images)

View File

@@ -1,444 +0,0 @@
from __future__ import annotations
import torch
import torch.nn.functional as F
from torch import nn, cat, stack, tensor
from torch.nn import Module, ModuleList
from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange
# helpers
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# classes
class FeedForward(Module):
def __init__(
self,
dim,
hidden_dim,
dropout = 0.
):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(Module):
def __init__(
self,
dim,
heads = 8,
dim_head = 64,
dropout = 0.,
cross_attend = False
):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.cross_attend = cross_attend
self.context_norm = nn.LayerNorm(dim) if cross_attend else None
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x, context = None):
assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross attending, or vice versa'
x = self.norm(x)
# handle norming of context for cross attention
kv_input = x
if self.cross_attend:
context = self.context_norm(context)
kv_input = context
# project for queries, keys, values
qkv = (self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(Module):
def __init__(
self,
dim,
depth,
heads,
dim_head,
mlp_dim,
dropout = 0.
):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(
self,
x,
return_hiddens = False
):
hiddens = []
for attn, ff in self.layers:
hiddens.append(x)
x = attn(x) + x
x = ff(x) + x
x = self.norm(x)
if not return_hiddens:
return x
return x, hiddens
class ViT(Module):
def __init__(
self,
*,
image_size,
patch_size,
num_classes,
dim,
depth,
heads,
mlp_dim,
pool = 'cls',
channels = 3,
dim_head = 64,
dropout = 0.,
emb_dropout = 0.
):
super().__init__()
self.dim = dim
self.depth = depth
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes)
def forward(self, img, return_hiddens = False):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x, hiddens = self.transformer(x, return_hiddens = True)
# return the representation trajectory
if return_hiddens:
return x, stack(hiddens)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)
# proposed VAT
# https://openreview.net/forum?id=TalHOvvLZu
# simple way to get SOTA on Libero dataset (beating fine-tuned pi-zero)
class VAT(Module):
def __init__(
self,
vit: ViT | dict,
*,
dim,
depth,
heads,
dim_head,
dim_action,
mlp_dim,
num_views = None,
dim_extra_token = None,
action_chunk_len = 7,
time_seq_len = 1,
dropout = 0.,
add_self_attn = True, # in the paper, they didn't have any ways for the action token to exchange information with the extra token, so we'll just add it as an option
self_attn_heads = 4,
self_attn_dim_head = 32,
vit_layer_indices: tuple[int, ...] | None = None
):
super().__init__()
if isinstance(vit, dict):
vit = ViT(**vit)
self.vit = vit
vit_dim = vit.dim
assert vit.depth == depth or exists(vit_layer_indices), f'if the VAT depth is not equal to the ViT depth, you must pass in the indices from the ViT to be layered to the VAT in order from bottom to top'
vit_layer_indices = default(vit_layer_indices, tuple(range(depth)))
assert len(vit_layer_indices) == depth, f'number of vit layer indices {len(vit_layer_indices)} does not much the VAT depth {depth}'
self.register_buffer('layer_indices', tensor(vit_layer_indices), persistent = False)
# handle maybe multiple frames
is_video = time_seq_len > 1
self.is_video = is_video
self.time_seq_len = time_seq_len
self.time_pos_emb = nn.Parameter(torch.randn(time_seq_len, vit_dim) * 1e-2) if is_video else None
# maybe view embeddings
self.view_emb = nn.Parameter(torch.randn(num_views, vit_dim) * 1e-2) if exists(num_views) and num_views > 1 else None
# to action tokens
self.action_pos_emb = nn.Parameter(torch.randn(action_chunk_len, dim) * 1e-2)
self.layers = ModuleList([])
for _ in range(depth):
maybe_self_attn = Attention(dim = dim, heads = self_attn_heads, dim_head = self_attn_dim_head, dropout = dropout) if add_self_attn else None
self.layers.append(ModuleList([
maybe_self_attn,
Attention(dim = dim, heads = heads, dim_head = dim_head, dropout = dropout, cross_attend = True),
FeedForward(dim = dim, hidden_dim = mlp_dim, dropout = dropout)
]))
self.final_norm = nn.LayerNorm(dim)
self.to_pred_action = nn.Linear(dim, dim_action, bias = False)
# handle the extra token
self.accept_extra_token = exists(dim_extra_token)
if exists(dim_extra_token):
self.to_extra_token = nn.Linear(dim_extra_token, dim)
def forward(
self,
video_or_image, # (b v? c t? h w) - batch, views [wrist + third person or more], channels, maybe time, height, width
extra = None, # (b d) - batch, dim extra
actions = None, # (b k d) - batch, action chunk length, action dimension
):
batch = video_or_image.shape[0]
return_loss = exists(actions)
# handle some various input dimensions
if video_or_image.ndim == 4:
video_or_image = rearrange(video_or_image, 'b 1 c h w')
assert (
(video_or_image.ndim == 5 and not self.is_video) or
(video_or_image.ndim == 6 and self.is_video)
)
if video_or_image.ndim == 5:
video_or_image = rearrange(video_or_image, 'b v c h w -> b v c 1 h w')
assert video_or_image.shape[3] == self.time_seq_len
# to images
images = rearrange(video_or_image, 'b v c t h w -> b v t c h w')
images, packed_shape = pack([images], '* c h w')
# get representation trajectory from vit
embed, hiddens = self.vit(images, return_hiddens = True)
hiddens = cat((hiddens, embed[None, ...]))
# extract the hiddens needed for the action cross attention
hiddens = hiddens[self.layer_indices]
# pack temporarily for embedding
hiddens, = unpack(hiddens, packed_shape, 'l * n d') # l for layers
# maybe add time embeddings
if self.is_video:
time_pos_emb = rearrange(time_pos_emb, 't d -> t 1 d')
hiddens = hiddens + time_pos_emb
# maybe view embeddings
if exists(self.view_emb):
assert self.view_emb.shape[0] == hiddens.shape[2]
view_emb = rearrange(self.view_emb, 'v d -> v 1 1 d')
hiddens = hiddens + view_emb
# cross from actions to representation trajectory
context = rearrange(hiddens, 'l b v t n d -> l b (v t n) d')
# get main action tokens and maybe append extra
action_tokens = repeat(self.action_pos_emb, 'k d -> b k d', b = batch)
has_extra = exists(extra)
if has_extra:
assert self.accept_extra_token
extra_token = self.to_extra_token(extra)
action_tokens, packed_extra = pack([action_tokens, extra_token], 'b * d')
# cross attention
for (maybe_self_attn, cross_attn, ff), layer_context in zip(self.layers, context):
action_tokens = cross_attn(action_tokens, layer_context) + action_tokens
if exists(maybe_self_attn):
action_tokens = maybe_self_attn(action_tokens) + action_tokens
action_tokens = ff(action_tokens) + action_tokens
# maybe unpack extra
if has_extra:
action_tokens, _ = unpack(action_tokens, packed_extra, 'b * d')
# norm and prediction
action_tokens = self.final_norm(action_tokens)
pred_action = self.to_pred_action(action_tokens)
if not return_loss:
return pred_action
assert pred_action.shape[1] == actions.shape[1]
# they found l1 loss suffices
return F.l1_loss(pred_action, actions)
# quick test
if __name__ == '__main__':
vit = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 512,
heads = 8,
depth = 4,
mlp_dim = 2048
)
vat = VAT(
vit,
dim = 512,
depth = 9,
heads = 8,
dim_head = 64,
mlp_dim = 2048,
dim_action = 20,
action_chunk_len = 7,
time_seq_len = 1,
num_views = 2,
add_self_attn = True,
dim_extra_token = 33, # extra token with some variable dimension
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
0, 0, 1, 1, 2, 2, 3, 3, 4
)
)
images = torch.randn(2, 2, 3, 256, 256) # (2 views with 4 frames)
extra = torch.randn(2, 33) # extra internal state
actions = torch.randn(2, 7, 20) # actions for learning
loss = vat(images, actions = actions, extra = extra)
loss.backward()
# after much training
pred_actions = vat(images)
assert pred_actions.shape == (2, 7, 20)

View File

@@ -1,191 +0,0 @@
from __future__ import annotations
import torch
from torch import nn
from torch.nn import Module
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
def join(arr, delimiter = ' '):
return delimiter.join(arr)
def ensure_tuple(t, length):
if isinstance(t, (tuple, list)):
assert len(t) == length, f'Expected tuple of length {length}, got {len(t)}'
return tuple(t)
return (t,) * length
# classes
class FeedForward(Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)
class ViTND(Module):
def __init__(
self,
*,
ndim: int,
input_shape: int | tuple[int, ...],
patch_size: int | tuple[int, ...],
num_classes: int,
dim: int,
depth: int,
heads: int,
mlp_dim: int,
pool: str = 'cls',
channels: int = 3,
dim_head: int = 64,
dropout: float = 0.,
emb_dropout: float = 0.
):
super().__init__()
assert 1 <= ndim <= 7, 'ndim must be between 1 and 7'
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.ndim = ndim
self.pool = pool
input_shape = ensure_tuple(input_shape, ndim)
patch_size = ensure_tuple(patch_size, ndim)
for i, (inp_dim, patch_dim) in enumerate(zip(input_shape, patch_size)):
assert inp_dim % patch_dim == 0, f'Input dimension {i} ({inp_dim}) must be divisible by patch size ({patch_dim})'
num_patches_per_dim = [inp_dim // patch_dim for inp_dim, patch_dim in zip(input_shape, patch_size)]
num_patches = 1
for n in num_patches_per_dim:
num_patches *= n
patch_dim = channels
for p in patch_size:
patch_dim *= p
dim_names = 'fghijkl'[:ndim]
input_dims = [f'({d} p{i})' for i, d in enumerate(dim_names)]
patch_dims = [f'p{i}' for i in range(ndim)]
input_pattern = f'b c {join(input_dims)}'
output_pattern = f'b ({join(dim_names)}) ({join(patch_dims)} c)'
rearrange_str = f'{input_pattern} -> {output_pattern}'
rearrange_kwargs = {f'p{i}': p for i, p in enumerate(patch_size)}
self.to_patch_embedding = nn.Sequential(
Rearrange(rearrange_str, **rearrange_kwargs),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes)
def forward(self, x):
x = self.to_patch_embedding(x)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim = 1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x)
x = x[:, 1:].mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)
if __name__ == '__main__':
model = ViTND(
ndim = 4,
input_shape = (8, 16, 32, 64),
patch_size = (2, 4, 4, 8),
num_classes = 1000,
dim = 512,
depth = 6,
heads = 8,
mlp_dim = 2048,
channels = 3,
dropout = 0.1,
emb_dropout = 0.1
)
occupancy_time = torch.randn(2, 3, 8, 16, 32, 64)
logits = model(occupancy_time)

View File

@@ -1,325 +0,0 @@
from __future__ import annotations
import torch
from torch import nn, arange, cat, stack, Tensor
from torch.nn import Module, ModuleList
import torch.nn.functional as F
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange
# helpers
def exists(val):
return val is not None
def l2norm(t):
return F.normalize(t, dim = -1, p = 2)
def join(arr, delimiter = ' '):
return delimiter.join(arr)
def ensure_tuple(t, length):
if isinstance(t, (tuple, list)):
assert len(t) == length, f'Expected tuple of length {length}, got {len(t)}'
return tuple(t)
return (t,) * length
# golden gate rotary - Jerry Xiong, PhD student at UIUC
# https://jerryxio.ng/posts/nd-rope/
def _phi(m: int) -> float:
x = 2.0
for _ in range(10):
x = (1 + x) ** (1.0 / (m + 1.0))
return x
def make_directions(n: int, d: int) -> Tensor:
g = _phi(d)
alpha = (1.0 / g) ** arange(1, d + 1, dtype = torch.float64)
i = arange(1, n + 1, dtype = torch.float64).unsqueeze(1)
z = torch.fmod(i * alpha, 1.0)
directions = torch.erfinv(2.0 * z - 1.0)
directions = l2norm(directions)
return directions.float()
class GoldenGateRoPENd(Module):
def __init__(
self,
dim_pos: int,
heads: int,
dim_head: int,
rope_min_freq: float = 1.0,
rope_max_freq: float = 10000.0,
rope_p_zero_freqs: float = 0.0, # proportion of frequencies set to 0
):
super().__init__()
n_freqs = dim_head // 2
n_zero_freqs = round(rope_p_zero_freqs * n_freqs)
omega = cat((
torch.zeros(n_zero_freqs),
rope_min_freq * (rope_max_freq / rope_min_freq) ** torch.linspace(0, 1, n_freqs - n_zero_freqs),
))
directions = rearrange(
make_directions(heads * n_freqs, dim_pos),
'(h f) p -> h f p',
h = heads
)
omega_expanded = rearrange(omega, 'f -> f 1')
self.register_buffer('freqs', directions * omega_expanded) # shape: (h, f, p)
def forward(self, input: Tensor, pos: Tensor) -> Tensor:
# input shape: (b, h, n, d) where d = head_dim
# pos shape: (b, n, p) where p = pos_dim
# self.freqs shape: (h, f, p) where f = d // 2
x, y = input.float().chunk(2, dim = -1) # both (b, h, n, f)
# Expand dimensions for broadcasting
freqs = rearrange(self.freqs, 'h f p -> 1 h 1 f p')
positions = rearrange(pos.float(), 'b n p -> b 1 n 1 p')
# Compute theta for each (batch, head, seq, freq)
theta = reduce(freqs * positions, 'b h n f p -> b h n f', 'sum')
cos_theta = torch.cos(theta)
sin_theta = torch.sin(theta)
# Apply rotation
x_out = x * cos_theta - y * sin_theta
y_out = x * sin_theta + y * cos_theta
output = cat((x_out, y_out), dim=-1)
return output.type_as(input)
# classes
class FeedForward(Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., rotary_emb = None):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.rotary_emb = rotary_emb
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qk = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_v = nn.Linear(dim, inner_dim, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x, pos = None):
x = self.norm(x)
qkv = (*self.to_qk(x).chunk(2, dim = -1), self.to_v(x))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
# Apply rotary embeddings if available
if exists(self.rotary_emb):
assert exists(pos)
q = self.rotary_emb(q, pos)
k = self.rotary_emb(k, pos)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., rotary_emb = None):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, rotary_emb = rotary_emb),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x, pos = None):
for attn, ff in self.layers:
x = attn(x, pos) + x
x = ff(x) + x
return self.norm(x)
class ViTND(Module):
def __init__(
self,
*,
ndim: int,
input_shape: int | tuple[int, ...],
patch_size: int | tuple[int, ...],
num_classes: int,
dim: int,
depth: int,
heads: int,
mlp_dim: int,
channels: int = 3,
dim_head: int = 64,
dropout: float = 0.,
emb_dropout: float = 0.,
rope_min_freq: float = 1.0,
rope_max_freq: float = 10000.0,
rope_p_zero_freqs: float = 0.0
):
super().__init__()
assert 1 <= ndim <= 7, 'ndim must be between 1 and 7'
self.ndim = ndim
input_shape = ensure_tuple(input_shape, ndim)
patch_size = ensure_tuple(patch_size, ndim)
for i, (inp_dim, patch_dim) in enumerate(zip(input_shape, patch_size)):
assert inp_dim % patch_dim == 0, f'Input dimension {i} ({inp_dim}) must be divisible by patch size ({patch_dim})'
num_patches_per_dim = [inp_dim // patch_dim for inp_dim, patch_dim in zip(input_shape, patch_size)]
num_patches = 1
for n in num_patches_per_dim:
num_patches *= n
patch_dim = channels
for p in patch_size:
patch_dim *= p
dim_names = 'fghijkl'[:ndim]
input_dims = [f'({d} p{i})' for i, d in enumerate(dim_names)]
patch_dims = [f'p{i}' for i in range(ndim)]
input_pattern = f'b c {join(input_dims)}'
output_pattern = f'b {join(dim_names)} ({join(patch_dims)} c)'
rearrange_str = f'{input_pattern} -> {output_pattern}'
rearrange_kwargs = {f'p{i}': p for i, p in enumerate(patch_size)}
self.to_patch_embedding = nn.Sequential(
Rearrange(rearrange_str, **rearrange_kwargs),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.dropout = nn.Dropout(emb_dropout)
# Create rotary embeddings
self.rotary_emb = GoldenGateRoPENd(
dim_pos = ndim,
heads = heads,
dim_head = dim_head,
rope_min_freq = rope_min_freq,
rope_max_freq = rope_max_freq,
rope_p_zero_freqs = rope_p_zero_freqs
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, rotary_emb = self.rotary_emb)
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes)
def muon_parameters(self):
params = []
for m in self.modules():
if isinstance(m, Attention):
params.extend([
m.to_v.weight,
m.to_out[0].weight
])
elif isinstance(m, FeedForward):
params.extend([
m.net[1].weight,
m.net[-2].weight
])
return params
def forward(
self,
x,
return_embed = False
):
x = self.to_patch_embedding(x) # (b, *spatial_dims, patch_dim)
batch, *spatial_dims, _, device = *x.shape, x.device
# Generate position coordinates
grids = [arange(d, device = device, dtype = torch.float32) for d in spatial_dims]
grid = torch.meshgrid(*grids, indexing = 'ij')
pos = stack(grid, dim = -1) # (*spatial_dims, ndim)
# flatten spatial dimensions for attention with nd rotary
pos = repeat(pos, '... p -> b (...) p', b = batch)
x, packed_shape = pack([x], 'b * d')
x = self.dropout(x)
embed = self.transformer(x, pos)
# return the embed with reconstituted patch shape
if return_embed:
embed, = unpack(embed, packed_shape, 'b * d')
return embed
# pooling to logits
pooled = reduce(embed, 'b n d -> b d', 'mean')
pooled = self.to_latent(pooled)
return self.mlp_head(pooled)
if __name__ == '__main__':
model = ViTND(
ndim = 5,
input_shape = (4, 8, 16, 32, 64),
patch_size = (2, 2, 4, 4, 8),
num_classes = 1000,
dim = 512,
depth = 6,
heads = 8,
mlp_dim = 2048,
channels = 3,
dropout = 0.1,
emb_dropout = 0.1
)
data = torch.randn(2, 3, 4, 8, 16, 32, 64)
logits = model(data)
embed = model(data, return_embed = True) # (2, 2, 4, 4, 8, 8, 512)

View File

@@ -78,30 +78,6 @@ class Transformer(nn.Module):
x = ff(x) + x
return self.norm(x)
class FactorizedTransformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
b, f, n, _ = x.shape
for spatial_attn, temporal_attn, ff in self.layers:
x = rearrange(x, 'b f n d -> (b f) n d')
x = spatial_attn(x) + x
x = rearrange(x, '(b f) n d -> (b n) f d', b=b, f=f)
x = temporal_attn(x) + x
x = ff(x) + x
x = rearrange(x, '(b n) f d -> b f n d', b=b, n=n)
return self.norm(x)
class ViT(nn.Module):
def __init__(
self,
@@ -120,8 +96,7 @@ class ViT(nn.Module):
channels = 3,
dim_head = 64,
dropout = 0.,
emb_dropout = 0.,
variant = 'factorized_encoder',
emb_dropout = 0.
):
super().__init__()
image_height, image_width = pair(image_size)
@@ -129,7 +104,6 @@ class ViT(nn.Module):
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size'
assert variant in ('factorized_encoder', 'factorized_self_attention'), f'variant = {variant} is not implemented'
num_image_patches = (image_height // patch_height) * (image_width // patch_width)
num_frame_patches = (frames // frame_patch_size)
@@ -151,20 +125,15 @@ class ViT(nn.Module):
self.dropout = nn.Dropout(emb_dropout)
self.spatial_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None
self.temporal_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None
if variant == 'factorized_encoder':
self.temporal_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None
self.spatial_transformer = Transformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)
self.temporal_transformer = Transformer(dim, temporal_depth, heads, dim_head, mlp_dim, dropout)
elif variant == 'factorized_self_attention':
assert spatial_depth == temporal_depth, 'Spatial and temporal depth must be the same for factorized self-attention'
self.factorized_transformer = FactorizedTransformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)
self.spatial_transformer = Transformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)
self.temporal_transformer = Transformer(dim, temporal_depth, heads, dim_head, mlp_dim, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes)
self.variant = variant
def forward(self, video):
x = self.to_patch_embedding(video)
@@ -178,37 +147,32 @@ class ViT(nn.Module):
x = self.dropout(x)
if self.variant == 'factorized_encoder':
x = rearrange(x, 'b f n d -> (b f) n d')
x = rearrange(x, 'b f n d -> (b f) n d')
# attend across space
# attend across space
x = self.spatial_transformer(x)
x = rearrange(x, '(b f) n d -> b f n d', b = b)
x = self.spatial_transformer(x)
# excise out the spatial cls tokens or average pool for temporal attention
x = rearrange(x, '(b f) n d -> b f n d', b = b)
x = x[:, :, 0] if not self.global_average_pool else reduce(x, 'b f n d -> b f d', 'mean')
# excise out the spatial cls tokens or average pool for temporal attention
# append temporal CLS tokens
x = x[:, :, 0] if not self.global_average_pool else reduce(x, 'b f n d -> b f d', 'mean')
if exists(self.temporal_cls_token):
temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = b)
# append temporal CLS tokens
x = torch.cat((temporal_cls_tokens, x), dim = 1)
if exists(self.temporal_cls_token):
temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = b)
# attend across time
x = torch.cat((temporal_cls_tokens, x), dim = 1)
x = self.temporal_transformer(x)
# attend across time
# excise out temporal cls token or average pool
x = self.temporal_transformer(x)
x = x[:, 0] if not self.global_average_pool else reduce(x, 'b f d -> b d', 'mean')
# excise out temporal cls token or average pool
elif self.variant == 'factorized_self_attention':
x = self.factorized_transformer(x)
x = x[:, 0, 0] if not self.global_average_pool else reduce(x, 'b f n d -> b d', 'mean')
x = x[:, 0] if not self.global_average_pool else reduce(x, 'b f d -> b d', 'mean')
x = self.to_latent(x)
return self.mlp_head(x)