Compare commits

...

30 Commits

Author SHA1 Message Date
lucidrains
98cbdab5a4 have a language model address https://github.com/lucidrains/vit-pytorch/issues/348 2025-09-25 06:12:37 -07:00
lucidrains
f6bc14c81d able to return embed from vit-nd-rotary 2025-09-23 07:21:34 -07:00
lucidrains
845c844b3b add a vit nd with rotary nd, from Jerry Xiong at UIUC 2025-09-21 10:45:42 -07:00
lucidrains
5f2bc0c796 with assistance from claude (yes it did the einops equation building here), generalize to n-dimensions 2025-09-21 06:22:43 -07:00
lucidrains
35bf273037 1.11.7 2025-08-17 18:07:42 -07:00
Baraa sameeh
1123063a5e Make all CCT regularization parameters user-configurable. (#346) 2025-08-17 18:07:25 -07:00
lucidrains
f8bec5ede2 able to project the image embedding before applying time positional embedding for accept video wrapper 2025-08-13 10:15:18 -07:00
lucidrains
297e7d00a2 handle channel first for accept video wrapper 2025-08-03 08:29:40 -07:00
lucidrains
29ac8e143c fix when video time seq len less than max time seq len for video acceptor 2025-07-27 09:00:56 -07:00
lucidrains
e05cd6d8b8 some models only return embeddings with some kwarg on forward 2025-07-27 08:46:43 -07:00
lucidrains
b46233c3d6 need to be able to invoke with eval no grad 2025-07-27 08:25:58 -07:00
lucidrains
68e13a3c7d bit more flexible 2025-07-27 08:14:48 -07:00
lucidrains
b22dc0ecd2 add a wrapper for accepting video and processing the images individually, optionally able to add time positional embeddings - for use in two robotics work 2025-07-27 08:05:48 -07:00
lucidrains
db05a141a6 add the proposed jumbo vit from Fuller et al. of Carleton University 2025-03-05 10:50:34 -08:00
lucidrains
9f49a31977 1.9.2 2025-01-19 05:53:11 -08:00
JacobLinCool
ab63fc9cc8 remove duplicated qkv computation in na_vit_nested_tensor_3d.py (#341) 2025-01-19 05:52:46 -08:00
Phil Wang
c3018d1433 1.9.1 2025-01-04 07:55:49 -08:00
Kale Kundert
b7ed6bad28 add option to set frame padding for 3D CCT (#339) 2025-01-04 07:55:27 -08:00
lucidrains
e7cba9ba6d add a simple vit flavor for a new bytedance paper that proposes to break out of the traditional one residual stream architecture - "hyper-connections" 2024-12-20 17:43:50 -08:00
lucidrains
56373c0cbd make value residual learned 2024-11-24 08:21:28 -08:00
lucidrains
24196a3e8a allow for qk norm to be turned off for na vit nested tensor 2024-11-20 10:59:22 -08:00
Phil Wang
f6d7287b6b readme 2024-11-19 08:20:38 -08:00
lucidrains
d47c57e32f fix tests 2024-11-10 09:43:54 -08:00
lucidrains
0449865786 update minimum version for nested tensor of NaViT 2024-11-10 09:37:48 -08:00
lucidrains
6693d47d0b update comment for navit 3d 2024-11-07 20:02:07 -08:00
Phil Wang
141239ca86 fix value residual 2024-10-31 06:48:24 -07:00
lucidrains
0b5c9b4559 add value residual based simple vit 2024-10-28 09:19:00 -07:00
lucidrains
e300cdd7dc fix multiheaded qk rmsnorm in nViT 2024-10-10 19:15:17 -07:00
Phil Wang
36ddc7a6ba go all the way with the normalized vit, fix some scales 2024-10-10 10:42:37 -07:00
Phil Wang
1d1a63fc5c cite for hypersphere vit adapted from ngpt 2024-10-10 10:15:04 -07:00
16 changed files with 1436 additions and 101 deletions

View File

@@ -18,9 +18,9 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v2
uses: actions/setup-python@v5
with:
python-version: '3.x'
- name: Install dependencies

View File

@@ -18,18 +18,17 @@ jobs:
python-version: [3.8, 3.9]
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install pytest
python -m pip install wheel
python -m pip install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cpu
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
python -m pip install -e .
python -m pip install pytest
- name: Test with pytest
run: |
python setup.py test
pytest -q

View File

@@ -198,7 +198,7 @@ preds = v(
) # (5, 1000)
```
Finally, if you would like to make use of a flavor of NaViT using <a href="https://pytorch.org/tutorials/prototype/nestedtensor.html">nested tensors</a> (which will omit a lot of the masking and padding altogether), make sure you are on version `2.4` and import as follows
Finally, if you would like to make use of a flavor of NaViT using <a href="https://pytorch.org/tutorials/prototype/nestedtensor.html">nested tensors</a> (which will omit a lot of the masking and padding altogether), make sure you are on version `2.5` and import as follows
```python
import torch
@@ -2142,4 +2142,52 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```
```bibtex
@inproceedings{Liu2017DeepHL,
title = {Deep Hyperspherical Learning},
author = {Weiyang Liu and Yanming Zhang and Xingguo Li and Zhen Liu and Bo Dai and Tuo Zhao and Le Song},
booktitle = {Neural Information Processing Systems},
year = {2017},
url = {https://api.semanticscholar.org/CorpusID:5104558}
}
```
```bibtex
@inproceedings{Zhou2024ValueRL,
title = {Value Residual Learning For Alleviating Attention Concentration In Transformers},
author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:273532030}
}
```
```bibtex
@article{Zhu2024HyperConnections,
title = {Hyper-Connections},
author = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou},
journal = {ArXiv},
year = {2024},
volume = {abs/2409.19606},
url = {https://api.semanticscholar.org/CorpusID:272987528}
}
```
```bibtex
@inproceedings{Fuller2025SimplerFV,
title = {Simpler Fast Vision Transformers with a Jumbo CLS Token},
author = {Anthony Fuller and Yousef Yassin and Daniel G. Kyrollos and Evan Shelhamer and James R. Green},
year = {2025},
url = {https://api.semanticscholar.org/CorpusID:276557720}
}
```
```bibtex
@misc{xiong2025ndrope,
author = {Jerry Xiong},
title = {On n-dimensional rotary positional embeddings},
year = {2025},
url = {https://jerryxio.ng/posts/nd-rope/}
}
```
*I visualise a time when we will be to robots what dogs are to humans, and Im rooting for the machines.* — Claude Shannon

63
pyproject.toml Normal file
View File

@@ -0,0 +1,63 @@
[build-system]
requires = ["setuptools>=61", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "vit-pytorch"
version = "1.12.3"
description = "Vision Transformer (ViT) - Pytorch"
readme = { file = "README.md", content-type = "text/markdown" }
license = { file = "LICENSE" }
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" },
]
requires-python = ">=3.8"
keywords = [
"artificial intelligence",
"attention mechanism",
"image recognition",
]
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
]
dependencies = [
"einops>=0.7.0",
"torch>=1.10",
"torchvision",
]
[project.optional-dependencies]
test = [
"pytest",
"torch==2.4.0",
"torchvision==0.19.0",
]
[project.urls]
Homepage = "https://github.com/lucidrains/vit-pytorch"
Repository = "https://github.com/lucidrains/vit-pytorch"
[tool.setuptools]
include-package-data = true
[tool.setuptools.packages.find]
include = ["vit_pytorch*"]
exclude = ["examples*", "tests*", "test*"]
[tool.pytest.ini_options]
testpaths = ["tests", "."]
python_files = ["test_*.py", "*_test.py"]
addopts = "-q"
filterwarnings = [
"ignore::FutureWarning",
]

View File

@@ -1,42 +0,0 @@
from setuptools import setup, find_packages
with open('README.md') as f:
long_description = f.read()
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.8.2',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description=long_description,
long_description_content_type = 'text/markdown',
author = 'Phil Wang',
author_email = 'lucidrains@gmail.com',
url = 'https://github.com/lucidrains/vit-pytorch',
keywords = [
'artificial intelligence',
'attention mechanism',
'image recognition'
],
install_requires=[
'einops>=0.7.0',
'torch>=1.10',
'torchvision'
],
setup_requires=[
'pytest-runner',
],
tests_require=[
'pytest',
'torch==2.4.0',
'torchvision==0.19.0'
],
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)

View File

@@ -0,0 +1,161 @@
from contextlib import nullcontext
import torch
from torch import is_tensor, randn
from torch.nn import Module, Linear, Parameter
from torch.utils._pytree import tree_flatten, tree_unflatten
from einops import rearrange, repeat
# helper functions
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
# classes
class AcceptVideoWrapper(Module):
def __init__(
self,
image_net: Module,
forward_function = 'forward',
add_time_pos_emb = False,
dim_emb = None,
time_seq_len = None,
embed_is_channel_first = False,
output_pos_add_pos_emb = 0, # defaults to first output position to add embedding
proj_embed_to_dim = None
):
super().__init__()
self.image_net = image_net
self.forward_function = forward_function # for openclip, used in TRI-LBM
self.add_time_pos_emb = add_time_pos_emb
self.output_pos_add_pos_emb = output_pos_add_pos_emb
# maybe project the image embedding
self.embed_proj = None
if exists(proj_embed_to_dim):
assert exists(dim_emb), '`dim_emb` must be passed in'
self.embed_proj = Linear(dim_emb, proj_embed_to_dim)
# time positional embedding
if add_time_pos_emb:
assert exists(dim_emb) and exists(time_seq_len), '`dim_emb` and `time_seq_len` must be set if adding positional embeddings to the output'
self.time_seq_len = time_seq_len
dim_pos_emb = default(proj_embed_to_dim, dim_emb)
self.pos_emb = Parameter(randn(time_seq_len, dim_pos_emb) * 1e-2)
self.embed_is_channel_first = embed_is_channel_first
def forward(
self,
video, # (b c t h w)
eval_with_no_grad = False,
forward_kwargs = dict()
):
add_time_pos_emb = self.add_time_pos_emb
time = video.shape[2]
# maybe validate time positional embedding
if add_time_pos_emb:
assert time <= self.time_seq_len, f'received video with {time} frames but `time_seq_len` ({self.time_seq_len}) is too low'
video = rearrange(video, 'b c t h w -> b t c h w')
video = rearrange(video, 'b t ... -> (b t) ...')
# forward through image net for outputs
func = getattr(self.image_net, self.forward_function)
if eval_with_no_grad:
self.image_net.eval()
context = torch.no_grad if eval_with_no_grad else nullcontext
with context():
outputs = func(video, **forward_kwargs)
# handle multiple outputs, say logits and embeddings returned from extractor - also handle some reduce aux loss being returned
outputs, tree_spec = tree_flatten(outputs)
outputs = tuple(rearrange(t, '(b t) ... -> b t ...', t = time) if is_tensor(t) and t.numel() > 1 else t for t in outputs)
# maybe project embedding
if exists(self.embed_proj):
outputs = list(outputs)
embed = outputs[self.output_pos_add_pos_emb]
outputs[self.output_pos_add_pos_emb] = self.embed_proj(embed)
# maybe add time positional embedding
if add_time_pos_emb:
outputs = list(outputs)
embed = outputs[self.output_pos_add_pos_emb]
pos_emb = rearrange(self.pos_emb, 't d -> 1 t d')
# handle the network outputting embeddings with spatial dimensions intact - assume embedded dimension is last
dims_to_unsqueeze = embed.ndim - pos_emb.ndim
one_dims = ((1,) * dims_to_unsqueeze)
if self.embed_is_channel_first:
pos_emb = pos_emb.reshape(*pos_emb.shape, *one_dims)
else:
pos_emb = pos_emb.reshape(*pos_emb.shape[:2], *one_dims, pos_emb.shape[-1])
pos_emb = pos_emb[:, :embed.shape[1]]
embed = embed + pos_emb
outputs[self.output_pos_add_pos_emb] = embed
return tree_unflatten(outputs, tree_spec)
# main
if __name__ == '__main__':
from vit_pytorch import ViT
v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
videos = torch.randn(1, 3, 7, 256, 256)
# step up the difficulty and return embeddings for robotics
from vit_pytorch.extractor import Extractor
v = Extractor(v)
video_acceptor = AcceptVideoWrapper(v, add_time_pos_emb = True, output_pos_add_pos_emb = 1, time_seq_len = 12, dim_emb = 1024, proj_embed_to_dim = 512)
logits, embeddings = video_acceptor(videos, eval_with_no_grad = True) # always (batch, channels, time, height, width) - time is always dimension 2
assert logits.shape == (1, 7, 1000)
assert embeddings.shape == (1, 7, 65, 512)

View File

@@ -316,6 +316,9 @@ class CCT(nn.Module):
pooling_kernel_size=3,
pooling_stride=2,
pooling_padding=1,
dropout_rate=0.,
attention_dropout=0.1,
stochastic_depth_rate=0.1,
*args, **kwargs
):
super().__init__()
@@ -340,9 +343,9 @@ class CCT(nn.Module):
width=img_width),
embedding_dim=embedding_dim,
seq_pool=True,
dropout_rate=0.,
attention_dropout=0.1,
stochastic_depth=0.1,
dropout_rate=dropout_rate,
attention_dropout=attention_dropout,
stochastic_depth_rate=stochastic_depth_rate,
*args, **kwargs)
def forward(self, x):

View File

@@ -167,8 +167,10 @@ class Tokenizer(nn.Module):
stride,
padding,
frame_stride=1,
frame_padding=None,
frame_pooling_stride=1,
frame_pooling_kernel_size=1,
frame_pooling_padding=None,
pooling_kernel_size=3,
pooling_stride=2,
pooling_padding=1,
@@ -188,16 +190,22 @@ class Tokenizer(nn.Module):
n_filter_list_pairs = zip(n_filter_list[:-1], n_filter_list[1:])
if frame_padding is None:
frame_padding = frame_kernel_size // 2
if frame_pooling_padding is None:
frame_pooling_padding = frame_pooling_kernel_size // 2
self.conv_layers = nn.Sequential(
*[nn.Sequential(
nn.Conv3d(chan_in, chan_out,
kernel_size=(frame_kernel_size, kernel_size, kernel_size),
stride=(frame_stride, stride, stride),
padding=(frame_kernel_size // 2, padding, padding), bias=conv_bias),
padding=(frame_padding, padding, padding), bias=conv_bias),
nn.Identity() if not exists(activation) else activation(),
nn.MaxPool3d(kernel_size=(frame_pooling_kernel_size, pooling_kernel_size, pooling_kernel_size),
stride=(frame_pooling_stride, pooling_stride, pooling_stride),
padding=(frame_pooling_kernel_size // 2, pooling_padding, pooling_padding)) if max_pool else nn.Identity()
padding=(frame_pooling_padding, pooling_padding, pooling_padding)) if max_pool else nn.Identity()
)
for chan_in, chan_out in n_filter_list_pairs
])
@@ -324,8 +332,10 @@ class CCT(nn.Module):
n_conv_layers=1,
frame_stride=1,
frame_kernel_size=3,
frame_padding=None,
frame_pooling_kernel_size=1,
frame_pooling_stride=1,
frame_pooling_padding=None,
kernel_size=7,
stride=2,
padding=3,
@@ -342,8 +352,10 @@ class CCT(nn.Module):
n_output_channels=embedding_dim,
frame_stride=frame_stride,
frame_kernel_size=frame_kernel_size,
frame_padding=frame_padding,
frame_pooling_stride=frame_pooling_stride,
frame_pooling_kernel_size=frame_pooling_kernel_size,
frame_pooling_padding=frame_pooling_padding,
kernel_size=kernel_size,
stride=stride,
padding=padding,

204
vit_pytorch/jumbo_vit.py Normal file
View File

@@ -0,0 +1,204 @@
# Simpler Fast Vision Transformers with a Jumbo CLS Token
# https://arxiv.org/abs/2502.15021
import torch
from torch import nn
from torch.nn import Module, ModuleList
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange
# helpers
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def divisible_by(num, den):
return (num % den) == 0
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
assert divisible_by(dim, 4), "feature dimension must be multiple of 4 for sincos emb"
omega = torch.arange(dim // 4) / (dim // 4 - 1)
omega = temperature ** -omega
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pos_emb = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
return pos_emb.type(dtype)
# classes
def FeedForward(dim, mult = 4.):
hidden_dim = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class JumboViT(Module):
def __init__(
self,
*,
image_size,
patch_size,
num_classes,
dim,
depth,
heads,
mlp_dim,
num_jumbo_cls = 1, # differing from paper, allow for multiple jumbo cls, so one could break it up into 2 jumbo cls tokens with 3x the dim, as an example
jumbo_cls_k = 6, # they use a CLS token with this factor times the dimension - 6 was the value they settled on
jumbo_ff_mult = 2, # expansion factor of the jumbo cls token feedforward
channels = 3,
dim_head = 64
):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert divisible_by(image_height, patch_height) and divisible_by(image_width, patch_width), 'Image dimensions must be divisible by the patch size.'
patch_dim = channels * patch_height * patch_width
self.to_patch_embedding = nn.Sequential(
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.pos_embedding = posemb_sincos_2d(
h = image_height // patch_height,
w = image_width // patch_width,
dim = dim,
)
jumbo_cls_dim = dim * jumbo_cls_k
self.jumbo_cls_token = nn.Parameter(torch.zeros(num_jumbo_cls, jumbo_cls_dim))
jumbo_cls_to_tokens = Rearrange('b n (k d) -> b (n k) d', k = jumbo_cls_k)
self.jumbo_cls_to_tokens = jumbo_cls_to_tokens
self.norm = nn.LayerNorm(dim)
self.layers = ModuleList([])
# attention and feedforwards
self.jumbo_ff = nn.Sequential(
Rearrange('b (n k) d -> b n (k d)', k = jumbo_cls_k),
FeedForward(jumbo_cls_dim, int(jumbo_cls_dim * jumbo_ff_mult)), # they use separate parameters for the jumbo feedforward, weight tied for parameter efficient
jumbo_cls_to_tokens
)
for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head),
FeedForward(dim, mlp_dim),
]))
self.to_latent = nn.Identity()
self.linear_head = nn.Linear(dim, num_classes)
def forward(self, img):
batch, device = img.shape[0], img.device
x = self.to_patch_embedding(img)
# pos embedding
pos_emb = self.pos_embedding.to(device, dtype = x.dtype)
x = x + pos_emb
# add cls tokens
cls_tokens = repeat(self.jumbo_cls_token, 'nj d -> b nj d', b = batch)
jumbo_tokens = self.jumbo_cls_to_tokens(cls_tokens)
x, cls_packed_shape = pack([jumbo_tokens, x], 'b * d')
# attention and feedforwards
for layer, (attn, ff) in enumerate(self.layers, start = 1):
is_last = layer == len(self.layers)
x = attn(x) + x
# jumbo feedforward
jumbo_cls_tokens, x = unpack(x, cls_packed_shape, 'b * d')
x = ff(x) + x
jumbo_cls_tokens = self.jumbo_ff(jumbo_cls_tokens) + jumbo_cls_tokens
if is_last:
continue
x, _ = pack([jumbo_cls_tokens, x], 'b * d')
pooled = reduce(jumbo_cls_tokens, 'b n d -> b d', 'mean')
# normalization and project to logits
embed = self.norm(pooled)
embed = self.to_latent(embed)
logits = self.linear_head(embed)
return logits
# copy pasteable file
if __name__ == '__main__':
v = JumboViT(
num_classes = 1000,
image_size = 64,
patch_size = 8,
dim = 16,
depth = 2,
heads = 2,
mlp_dim = 32,
jumbo_cls_k = 3,
jumbo_ff_mult = 2,
)
images = torch.randn(1, 3, 64, 64)
logits = v(images)
assert logits.shape == (1, 1000)

View File

@@ -6,9 +6,6 @@ from functools import partial
import torch
import packaging.version as pkg_version
if pkg_version.parse(torch.__version__) < pkg_version.parse('2.4'):
print('nested tensor NaViT was tested on pytorch 2.4')
from torch import nn, Tensor
import torch.nn.functional as F
from torch.nn import Module, ModuleList
@@ -44,7 +41,7 @@ def FeedForward(dim, hidden_dim, dropout = 0.):
)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., qk_norm = True):
super().__init__()
self.norm = nn.LayerNorm(dim, bias = False)
@@ -59,8 +56,8 @@ class Attention(Module):
# in the paper, they employ qk rmsnorm, a way to stabilize attention
# will use layernorm in place of rmsnorm, which has been shown to work in certain papers. requires l2norm on non-ragged dimension to be supported in nested tensors
self.query_norm = nn.LayerNorm(dim_head, bias = False)
self.key_norm = nn.LayerNorm(dim_head, bias = False)
self.query_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity()
self.key_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity()
self.dropout = dropout
@@ -114,13 +111,13 @@ class Attention(Module):
return self.to_out(out)
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., qk_norm = True):
super().__init__()
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, qk_norm = qk_norm),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
@@ -149,9 +146,15 @@ class NaViT(Module):
dim_head = 64,
dropout = 0.,
emb_dropout = 0.,
qk_rmsnorm = True,
token_dropout_prob: float | None = None
):
super().__init__()
if pkg_version.parse(torch.__version__) < pkg_version.parse('2.5'):
print('nested tensor NaViT was tested on pytorch 2.5')
image_height, image_width = pair(image_size)
# what percent of tokens to dropout
@@ -182,7 +185,7 @@ class NaViT(Module):
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, qk_rmsnorm)
# final attention pooling queries
@@ -323,3 +326,5 @@ if __name__ == '__main__':
]
assert v(images).shape == (5, 1000)
v(images).sum().backward()

View File

@@ -6,9 +6,6 @@ from functools import partial
import torch
import packaging.version as pkg_version
if pkg_version.parse(torch.__version__) < pkg_version.parse('2.4'):
print('nested tensor NaViT was tested on pytorch 2.4')
from torch import nn, Tensor
import torch.nn.functional as F
from torch.nn import Module, ModuleList
@@ -44,7 +41,7 @@ def FeedForward(dim, hidden_dim, dropout = 0.):
)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., qk_norm = True):
super().__init__()
self.norm = nn.LayerNorm(dim, bias = False)
@@ -59,8 +56,8 @@ class Attention(Module):
# in the paper, they employ qk rmsnorm, a way to stabilize attention
# will use layernorm in place of rmsnorm, which has been shown to work in certain papers. requires l2norm on non-ragged dimension to be supported in nested tensors
self.query_norm = nn.LayerNorm(dim_head, bias = False)
self.key_norm = nn.LayerNorm(dim_head, bias = False)
self.query_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity()
self.key_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity()
self.dropout = dropout
@@ -86,17 +83,6 @@ class Attention(Module):
# split heads
def split_heads(t):
return t.unflatten(-1, (self.heads, self.dim_head)).transpose(1, 2).contiguous()
# queries, keys, values
query = self.to_queries(x)
key = self.to_keys(context)
value = self.to_values(context)
# split heads
def split_heads(t):
return t.unflatten(-1, (self.heads, self.dim_head))
@@ -126,13 +112,13 @@ class Attention(Module):
return self.to_out(out)
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., qk_norm = True):
super().__init__()
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, qk_norm = qk_norm),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
@@ -164,11 +150,15 @@ class NaViT(Module):
dropout = 0.,
emb_dropout = 0.,
num_registers = 4,
qk_rmsnorm = True,
token_dropout_prob: float | None = None
):
super().__init__()
image_height, image_width = pair(image_size)
if pkg_version.parse(torch.__version__) < pkg_version.parse('2.5'):
print('nested tensor NaViT was tested on pytorch 2.5')
# what percent of tokens to dropout
# if int or float given, then assume constant dropout prob
# otherwise accept a callback that in turn calculates dropout prob from height and width
@@ -209,7 +199,7 @@ class NaViT(Module):
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, qk_rmsnorm)
# final attention pooling queries
@@ -336,7 +326,7 @@ class NaViT(Module):
if __name__ == '__main__':
# works for torch 2.4
# works for torch 2.5
v = NaViT(
image_size = 256,
@@ -362,3 +352,5 @@ if __name__ == '__main__':
]
assert v(volumes).shape == (5, 1000)
v(volumes).sum().backward()

View File

@@ -76,8 +76,8 @@ class Attention(Module):
self.dropout = dropout
self.q_scale = nn.Parameter(torch.ones(dim_inner) * (dim_head ** 0.25))
self.k_scale = nn.Parameter(torch.ones(dim_inner) * (dim_head ** 0.25))
self.q_scale = nn.Parameter(torch.ones(heads, 1, dim_head) * (dim_head ** 0.25))
self.k_scale = nn.Parameter(torch.ones(heads, 1, dim_head) * (dim_head ** 0.25))
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
self.merge_heads = Rearrange('b h n d -> b n (h d)')
@@ -90,15 +90,15 @@ class Attention(Module):
):
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
q = q * self.q_scale
k = k * self.k_scale
q, k, v = map(self.split_heads, (q, k, v))
# query key rmsnorm
q, k = map(l2norm, (q, k))
q = q * self.q_scale
k = k * self.k_scale
# scale is 1., as scaling factor is moved to s_qk (dk ^ 0.25) - eq. 16
out = F.scaled_dot_product_attention(
@@ -179,18 +179,18 @@ class nViT(Module):
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1 = patch_size, p2 = patch_size),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
NormLinear(patch_dim, dim, norm_dim_in = False),
)
self.abs_pos_emb = nn.Embedding(num_patches, dim)
self.abs_pos_emb = NormLinear(dim, num_patches)
residual_lerp_scale_init = default(residual_lerp_scale_init, 1. / depth)
# layers
self.dim = dim
self.scale = dim ** 0.5
self.layers = ModuleList([])
self.residual_lerp_scales = nn.ParameterList([])
@@ -201,8 +201,8 @@ class nViT(Module):
]))
self.residual_lerp_scales.append(nn.ParameterList([
nn.Parameter(torch.ones(dim) * residual_lerp_scale_init),
nn.Parameter(torch.ones(dim) * residual_lerp_scale_init),
nn.Parameter(torch.ones(dim) * residual_lerp_scale_init / self.scale),
nn.Parameter(torch.ones(dim) * residual_lerp_scale_init / self.scale),
]))
self.logit_scale = nn.Parameter(torch.ones(num_classes))
@@ -225,22 +225,23 @@ class nViT(Module):
tokens = self.to_patch_embedding(images)
pos_emb = self.abs_pos_emb(torch.arange(tokens.shape[-2], device = device))
seq_len = tokens.shape[-2]
pos_emb = self.abs_pos_emb.weight[torch.arange(seq_len, device = device)]
tokens = l2norm(tokens + pos_emb)
for (attn, ff), (attn_alpha, ff_alpha) in zip(self.layers, self.residual_lerp_scales):
attn_out = l2norm(attn(tokens))
tokens = l2norm(tokens.lerp(attn_out, attn_alpha))
tokens = l2norm(tokens.lerp(attn_out, attn_alpha * self.scale))
ff_out = l2norm(ff(tokens))
tokens = l2norm(tokens.lerp(ff_out, ff_alpha))
tokens = l2norm(tokens.lerp(ff_out, ff_alpha * self.scale))
pooled = reduce(tokens, 'b n d -> b d', 'mean')
logits = self.to_pred(pooled)
logits = logits * self.logit_scale * (self.dim ** 0.5)
logits = logits * self.logit_scale * self.scale
return logits

View File

@@ -0,0 +1,233 @@
"""
ViT + Hyper-Connections + Register Tokens
https://arxiv.org/abs/2409.19606
"""
import torch
from torch import nn, tensor
from torch.nn import Module, ModuleList
from einops import rearrange, repeat, reduce, einsum, pack, unpack
from einops.layers.torch import Rearrange
# b - batch, h - heads, n - sequence, e - expansion rate / residual streams, d - feature dimension
# helpers
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
omega = torch.arange(dim // 4) / (dim // 4 - 1)
omega = 1.0 / (temperature ** omega)
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
return pe.type(dtype)
# hyper connections
class HyperConnection(Module):
def __init__(
self,
dim,
num_residual_streams,
layer_index
):
""" Appendix J - Algorithm 2, Dynamic only """
super().__init__()
self.norm = nn.LayerNorm(dim, bias = False)
self.num_residual_streams = num_residual_streams
self.layer_index = layer_index
self.static_beta = nn.Parameter(torch.ones(num_residual_streams))
init_alpha0 = torch.zeros((num_residual_streams, 1))
init_alpha0[layer_index % num_residual_streams, 0] = 1.
self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + 1))
self.dynamic_alpha_scale = nn.Parameter(tensor(1e-2))
self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim))
self.dynamic_beta_scale = nn.Parameter(tensor(1e-2))
def width_connection(self, residuals):
normed = self.norm(residuals)
wc_weight = (normed @ self.dynamic_alpha_fn).tanh()
dynamic_alpha = wc_weight * self.dynamic_alpha_scale
alpha = dynamic_alpha + self.static_alpha
dc_weight = (normed @ self.dynamic_beta_fn).tanh()
dynamic_beta = dc_weight * self.dynamic_beta_scale
beta = dynamic_beta + self.static_beta
# width connection
mix_h = einsum(alpha, residuals, '... e1 e2, ... e1 d -> ... e2 d')
branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
return branch_input, residuals, beta
def depth_connection(
self,
branch_output,
residuals,
beta
):
return einsum(branch_output, beta, "b n d, b n e -> b n e d") + residuals
# classes
class FeedForward(Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
def forward(self, x):
return self.net(x)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, num_residual_streams):
super().__init__()
self.num_residual_streams = num_residual_streams
self.norm = nn.LayerNorm(dim)
self.layers = ModuleList([])
for layer_index in range(depth):
self.layers.append(nn.ModuleList([
HyperConnection(dim, num_residual_streams, layer_index),
Attention(dim, heads = heads, dim_head = dim_head),
HyperConnection(dim, num_residual_streams, layer_index),
FeedForward(dim, mlp_dim)
]))
def forward(self, x):
x = repeat(x, 'b n d -> b n e d', e = self.num_residual_streams)
for attn_hyper_conn, attn, ff_hyper_conn, ff in self.layers:
x, attn_res, beta = attn_hyper_conn.width_connection(x)
x = attn(x)
x = attn_hyper_conn.depth_connection(x, attn_res, beta)
x, ff_res, beta = ff_hyper_conn.width_connection(x)
x = ff(x)
x = ff_hyper_conn.depth_connection(x, ff_res, beta)
x = reduce(x, 'b n e d -> b n d', 'sum')
return self.norm(x)
class SimpleViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, num_residual_streams, num_register_tokens = 4, channels = 3, dim_head = 64):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
patch_dim = channels * patch_height * patch_width
self.to_patch_embedding = nn.Sequential(
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim))
self.pos_embedding = posemb_sincos_2d(
h = image_height // patch_height,
w = image_width // patch_width,
dim = dim,
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, num_residual_streams)
self.pool = "mean"
self.to_latent = nn.Identity()
self.linear_head = nn.Linear(dim, num_classes)
def forward(self, img):
batch, device = img.shape[0], img.device
x = self.to_patch_embedding(img)
x += self.pos_embedding.to(x)
r = repeat(self.register_tokens, 'n d -> b n d', b = batch)
x, ps = pack([x, r], 'b * d')
x = self.transformer(x)
x, _ = unpack(x, ps, 'b * d')
x = x.mean(dim = 1)
x = self.to_latent(x)
return self.linear_head(x)
# main
if __name__ == '__main__':
vit = SimpleViT(
num_classes = 1000,
image_size = 256,
patch_size = 8,
dim = 1024,
depth = 12,
heads = 8,
mlp_dim = 2048,
num_residual_streams = 8
)
images = torch.randn(3, 3, 256, 256)
logits = vit(images)

View File

@@ -0,0 +1,159 @@
import torch
from torch import nn
from torch.nn import Module, ModuleList
from einops import rearrange
from einops.layers.torch import Rearrange
# helpers
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
omega = torch.arange(dim // 4) / (dim // 4 - 1)
omega = 1.0 / (temperature ** omega)
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
return pe.type(dtype)
# classes
def FeedForward(dim, hidden_dim):
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, learned_value_residual_mix = False):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
self.to_residual_mix = nn.Sequential(
nn.Linear(dim, heads),
nn.Sigmoid(),
Rearrange('b n h -> b h n 1')
) if learned_value_residual_mix else (lambda _: 0.5)
def forward(self, x, value_residual = None):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
if exists(value_residual):
mix = self.to_residual_mix(x)
v = v * mix + value_residual * (1. - mix)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out), v
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = ModuleList([])
for i in range(depth):
is_first = i == 0
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, learned_value_residual_mix = not is_first),
FeedForward(dim, mlp_dim)
]))
def forward(self, x):
value_residual = None
for attn, ff in self.layers:
attn_out, values = attn(x, value_residual = value_residual)
value_residual = default(value_residual, values)
x = attn_out + x
x = ff(x) + x
return self.norm(x)
class SimpleViT(Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
patch_dim = channels * patch_height * patch_width
self.to_patch_embedding = nn.Sequential(
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.pos_embedding = posemb_sincos_2d(
h = image_height // patch_height,
w = image_width // patch_width,
dim = dim,
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
self.pool = "mean"
self.to_latent = nn.Identity()
self.linear_head = nn.Linear(dim, num_classes)
def forward(self, img):
device = img.device
x = self.to_patch_embedding(img)
x += self.pos_embedding.to(device, dtype=x.dtype)
x = self.transformer(x)
x = x.mean(dim = 1)
x = self.to_latent(x)
return self.linear_head(x)
# quick test
if __name__ == '__main__':
v = SimpleViT(
num_classes = 1000,
image_size = 256,
patch_size = 8,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048,
)
images = torch.randn(2, 3, 256, 256)
logits = v(images)

191
vit_pytorch/vit_nd.py Normal file
View File

@@ -0,0 +1,191 @@
from __future__ import annotations
import torch
from torch import nn
from torch.nn import Module
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
def join(arr, delimiter = ' '):
return delimiter.join(arr)
def ensure_tuple(t, length):
if isinstance(t, (tuple, list)):
assert len(t) == length, f'Expected tuple of length {length}, got {len(t)}'
return tuple(t)
return (t,) * length
# classes
class FeedForward(Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)
class ViTND(Module):
def __init__(
self,
*,
ndim: int,
input_shape: int | tuple[int, ...],
patch_size: int | tuple[int, ...],
num_classes: int,
dim: int,
depth: int,
heads: int,
mlp_dim: int,
pool: str = 'cls',
channels: int = 3,
dim_head: int = 64,
dropout: float = 0.,
emb_dropout: float = 0.
):
super().__init__()
assert 1 <= ndim <= 7, 'ndim must be between 1 and 7'
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.ndim = ndim
self.pool = pool
input_shape = ensure_tuple(input_shape, ndim)
patch_size = ensure_tuple(patch_size, ndim)
for i, (inp_dim, patch_dim) in enumerate(zip(input_shape, patch_size)):
assert inp_dim % patch_dim == 0, f'Input dimension {i} ({inp_dim}) must be divisible by patch size ({patch_dim})'
num_patches_per_dim = [inp_dim // patch_dim for inp_dim, patch_dim in zip(input_shape, patch_size)]
num_patches = 1
for n in num_patches_per_dim:
num_patches *= n
patch_dim = channels
for p in patch_size:
patch_dim *= p
dim_names = 'fghijkl'[:ndim]
input_dims = [f'({d} p{i})' for i, d in enumerate(dim_names)]
patch_dims = [f'p{i}' for i in range(ndim)]
input_pattern = f'b c {join(input_dims)}'
output_pattern = f'b ({join(dim_names)}) ({join(patch_dims)} c)'
rearrange_str = f'{input_pattern} -> {output_pattern}'
rearrange_kwargs = {f'p{i}': p for i, p in enumerate(patch_size)}
self.to_patch_embedding = nn.Sequential(
Rearrange(rearrange_str, **rearrange_kwargs),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes)
def forward(self, x):
x = self.to_patch_embedding(x)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim = 1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x)
x = x[:, 1:].mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)
if __name__ == '__main__':
model = ViTND(
ndim = 4,
input_shape = (8, 16, 32, 64),
patch_size = (2, 4, 4, 8),
num_classes = 1000,
dim = 512,
depth = 6,
heads = 8,
mlp_dim = 2048,
channels = 3,
dropout = 0.1,
emb_dropout = 0.1
)
occupancy_time = torch.randn(2, 3, 8, 16, 32, 64)
logits = model(occupancy_time)

View File

@@ -0,0 +1,306 @@
from __future__ import annotations
import torch
from torch import nn, arange, cat, stack, Tensor
from torch.nn import Module, ModuleList
import torch.nn.functional as F
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange
# helpers
def exists(val):
return val is not None
def l2norm(t):
return F.normalize(t, dim = -1, p = 2)
def join(arr, delimiter = ' '):
return delimiter.join(arr)
def ensure_tuple(t, length):
if isinstance(t, (tuple, list)):
assert len(t) == length, f'Expected tuple of length {length}, got {len(t)}'
return tuple(t)
return (t,) * length
# golden gate rotary - Jerry Xiong, PhD student at UIUC
# https://jerryxio.ng/posts/nd-rope/
def _phi(m: int) -> float:
x = 2.0
for _ in range(10):
x = (1 + x) ** (1.0 / (m + 1.0))
return x
def make_directions(n: int, d: int) -> Tensor:
g = _phi(d)
alpha = (1.0 / g) ** arange(1, d + 1, dtype = torch.float64)
i = arange(1, n + 1, dtype = torch.float64).unsqueeze(1)
z = torch.fmod(i * alpha, 1.0)
directions = torch.erfinv(2.0 * z - 1.0)
directions = l2norm(directions)
return directions.float()
class GoldenGateRoPENd(Module):
def __init__(
self,
dim_pos: int,
heads: int,
dim_head: int,
rope_min_freq: float = 1.0,
rope_max_freq: float = 10000.0,
rope_p_zero_freqs: float = 0.0, # proportion of frequencies set to 0
):
super().__init__()
n_freqs = dim_head // 2
n_zero_freqs = round(rope_p_zero_freqs * n_freqs)
omega = cat((
torch.zeros(n_zero_freqs),
rope_min_freq * (rope_max_freq / rope_min_freq) ** torch.linspace(0, 1, n_freqs - n_zero_freqs),
))
directions = rearrange(
make_directions(heads * n_freqs, dim_pos),
'(h f) p -> h f p',
h = heads
)
omega_expanded = rearrange(omega, 'f -> f 1')
self.register_buffer('freqs', directions * omega_expanded) # shape: (h, f, p)
def forward(self, input: Tensor, pos: Tensor) -> Tensor:
# input shape: (b, h, n, d) where d = head_dim
# pos shape: (b, n, p) where p = pos_dim
# self.freqs shape: (h, f, p) where f = d // 2
x, y = input.float().chunk(2, dim = -1) # both (b, h, n, f)
# Expand dimensions for broadcasting
freqs = rearrange(self.freqs, 'h f p -> 1 h 1 f p')
positions = rearrange(pos.float(), 'b n p -> b 1 n 1 p')
# Compute theta for each (batch, head, seq, freq)
theta = reduce(freqs * positions, 'b h n f p -> b h n f', 'sum')
cos_theta = torch.cos(theta)
sin_theta = torch.sin(theta)
# Apply rotation
x_out = x * cos_theta - y * sin_theta
y_out = x * sin_theta + y * cos_theta
output = cat((x_out, y_out), dim=-1)
return output.type_as(input)
# classes
class FeedForward(Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., rotary_emb = None):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.rotary_emb = rotary_emb
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x, pos = None):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
# Apply rotary embeddings if available
if exists(self.rotary_emb):
assert exists(pos)
q = self.rotary_emb(q, pos)
k = self.rotary_emb(k, pos)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., rotary_emb = None):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, rotary_emb = rotary_emb),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x, pos = None):
for attn, ff in self.layers:
x = attn(x, pos) + x
x = ff(x) + x
return self.norm(x)
class ViTND(Module):
def __init__(
self,
*,
ndim: int,
input_shape: int | tuple[int, ...],
patch_size: int | tuple[int, ...],
num_classes: int,
dim: int,
depth: int,
heads: int,
mlp_dim: int,
channels: int = 3,
dim_head: int = 64,
dropout: float = 0.,
emb_dropout: float = 0.,
rope_min_freq: float = 1.0,
rope_max_freq: float = 10000.0,
rope_p_zero_freqs: float = 0.0
):
super().__init__()
assert 1 <= ndim <= 7, 'ndim must be between 1 and 7'
self.ndim = ndim
input_shape = ensure_tuple(input_shape, ndim)
patch_size = ensure_tuple(patch_size, ndim)
for i, (inp_dim, patch_dim) in enumerate(zip(input_shape, patch_size)):
assert inp_dim % patch_dim == 0, f'Input dimension {i} ({inp_dim}) must be divisible by patch size ({patch_dim})'
num_patches_per_dim = [inp_dim // patch_dim for inp_dim, patch_dim in zip(input_shape, patch_size)]
num_patches = 1
for n in num_patches_per_dim:
num_patches *= n
patch_dim = channels
for p in patch_size:
patch_dim *= p
dim_names = 'fghijkl'[:ndim]
input_dims = [f'({d} p{i})' for i, d in enumerate(dim_names)]
patch_dims = [f'p{i}' for i in range(ndim)]
input_pattern = f'b c {join(input_dims)}'
output_pattern = f'b {join(dim_names)} ({join(patch_dims)} c)'
rearrange_str = f'{input_pattern} -> {output_pattern}'
rearrange_kwargs = {f'p{i}': p for i, p in enumerate(patch_size)}
self.to_patch_embedding = nn.Sequential(
Rearrange(rearrange_str, **rearrange_kwargs),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.dropout = nn.Dropout(emb_dropout)
# Create rotary embeddings
self.rotary_emb = GoldenGateRoPENd(
dim_pos = ndim,
heads = heads,
dim_head = dim_head,
rope_min_freq = rope_min_freq,
rope_max_freq = rope_max_freq,
rope_p_zero_freqs = rope_p_zero_freqs
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, rotary_emb = self.rotary_emb)
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes)
def forward(
self,
x,
return_embed = False
):
x = self.to_patch_embedding(x) # (b, *spatial_dims, patch_dim)
batch, *spatial_dims, _, device = *x.shape, x.device
# Generate position coordinates
grids = [arange(d, device = device, dtype = torch.float32) for d in spatial_dims]
grid = torch.meshgrid(*grids, indexing = 'ij')
pos = stack(grid, dim = -1) # (*spatial_dims, ndim)
# flatten spatial dimensions for attention with nd rotary
pos = repeat(pos, '... p -> b (...) p', b = batch)
x, packed_shape = pack([x], 'b * d')
x = self.dropout(x)
embed = self.transformer(x, pos)
# return the embed with reconstituted patch shape
if return_embed:
embed, = unpack(embed, packed_shape, 'b * d')
return embed
# pooling to logits
pooled = reduce(embed, 'b n d -> b d', 'mean')
pooled = self.to_latent(pooled)
return self.mlp_head(pooled)
if __name__ == '__main__':
model = ViTND(
ndim = 5,
input_shape = (4, 8, 16, 32, 64),
patch_size = (2, 2, 4, 4, 8),
num_classes = 1000,
dim = 512,
depth = 6,
heads = 8,
mlp_dim = 2048,
channels = 3,
dropout = 0.1,
emb_dropout = 0.1
)
data = torch.randn(2, 3, 4, 8, 16, 32, 64)
logits = model(data)
embed = model(data, return_embed = True) # (2, 2, 4, 4, 8, 8, 512)