mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
Compare commits
13 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
39fd9ac8be | ||
|
|
3becf087bb | ||
|
|
f6bc14c81d | ||
|
|
845c844b3b | ||
|
|
5f2bc0c796 | ||
|
|
35bf273037 | ||
|
|
1123063a5e | ||
|
|
f8bec5ede2 | ||
|
|
297e7d00a2 | ||
|
|
29ac8e143c | ||
|
|
e05cd6d8b8 | ||
|
|
b46233c3d6 | ||
|
|
68e13a3c7d |
4
.github/workflows/python-publish.yml
vendored
4
.github/workflows/python-publish.yml
vendored
@@ -18,9 +18,9 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v2
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.x'
|
||||
- name: Install dependencies
|
||||
|
||||
11
.github/workflows/python-test.yml
vendored
11
.github/workflows/python-test.yml
vendored
@@ -18,18 +18,17 @@ jobs:
|
||||
python-version: [3.8, 3.9]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install pytest
|
||||
python -m pip install wheel
|
||||
python -m pip install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cpu
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
||||
python -m pip install -e .
|
||||
python -m pip install pytest
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
python setup.py test
|
||||
pytest -q
|
||||
|
||||
@@ -2181,4 +2181,13 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{xiong2025ndrope,
|
||||
author = {Jerry Xiong},
|
||||
title = {On n-dimensional rotary positional embeddings},
|
||||
year = {2025},
|
||||
url = {https://jerryxio.ng/posts/nd-rope/}
|
||||
}
|
||||
```
|
||||
|
||||
*I visualise a time when we will be to robots what dogs are to humans, and I’m rooting for the machines.* — Claude Shannon
|
||||
|
||||
63
pyproject.toml
Normal file
63
pyproject.toml
Normal file
@@ -0,0 +1,63 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=61", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "vit-pytorch"
|
||||
version = "1.12.5"
|
||||
description = "Vision Transformer (ViT) - Pytorch"
|
||||
readme = { file = "README.md", content-type = "text/markdown" }
|
||||
license = { file = "LICENSE" }
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" },
|
||||
]
|
||||
requires-python = ">=3.8"
|
||||
keywords = [
|
||||
"artificial intelligence",
|
||||
"attention mechanism",
|
||||
"image recognition",
|
||||
]
|
||||
classifiers = [
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Developers",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3 :: Only",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
]
|
||||
dependencies = [
|
||||
"einops>=0.7.0",
|
||||
"torch>=1.10",
|
||||
"torchvision",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
test = [
|
||||
"pytest",
|
||||
"torch==2.4.0",
|
||||
"torchvision==0.19.0",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/lucidrains/vit-pytorch"
|
||||
Repository = "https://github.com/lucidrains/vit-pytorch"
|
||||
|
||||
[tool.setuptools]
|
||||
include-package-data = true
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
include = ["vit_pytorch*"]
|
||||
exclude = ["examples*", "tests*", "test*"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests", "."]
|
||||
python_files = ["test_*.py", "*_test.py"]
|
||||
addopts = "-q"
|
||||
filterwarnings = [
|
||||
"ignore::FutureWarning",
|
||||
]
|
||||
42
setup.py
42
setup.py
@@ -1,42 +0,0 @@
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
with open('README.md') as f:
|
||||
long_description = f.read()
|
||||
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '1.11.0',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
long_description = long_description,
|
||||
long_description_content_type = 'text/markdown',
|
||||
author = 'Phil Wang',
|
||||
author_email = 'lucidrains@gmail.com',
|
||||
url = 'https://github.com/lucidrains/vit-pytorch',
|
||||
keywords = [
|
||||
'artificial intelligence',
|
||||
'attention mechanism',
|
||||
'image recognition'
|
||||
],
|
||||
install_requires=[
|
||||
'einops>=0.7.0',
|
||||
'torch>=1.10',
|
||||
'torchvision'
|
||||
],
|
||||
setup_requires=[
|
||||
'pytest-runner',
|
||||
],
|
||||
tests_require=[
|
||||
'pytest',
|
||||
'torch==2.4.0',
|
||||
'torchvision==0.19.0'
|
||||
],
|
||||
classifiers=[
|
||||
'Development Status :: 4 - Beta',
|
||||
'Intended Audience :: Developers',
|
||||
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||
'License :: OSI Approved :: MIT License',
|
||||
'Programming Language :: Python :: 3.6',
|
||||
],
|
||||
)
|
||||
BIN
tests/.DS_Store
vendored
Normal file
BIN
tests/.DS_Store
vendored
Normal file
Binary file not shown.
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
from vit_pytorch import ViT
|
||||
|
||||
def test():
|
||||
def test_vit():
|
||||
v = ViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
@@ -1,6 +1,8 @@
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from torch import is_tensor, randn
|
||||
from torch.nn import Module, Parameter
|
||||
from torch.nn import Module, Linear, Parameter
|
||||
from torch.utils._pytree import tree_flatten, tree_unflatten
|
||||
|
||||
from einops import rearrange, repeat
|
||||
@@ -23,7 +25,9 @@ class AcceptVideoWrapper(Module):
|
||||
add_time_pos_emb = False,
|
||||
dim_emb = None,
|
||||
time_seq_len = None,
|
||||
output_pos_add_pos_emb = 0 # defaults to first output position to add embedding
|
||||
embed_is_channel_first = False,
|
||||
output_pos_add_pos_emb = 0, # defaults to first output position to add embedding
|
||||
proj_embed_to_dim = None
|
||||
):
|
||||
super().__init__()
|
||||
self.image_net = image_net
|
||||
@@ -32,18 +36,34 @@ class AcceptVideoWrapper(Module):
|
||||
self.add_time_pos_emb = add_time_pos_emb
|
||||
self.output_pos_add_pos_emb = output_pos_add_pos_emb
|
||||
|
||||
# maybe project the image embedding
|
||||
|
||||
self.embed_proj = None
|
||||
|
||||
if exists(proj_embed_to_dim):
|
||||
assert exists(dim_emb), '`dim_emb` must be passed in'
|
||||
self.embed_proj = Linear(dim_emb, proj_embed_to_dim)
|
||||
|
||||
# time positional embedding
|
||||
|
||||
if add_time_pos_emb:
|
||||
assert exists(dim_emb) and exists(time_seq_len), '`dim_emb` and `time_seq_len` must be set if adding positional embeddings to the output'
|
||||
self.time_seq_len = time_seq_len
|
||||
|
||||
self.pos_emb = Parameter(randn(time_seq_len, dim_emb) * 1e-2)
|
||||
dim_pos_emb = default(proj_embed_to_dim, dim_emb)
|
||||
|
||||
self.pos_emb = Parameter(randn(time_seq_len, dim_pos_emb) * 1e-2)
|
||||
|
||||
self.embed_is_channel_first = embed_is_channel_first
|
||||
|
||||
def forward(
|
||||
self,
|
||||
video # (b c t h w)
|
||||
video, # (b c t h w)
|
||||
eval_with_no_grad = False,
|
||||
forward_kwargs = dict()
|
||||
):
|
||||
add_time_pos_emb = self.add_time_pos_emb
|
||||
batch, time = video.shape[0], video.shape[2]
|
||||
time = video.shape[2]
|
||||
|
||||
# maybe validate time positional embedding
|
||||
|
||||
@@ -54,9 +74,17 @@ class AcceptVideoWrapper(Module):
|
||||
|
||||
video = rearrange(video, 'b t ... -> (b t) ...')
|
||||
|
||||
# forward through image net for outputs
|
||||
|
||||
func = getattr(self.image_net, self.forward_function)
|
||||
|
||||
outputs = func(video)
|
||||
if eval_with_no_grad:
|
||||
self.image_net.eval()
|
||||
|
||||
context = torch.no_grad if eval_with_no_grad else nullcontext
|
||||
|
||||
with context():
|
||||
outputs = func(video, **forward_kwargs)
|
||||
|
||||
# handle multiple outputs, say logits and embeddings returned from extractor - also handle some reduce aux loss being returned
|
||||
|
||||
@@ -64,14 +92,37 @@ class AcceptVideoWrapper(Module):
|
||||
|
||||
outputs = tuple(rearrange(t, '(b t) ... -> b t ...', t = time) if is_tensor(t) and t.numel() > 1 else t for t in outputs)
|
||||
|
||||
# maybe project embedding
|
||||
|
||||
if exists(self.embed_proj):
|
||||
outputs = list(outputs)
|
||||
|
||||
embed = outputs[self.output_pos_add_pos_emb]
|
||||
|
||||
outputs[self.output_pos_add_pos_emb] = self.embed_proj(embed)
|
||||
|
||||
# maybe add time positional embedding
|
||||
|
||||
if add_time_pos_emb:
|
||||
pos_emb = repeat(self.pos_emb, 't d -> b t 1 d', b = batch)
|
||||
|
||||
outputs = list(outputs)
|
||||
embed = outputs[self.output_pos_add_pos_emb]
|
||||
|
||||
pos_emb = rearrange(self.pos_emb, 't d -> 1 t d')
|
||||
|
||||
# handle the network outputting embeddings with spatial dimensions intact - assume embedded dimension is last
|
||||
|
||||
dims_to_unsqueeze = embed.ndim - pos_emb.ndim
|
||||
|
||||
one_dims = ((1,) * dims_to_unsqueeze)
|
||||
|
||||
if self.embed_is_channel_first:
|
||||
pos_emb = pos_emb.reshape(*pos_emb.shape, *one_dims)
|
||||
else:
|
||||
pos_emb = pos_emb.reshape(*pos_emb.shape[:2], *one_dims, pos_emb.shape[-1])
|
||||
|
||||
pos_emb = pos_emb[:, :embed.shape[1]]
|
||||
|
||||
embed = embed + pos_emb
|
||||
|
||||
outputs[self.output_pos_add_pos_emb] = embed
|
||||
@@ -95,16 +146,16 @@ if __name__ == '__main__':
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
videos = torch.randn(1, 3, 10, 256, 256)
|
||||
videos = torch.randn(1, 3, 7, 256, 256)
|
||||
|
||||
# step up the difficulty and return embeddings for robotics
|
||||
|
||||
from vit_pytorch.extractor import Extractor
|
||||
v = Extractor(v)
|
||||
|
||||
video_acceptor = AcceptVideoWrapper(v, add_time_pos_emb = True, output_pos_add_pos_emb = 1, time_seq_len = 10, dim_emb = 1024)
|
||||
video_acceptor = AcceptVideoWrapper(v, add_time_pos_emb = True, output_pos_add_pos_emb = 1, time_seq_len = 12, dim_emb = 1024, proj_embed_to_dim = 512)
|
||||
|
||||
logits, embeddings = video_acceptor(videos) # always (batch, channels, time, height, width) - time is always dimension 2
|
||||
logits, embeddings = video_acceptor(videos, eval_with_no_grad = True) # always (batch, channels, time, height, width) - time is always dimension 2
|
||||
|
||||
assert logits.shape == (1, 10, 1000)
|
||||
assert embeddings.shape == (1, 10, 65, 1024)
|
||||
assert logits.shape == (1, 7, 1000)
|
||||
assert embeddings.shape == (1, 7, 65, 512)
|
||||
|
||||
@@ -316,6 +316,9 @@ class CCT(nn.Module):
|
||||
pooling_kernel_size=3,
|
||||
pooling_stride=2,
|
||||
pooling_padding=1,
|
||||
dropout_rate=0.,
|
||||
attention_dropout=0.1,
|
||||
stochastic_depth_rate=0.1,
|
||||
*args, **kwargs
|
||||
):
|
||||
super().__init__()
|
||||
@@ -340,9 +343,9 @@ class CCT(nn.Module):
|
||||
width=img_width),
|
||||
embedding_dim=embedding_dim,
|
||||
seq_pool=True,
|
||||
dropout_rate=0.,
|
||||
attention_dropout=0.1,
|
||||
stochastic_depth=0.1,
|
||||
dropout_rate=dropout_rate,
|
||||
attention_dropout=attention_dropout,
|
||||
stochastic_depth_rate=stochastic_depth_rate,
|
||||
*args, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
191
vit_pytorch/vit_nd.py
Normal file
191
vit_pytorch/vit_nd.py
Normal file
@@ -0,0 +1,191 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Module
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def join(arr, delimiter = ' '):
|
||||
return delimiter.join(arr)
|
||||
|
||||
def ensure_tuple(t, length):
|
||||
if isinstance(t, (tuple, list)):
|
||||
assert len(t) == length, f'Expected tuple of length {length}, got {len(t)}'
|
||||
return tuple(t)
|
||||
return (t,) * length
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return self.norm(x)
|
||||
|
||||
class ViTND(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ndim: int,
|
||||
input_shape: int | tuple[int, ...],
|
||||
patch_size: int | tuple[int, ...],
|
||||
num_classes: int,
|
||||
dim: int,
|
||||
depth: int,
|
||||
heads: int,
|
||||
mlp_dim: int,
|
||||
pool: str = 'cls',
|
||||
channels: int = 3,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.,
|
||||
emb_dropout: float = 0.
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert 1 <= ndim <= 7, 'ndim must be between 1 and 7'
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
|
||||
self.ndim = ndim
|
||||
self.pool = pool
|
||||
|
||||
input_shape = ensure_tuple(input_shape, ndim)
|
||||
patch_size = ensure_tuple(patch_size, ndim)
|
||||
|
||||
for i, (inp_dim, patch_dim) in enumerate(zip(input_shape, patch_size)):
|
||||
assert inp_dim % patch_dim == 0, f'Input dimension {i} ({inp_dim}) must be divisible by patch size ({patch_dim})'
|
||||
|
||||
num_patches_per_dim = [inp_dim // patch_dim for inp_dim, patch_dim in zip(input_shape, patch_size)]
|
||||
num_patches = 1
|
||||
for n in num_patches_per_dim:
|
||||
num_patches *= n
|
||||
|
||||
patch_dim = channels
|
||||
for p in patch_size:
|
||||
patch_dim *= p
|
||||
|
||||
dim_names = 'fghijkl'[:ndim]
|
||||
|
||||
input_dims = [f'({d} p{i})' for i, d in enumerate(dim_names)]
|
||||
patch_dims = [f'p{i}' for i in range(ndim)]
|
||||
|
||||
input_pattern = f'b c {join(input_dims)}'
|
||||
output_pattern = f'b ({join(dim_names)}) ({join(patch_dims)} c)'
|
||||
rearrange_str = f'{input_pattern} -> {output_pattern}'
|
||||
|
||||
rearrange_kwargs = {f'p{i}': p for i, p in enumerate(patch_size)}
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange(rearrange_str, **rearrange_kwargs),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
self.to_latent = nn.Identity()
|
||||
self.mlp_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.to_patch_embedding(x)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim = 1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
x = x[:, 1:].mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.mlp_head(x)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
model = ViTND(
|
||||
ndim = 4,
|
||||
input_shape = (8, 16, 32, 64),
|
||||
patch_size = (2, 4, 4, 8),
|
||||
num_classes = 1000,
|
||||
dim = 512,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048,
|
||||
channels = 3,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
occupancy_time = torch.randn(2, 3, 8, 16, 32, 64)
|
||||
|
||||
logits = model(occupancy_time)
|
||||
325
vit_pytorch/vit_nd_rotary.py
Normal file
325
vit_pytorch/vit_nd_rotary.py
Normal file
@@ -0,0 +1,325 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from torch import nn, arange, cat, stack, Tensor
|
||||
from torch.nn import Module, ModuleList
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat, reduce, pack, unpack
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def l2norm(t):
|
||||
return F.normalize(t, dim = -1, p = 2)
|
||||
|
||||
def join(arr, delimiter = ' '):
|
||||
return delimiter.join(arr)
|
||||
|
||||
def ensure_tuple(t, length):
|
||||
if isinstance(t, (tuple, list)):
|
||||
assert len(t) == length, f'Expected tuple of length {length}, got {len(t)}'
|
||||
return tuple(t)
|
||||
return (t,) * length
|
||||
|
||||
# golden gate rotary - Jerry Xiong, PhD student at UIUC
|
||||
# https://jerryxio.ng/posts/nd-rope/
|
||||
|
||||
def _phi(m: int) -> float:
|
||||
x = 2.0
|
||||
for _ in range(10):
|
||||
x = (1 + x) ** (1.0 / (m + 1.0))
|
||||
return x
|
||||
|
||||
def make_directions(n: int, d: int) -> Tensor:
|
||||
g = _phi(d)
|
||||
alpha = (1.0 / g) ** arange(1, d + 1, dtype = torch.float64)
|
||||
i = arange(1, n + 1, dtype = torch.float64).unsqueeze(1)
|
||||
z = torch.fmod(i * alpha, 1.0)
|
||||
directions = torch.erfinv(2.0 * z - 1.0)
|
||||
directions = l2norm(directions)
|
||||
return directions.float()
|
||||
|
||||
class GoldenGateRoPENd(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim_pos: int,
|
||||
heads: int,
|
||||
dim_head: int,
|
||||
rope_min_freq: float = 1.0,
|
||||
rope_max_freq: float = 10000.0,
|
||||
rope_p_zero_freqs: float = 0.0, # proportion of frequencies set to 0
|
||||
):
|
||||
super().__init__()
|
||||
n_freqs = dim_head // 2
|
||||
n_zero_freqs = round(rope_p_zero_freqs * n_freqs)
|
||||
|
||||
omega = cat((
|
||||
torch.zeros(n_zero_freqs),
|
||||
rope_min_freq * (rope_max_freq / rope_min_freq) ** torch.linspace(0, 1, n_freqs - n_zero_freqs),
|
||||
))
|
||||
|
||||
directions = rearrange(
|
||||
make_directions(heads * n_freqs, dim_pos),
|
||||
'(h f) p -> h f p',
|
||||
h = heads
|
||||
)
|
||||
|
||||
omega_expanded = rearrange(omega, 'f -> f 1')
|
||||
self.register_buffer('freqs', directions * omega_expanded) # shape: (h, f, p)
|
||||
|
||||
def forward(self, input: Tensor, pos: Tensor) -> Tensor:
|
||||
# input shape: (b, h, n, d) where d = head_dim
|
||||
# pos shape: (b, n, p) where p = pos_dim
|
||||
# self.freqs shape: (h, f, p) where f = d // 2
|
||||
|
||||
x, y = input.float().chunk(2, dim = -1) # both (b, h, n, f)
|
||||
|
||||
# Expand dimensions for broadcasting
|
||||
freqs = rearrange(self.freqs, 'h f p -> 1 h 1 f p')
|
||||
positions = rearrange(pos.float(), 'b n p -> b 1 n 1 p')
|
||||
|
||||
# Compute theta for each (batch, head, seq, freq)
|
||||
theta = reduce(freqs * positions, 'b h n f p -> b h n f', 'sum')
|
||||
|
||||
cos_theta = torch.cos(theta)
|
||||
sin_theta = torch.sin(theta)
|
||||
|
||||
# Apply rotation
|
||||
x_out = x * cos_theta - y * sin_theta
|
||||
y_out = x * sin_theta + y * cos_theta
|
||||
|
||||
output = cat((x_out, y_out), dim=-1)
|
||||
return output.type_as(input)
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., rotary_emb = None):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.rotary_emb = rotary_emb
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qk = nn.Linear(dim, inner_dim * 2, bias = False)
|
||||
self.to_v = nn.Linear(dim, inner_dim, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x, pos = None):
|
||||
x = self.norm(x)
|
||||
qkv = (*self.to_qk(x).chunk(2, dim = -1), self.to_v(x))
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
# Apply rotary embeddings if available
|
||||
if exists(self.rotary_emb):
|
||||
assert exists(pos)
|
||||
q = self.rotary_emb(q, pos)
|
||||
k = self.rotary_emb(k, pos)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., rotary_emb = None):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, rotary_emb = rotary_emb),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
def forward(self, x, pos = None):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x, pos) + x
|
||||
x = ff(x) + x
|
||||
return self.norm(x)
|
||||
|
||||
class ViTND(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ndim: int,
|
||||
input_shape: int | tuple[int, ...],
|
||||
patch_size: int | tuple[int, ...],
|
||||
num_classes: int,
|
||||
dim: int,
|
||||
depth: int,
|
||||
heads: int,
|
||||
mlp_dim: int,
|
||||
channels: int = 3,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.,
|
||||
emb_dropout: float = 0.,
|
||||
rope_min_freq: float = 1.0,
|
||||
rope_max_freq: float = 10000.0,
|
||||
rope_p_zero_freqs: float = 0.0
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert 1 <= ndim <= 7, 'ndim must be between 1 and 7'
|
||||
|
||||
self.ndim = ndim
|
||||
|
||||
input_shape = ensure_tuple(input_shape, ndim)
|
||||
patch_size = ensure_tuple(patch_size, ndim)
|
||||
|
||||
for i, (inp_dim, patch_dim) in enumerate(zip(input_shape, patch_size)):
|
||||
assert inp_dim % patch_dim == 0, f'Input dimension {i} ({inp_dim}) must be divisible by patch size ({patch_dim})'
|
||||
|
||||
num_patches_per_dim = [inp_dim // patch_dim for inp_dim, patch_dim in zip(input_shape, patch_size)]
|
||||
num_patches = 1
|
||||
for n in num_patches_per_dim:
|
||||
num_patches *= n
|
||||
|
||||
patch_dim = channels
|
||||
for p in patch_size:
|
||||
patch_dim *= p
|
||||
|
||||
dim_names = 'fghijkl'[:ndim]
|
||||
|
||||
input_dims = [f'({d} p{i})' for i, d in enumerate(dim_names)]
|
||||
patch_dims = [f'p{i}' for i in range(ndim)]
|
||||
|
||||
input_pattern = f'b c {join(input_dims)}'
|
||||
output_pattern = f'b {join(dim_names)} ({join(patch_dims)} c)'
|
||||
rearrange_str = f'{input_pattern} -> {output_pattern}'
|
||||
|
||||
rearrange_kwargs = {f'p{i}': p for i, p in enumerate(patch_size)}
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange(rearrange_str, **rearrange_kwargs),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
# Create rotary embeddings
|
||||
self.rotary_emb = GoldenGateRoPENd(
|
||||
dim_pos = ndim,
|
||||
heads = heads,
|
||||
dim_head = dim_head,
|
||||
rope_min_freq = rope_min_freq,
|
||||
rope_max_freq = rope_max_freq,
|
||||
rope_p_zero_freqs = rope_p_zero_freqs
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, rotary_emb = self.rotary_emb)
|
||||
|
||||
self.to_latent = nn.Identity()
|
||||
self.mlp_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def muon_parameters(self):
|
||||
params = []
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, Attention):
|
||||
params.extend([
|
||||
m.to_v.weight,
|
||||
m.to_out[0].weight
|
||||
])
|
||||
elif isinstance(m, FeedForward):
|
||||
params.extend([
|
||||
m.net[1].weight,
|
||||
m.net[-2].weight
|
||||
])
|
||||
|
||||
return params
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
return_embed = False
|
||||
):
|
||||
x = self.to_patch_embedding(x) # (b, *spatial_dims, patch_dim)
|
||||
|
||||
batch, *spatial_dims, _, device = *x.shape, x.device
|
||||
|
||||
# Generate position coordinates
|
||||
|
||||
grids = [arange(d, device = device, dtype = torch.float32) for d in spatial_dims]
|
||||
grid = torch.meshgrid(*grids, indexing = 'ij')
|
||||
pos = stack(grid, dim = -1) # (*spatial_dims, ndim)
|
||||
|
||||
# flatten spatial dimensions for attention with nd rotary
|
||||
|
||||
pos = repeat(pos, '... p -> b (...) p', b = batch)
|
||||
x, packed_shape = pack([x], 'b * d')
|
||||
|
||||
x = self.dropout(x)
|
||||
|
||||
embed = self.transformer(x, pos)
|
||||
|
||||
# return the embed with reconstituted patch shape
|
||||
|
||||
if return_embed:
|
||||
embed, = unpack(embed, packed_shape, 'b * d')
|
||||
return embed
|
||||
|
||||
# pooling to logits
|
||||
|
||||
pooled = reduce(embed, 'b n d -> b d', 'mean')
|
||||
|
||||
pooled = self.to_latent(pooled)
|
||||
return self.mlp_head(pooled)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
model = ViTND(
|
||||
ndim = 5,
|
||||
input_shape = (4, 8, 16, 32, 64),
|
||||
patch_size = (2, 2, 4, 4, 8),
|
||||
num_classes = 1000,
|
||||
dim = 512,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048,
|
||||
channels = 3,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
data = torch.randn(2, 3, 4, 8, 16, 32, 64)
|
||||
|
||||
logits = model(data)
|
||||
|
||||
embed = model(data, return_embed = True) # (2, 2, 4, 4, 8, 8, 512)
|
||||
Reference in New Issue
Block a user