mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d5d6c3f38f | ||
|
|
39fd9ac8be | ||
|
|
3becf087bb | ||
|
|
f6bc14c81d | ||
|
|
845c844b3b | ||
|
|
5f2bc0c796 |
4
.github/workflows/python-publish.yml
vendored
4
.github/workflows/python-publish.yml
vendored
@@ -18,9 +18,9 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v2
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.x'
|
||||
- name: Install dependencies
|
||||
|
||||
11
.github/workflows/python-test.yml
vendored
11
.github/workflows/python-test.yml
vendored
@@ -18,18 +18,17 @@ jobs:
|
||||
python-version: [3.8, 3.9]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install pytest
|
||||
python -m pip install wheel
|
||||
python -m pip install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cpu
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
||||
python -m pip install -e .
|
||||
python -m pip install pytest
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
python setup.py test
|
||||
pytest -q
|
||||
|
||||
20
README.md
20
README.md
@@ -2181,4 +2181,24 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{xiong2025ndrope,
|
||||
author = {Jerry Xiong},
|
||||
title = {On n-dimensional rotary positional embeddings},
|
||||
year = {2025},
|
||||
url = {https://jerryxio.ng/posts/nd-rope/}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{anonymous2025vat,
|
||||
title = {{VAT}: Vision Action Transformer by Unlocking Full Representation of ViT},
|
||||
author = {Anonymous},
|
||||
booktitle = {Submitted to The Fourteenth International Conference on Learning Representations},
|
||||
year = {2025},
|
||||
url = {https://openreview.net/forum?id=TalHOvvLZu},
|
||||
note = {under review}
|
||||
}
|
||||
```
|
||||
|
||||
*I visualise a time when we will be to robots what dogs are to humans, and I’m rooting for the machines.* — Claude Shannon
|
||||
|
||||
63
pyproject.toml
Normal file
63
pyproject.toml
Normal file
@@ -0,0 +1,63 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=61", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "vit-pytorch"
|
||||
version = "1.14.0"
|
||||
description = "Vision Transformer (ViT) - Pytorch"
|
||||
readme = { file = "README.md", content-type = "text/markdown" }
|
||||
license = { file = "LICENSE" }
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" },
|
||||
]
|
||||
requires-python = ">=3.8"
|
||||
keywords = [
|
||||
"artificial intelligence",
|
||||
"attention mechanism",
|
||||
"image recognition",
|
||||
]
|
||||
classifiers = [
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Developers",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3 :: Only",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
]
|
||||
dependencies = [
|
||||
"einops>=0.7.0",
|
||||
"torch>=1.10",
|
||||
"torchvision",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
test = [
|
||||
"pytest",
|
||||
"torch==2.4.0",
|
||||
"torchvision==0.19.0",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/lucidrains/vit-pytorch"
|
||||
Repository = "https://github.com/lucidrains/vit-pytorch"
|
||||
|
||||
[tool.setuptools]
|
||||
include-package-data = true
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
include = ["vit_pytorch*"]
|
||||
exclude = ["examples*", "tests*", "test*"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests", "."]
|
||||
python_files = ["test_*.py", "*_test.py"]
|
||||
addopts = "-q"
|
||||
filterwarnings = [
|
||||
"ignore::FutureWarning",
|
||||
]
|
||||
42
setup.py
42
setup.py
@@ -1,42 +0,0 @@
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
with open('README.md') as f:
|
||||
long_description = f.read()
|
||||
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '1.11.7',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
long_description = long_description,
|
||||
long_description_content_type = 'text/markdown',
|
||||
author = 'Phil Wang',
|
||||
author_email = 'lucidrains@gmail.com',
|
||||
url = 'https://github.com/lucidrains/vit-pytorch',
|
||||
keywords = [
|
||||
'artificial intelligence',
|
||||
'attention mechanism',
|
||||
'image recognition'
|
||||
],
|
||||
install_requires=[
|
||||
'einops>=0.7.0',
|
||||
'torch>=1.10',
|
||||
'torchvision'
|
||||
],
|
||||
setup_requires=[
|
||||
'pytest-runner',
|
||||
],
|
||||
tests_require=[
|
||||
'pytest',
|
||||
'torch==2.4.0',
|
||||
'torchvision==0.19.0'
|
||||
],
|
||||
classifiers=[
|
||||
'Development Status :: 4 - Beta',
|
||||
'Intended Audience :: Developers',
|
||||
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||
'License :: OSI Approved :: MIT License',
|
||||
'Programming Language :: Python :: 3.6',
|
||||
],
|
||||
)
|
||||
BIN
tests/.DS_Store
vendored
Normal file
BIN
tests/.DS_Store
vendored
Normal file
Binary file not shown.
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
from vit_pytorch import ViT
|
||||
|
||||
def test():
|
||||
def test_vit():
|
||||
v = ViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
444
vit_pytorch/vat.py
Normal file
444
vit_pytorch/vat.py
Normal file
@@ -0,0 +1,444 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, cat, stack, tensor
|
||||
from torch.nn import Module, ModuleList
|
||||
|
||||
from einops import rearrange, repeat, pack, unpack
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
def default(v, d):
|
||||
return v if exists(v) else d
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
hidden_dim,
|
||||
dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
heads = 8,
|
||||
dim_head = 64,
|
||||
dropout = 0.,
|
||||
cross_attend = False
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.cross_attend = cross_attend
|
||||
self.context_norm = nn.LayerNorm(dim) if cross_attend else None
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x, context = None):
|
||||
|
||||
assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross attending, or vice versa'
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
# handle norming of context for cross attention
|
||||
|
||||
kv_input = x
|
||||
|
||||
if self.cross_attend:
|
||||
context = self.context_norm(context)
|
||||
kv_input = context
|
||||
|
||||
# project for queries, keys, values
|
||||
|
||||
qkv = (self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1))
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
depth,
|
||||
heads,
|
||||
dim_head,
|
||||
mlp_dim,
|
||||
dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
return_hiddens = False
|
||||
):
|
||||
|
||||
hiddens = []
|
||||
|
||||
for attn, ff in self.layers:
|
||||
hiddens.append(x)
|
||||
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
if not return_hiddens:
|
||||
return x
|
||||
|
||||
return x, hiddens
|
||||
|
||||
class ViT(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_size,
|
||||
patch_size,
|
||||
num_classes,
|
||||
dim,
|
||||
depth,
|
||||
heads,
|
||||
mlp_dim,
|
||||
pool = 'cls',
|
||||
channels = 3,
|
||||
dim_head = 64,
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.depth = depth
|
||||
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
self.pool = pool
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, img, return_hiddens = False):
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
|
||||
x = cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.dropout(x)
|
||||
|
||||
x, hiddens = self.transformer(x, return_hiddens = True)
|
||||
|
||||
# return the representation trajectory
|
||||
|
||||
if return_hiddens:
|
||||
return x, stack(hiddens)
|
||||
|
||||
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.mlp_head(x)
|
||||
|
||||
# proposed VAT
|
||||
|
||||
# https://openreview.net/forum?id=TalHOvvLZu
|
||||
# simple way to get SOTA on Libero dataset (beating fine-tuned pi-zero)
|
||||
|
||||
class VAT(Module):
|
||||
def __init__(
|
||||
self,
|
||||
vit: ViT | dict,
|
||||
*,
|
||||
dim,
|
||||
depth,
|
||||
heads,
|
||||
dim_head,
|
||||
dim_action,
|
||||
mlp_dim,
|
||||
num_views = None,
|
||||
dim_extra_token = None,
|
||||
action_chunk_len = 7,
|
||||
time_seq_len = 1,
|
||||
dropout = 0.,
|
||||
add_self_attn = True, # in the paper, they didn't have any ways for the action token to exchange information with the extra token, so we'll just add it as an option
|
||||
self_attn_heads = 4,
|
||||
self_attn_dim_head = 32,
|
||||
vit_layer_indices: tuple[int, ...] | None = None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if isinstance(vit, dict):
|
||||
vit = ViT(**vit)
|
||||
|
||||
self.vit = vit
|
||||
|
||||
vit_dim = vit.dim
|
||||
|
||||
assert vit.depth == depth or exists(vit_layer_indices), f'if the VAT depth is not equal to the ViT depth, you must pass in the indices from the ViT to be layered to the VAT in order from bottom to top'
|
||||
|
||||
vit_layer_indices = default(vit_layer_indices, tuple(range(depth)))
|
||||
|
||||
assert len(vit_layer_indices) == depth, f'number of vit layer indices {len(vit_layer_indices)} does not much the VAT depth {depth}'
|
||||
|
||||
self.register_buffer('layer_indices', tensor(vit_layer_indices), persistent = False)
|
||||
|
||||
# handle maybe multiple frames
|
||||
|
||||
is_video = time_seq_len > 1
|
||||
|
||||
self.is_video = is_video
|
||||
self.time_seq_len = time_seq_len
|
||||
self.time_pos_emb = nn.Parameter(torch.randn(time_seq_len, vit_dim) * 1e-2) if is_video else None
|
||||
|
||||
# maybe view embeddings
|
||||
|
||||
self.view_emb = nn.Parameter(torch.randn(num_views, vit_dim) * 1e-2) if exists(num_views) and num_views > 1 else None
|
||||
|
||||
# to action tokens
|
||||
|
||||
self.action_pos_emb = nn.Parameter(torch.randn(action_chunk_len, dim) * 1e-2)
|
||||
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
maybe_self_attn = Attention(dim = dim, heads = self_attn_heads, dim_head = self_attn_dim_head, dropout = dropout) if add_self_attn else None
|
||||
|
||||
self.layers.append(ModuleList([
|
||||
maybe_self_attn,
|
||||
Attention(dim = dim, heads = heads, dim_head = dim_head, dropout = dropout, cross_attend = True),
|
||||
FeedForward(dim = dim, hidden_dim = mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
self.final_norm = nn.LayerNorm(dim)
|
||||
self.to_pred_action = nn.Linear(dim, dim_action, bias = False)
|
||||
|
||||
# handle the extra token
|
||||
|
||||
self.accept_extra_token = exists(dim_extra_token)
|
||||
|
||||
if exists(dim_extra_token):
|
||||
self.to_extra_token = nn.Linear(dim_extra_token, dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
video_or_image, # (b v? c t? h w) - batch, views [wrist + third person or more], channels, maybe time, height, width
|
||||
extra = None, # (b d) - batch, dim extra
|
||||
actions = None, # (b k d) - batch, action chunk length, action dimension
|
||||
):
|
||||
batch = video_or_image.shape[0]
|
||||
return_loss = exists(actions)
|
||||
|
||||
# handle some various input dimensions
|
||||
|
||||
if video_or_image.ndim == 4:
|
||||
video_or_image = rearrange(video_or_image, 'b 1 c h w')
|
||||
|
||||
assert (
|
||||
(video_or_image.ndim == 5 and not self.is_video) or
|
||||
(video_or_image.ndim == 6 and self.is_video)
|
||||
)
|
||||
|
||||
if video_or_image.ndim == 5:
|
||||
video_or_image = rearrange(video_or_image, 'b v c h w -> b v c 1 h w')
|
||||
|
||||
assert video_or_image.shape[3] == self.time_seq_len
|
||||
|
||||
# to images
|
||||
|
||||
images = rearrange(video_or_image, 'b v c t h w -> b v t c h w')
|
||||
|
||||
images, packed_shape = pack([images], '* c h w')
|
||||
|
||||
# get representation trajectory from vit
|
||||
|
||||
embed, hiddens = self.vit(images, return_hiddens = True)
|
||||
|
||||
hiddens = cat((hiddens, embed[None, ...]))
|
||||
|
||||
# extract the hiddens needed for the action cross attention
|
||||
|
||||
hiddens = hiddens[self.layer_indices]
|
||||
|
||||
# pack temporarily for embedding
|
||||
|
||||
hiddens, = unpack(hiddens, packed_shape, 'l * n d') # l for layers
|
||||
|
||||
# maybe add time embeddings
|
||||
|
||||
if self.is_video:
|
||||
time_pos_emb = rearrange(time_pos_emb, 't d -> t 1 d')
|
||||
hiddens = hiddens + time_pos_emb
|
||||
|
||||
# maybe view embeddings
|
||||
|
||||
if exists(self.view_emb):
|
||||
assert self.view_emb.shape[0] == hiddens.shape[2]
|
||||
|
||||
view_emb = rearrange(self.view_emb, 'v d -> v 1 1 d')
|
||||
hiddens = hiddens + view_emb
|
||||
|
||||
# cross from actions to representation trajectory
|
||||
|
||||
context = rearrange(hiddens, 'l b v t n d -> l b (v t n) d')
|
||||
|
||||
# get main action tokens and maybe append extra
|
||||
|
||||
action_tokens = repeat(self.action_pos_emb, 'k d -> b k d', b = batch)
|
||||
|
||||
has_extra = exists(extra)
|
||||
|
||||
if has_extra:
|
||||
assert self.accept_extra_token
|
||||
|
||||
extra_token = self.to_extra_token(extra)
|
||||
|
||||
action_tokens, packed_extra = pack([action_tokens, extra_token], 'b * d')
|
||||
|
||||
# cross attention
|
||||
|
||||
for (maybe_self_attn, cross_attn, ff), layer_context in zip(self.layers, context):
|
||||
|
||||
action_tokens = cross_attn(action_tokens, layer_context) + action_tokens
|
||||
|
||||
if exists(maybe_self_attn):
|
||||
action_tokens = maybe_self_attn(action_tokens) + action_tokens
|
||||
|
||||
action_tokens = ff(action_tokens) + action_tokens
|
||||
|
||||
# maybe unpack extra
|
||||
|
||||
if has_extra:
|
||||
action_tokens, _ = unpack(action_tokens, packed_extra, 'b * d')
|
||||
|
||||
# norm and prediction
|
||||
|
||||
action_tokens = self.final_norm(action_tokens)
|
||||
|
||||
pred_action = self.to_pred_action(action_tokens)
|
||||
|
||||
if not return_loss:
|
||||
return pred_action
|
||||
|
||||
assert pred_action.shape[1] == actions.shape[1]
|
||||
|
||||
# they found l1 loss suffices
|
||||
|
||||
return F.l1_loss(pred_action, actions)
|
||||
|
||||
# quick test
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
vit = ViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 512,
|
||||
heads = 8,
|
||||
depth = 4,
|
||||
mlp_dim = 2048
|
||||
)
|
||||
|
||||
vat = VAT(
|
||||
vit,
|
||||
dim = 512,
|
||||
depth = 9,
|
||||
heads = 8,
|
||||
dim_head = 64,
|
||||
mlp_dim = 2048,
|
||||
dim_action = 20,
|
||||
action_chunk_len = 7,
|
||||
time_seq_len = 1,
|
||||
num_views = 2,
|
||||
add_self_attn = True,
|
||||
dim_extra_token = 33, # extra token with some variable dimension
|
||||
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
|
||||
0, 0, 1, 1, 2, 2, 3, 3, 4
|
||||
)
|
||||
)
|
||||
|
||||
images = torch.randn(2, 2, 3, 256, 256) # (2 views with 4 frames)
|
||||
extra = torch.randn(2, 33) # extra internal state
|
||||
|
||||
actions = torch.randn(2, 7, 20) # actions for learning
|
||||
|
||||
loss = vat(images, actions = actions, extra = extra)
|
||||
loss.backward()
|
||||
|
||||
# after much training
|
||||
|
||||
pred_actions = vat(images)
|
||||
|
||||
assert pred_actions.shape == (2, 7, 20)
|
||||
191
vit_pytorch/vit_nd.py
Normal file
191
vit_pytorch/vit_nd.py
Normal file
@@ -0,0 +1,191 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Module
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def join(arr, delimiter = ' '):
|
||||
return delimiter.join(arr)
|
||||
|
||||
def ensure_tuple(t, length):
|
||||
if isinstance(t, (tuple, list)):
|
||||
assert len(t) == length, f'Expected tuple of length {length}, got {len(t)}'
|
||||
return tuple(t)
|
||||
return (t,) * length
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return self.norm(x)
|
||||
|
||||
class ViTND(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ndim: int,
|
||||
input_shape: int | tuple[int, ...],
|
||||
patch_size: int | tuple[int, ...],
|
||||
num_classes: int,
|
||||
dim: int,
|
||||
depth: int,
|
||||
heads: int,
|
||||
mlp_dim: int,
|
||||
pool: str = 'cls',
|
||||
channels: int = 3,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.,
|
||||
emb_dropout: float = 0.
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert 1 <= ndim <= 7, 'ndim must be between 1 and 7'
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
|
||||
self.ndim = ndim
|
||||
self.pool = pool
|
||||
|
||||
input_shape = ensure_tuple(input_shape, ndim)
|
||||
patch_size = ensure_tuple(patch_size, ndim)
|
||||
|
||||
for i, (inp_dim, patch_dim) in enumerate(zip(input_shape, patch_size)):
|
||||
assert inp_dim % patch_dim == 0, f'Input dimension {i} ({inp_dim}) must be divisible by patch size ({patch_dim})'
|
||||
|
||||
num_patches_per_dim = [inp_dim // patch_dim for inp_dim, patch_dim in zip(input_shape, patch_size)]
|
||||
num_patches = 1
|
||||
for n in num_patches_per_dim:
|
||||
num_patches *= n
|
||||
|
||||
patch_dim = channels
|
||||
for p in patch_size:
|
||||
patch_dim *= p
|
||||
|
||||
dim_names = 'fghijkl'[:ndim]
|
||||
|
||||
input_dims = [f'({d} p{i})' for i, d in enumerate(dim_names)]
|
||||
patch_dims = [f'p{i}' for i in range(ndim)]
|
||||
|
||||
input_pattern = f'b c {join(input_dims)}'
|
||||
output_pattern = f'b ({join(dim_names)}) ({join(patch_dims)} c)'
|
||||
rearrange_str = f'{input_pattern} -> {output_pattern}'
|
||||
|
||||
rearrange_kwargs = {f'p{i}': p for i, p in enumerate(patch_size)}
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange(rearrange_str, **rearrange_kwargs),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
self.to_latent = nn.Identity()
|
||||
self.mlp_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.to_patch_embedding(x)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim = 1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
x = x[:, 1:].mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.mlp_head(x)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
model = ViTND(
|
||||
ndim = 4,
|
||||
input_shape = (8, 16, 32, 64),
|
||||
patch_size = (2, 4, 4, 8),
|
||||
num_classes = 1000,
|
||||
dim = 512,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048,
|
||||
channels = 3,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
occupancy_time = torch.randn(2, 3, 8, 16, 32, 64)
|
||||
|
||||
logits = model(occupancy_time)
|
||||
325
vit_pytorch/vit_nd_rotary.py
Normal file
325
vit_pytorch/vit_nd_rotary.py
Normal file
@@ -0,0 +1,325 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from torch import nn, arange, cat, stack, Tensor
|
||||
from torch.nn import Module, ModuleList
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat, reduce, pack, unpack
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def l2norm(t):
|
||||
return F.normalize(t, dim = -1, p = 2)
|
||||
|
||||
def join(arr, delimiter = ' '):
|
||||
return delimiter.join(arr)
|
||||
|
||||
def ensure_tuple(t, length):
|
||||
if isinstance(t, (tuple, list)):
|
||||
assert len(t) == length, f'Expected tuple of length {length}, got {len(t)}'
|
||||
return tuple(t)
|
||||
return (t,) * length
|
||||
|
||||
# golden gate rotary - Jerry Xiong, PhD student at UIUC
|
||||
# https://jerryxio.ng/posts/nd-rope/
|
||||
|
||||
def _phi(m: int) -> float:
|
||||
x = 2.0
|
||||
for _ in range(10):
|
||||
x = (1 + x) ** (1.0 / (m + 1.0))
|
||||
return x
|
||||
|
||||
def make_directions(n: int, d: int) -> Tensor:
|
||||
g = _phi(d)
|
||||
alpha = (1.0 / g) ** arange(1, d + 1, dtype = torch.float64)
|
||||
i = arange(1, n + 1, dtype = torch.float64).unsqueeze(1)
|
||||
z = torch.fmod(i * alpha, 1.0)
|
||||
directions = torch.erfinv(2.0 * z - 1.0)
|
||||
directions = l2norm(directions)
|
||||
return directions.float()
|
||||
|
||||
class GoldenGateRoPENd(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim_pos: int,
|
||||
heads: int,
|
||||
dim_head: int,
|
||||
rope_min_freq: float = 1.0,
|
||||
rope_max_freq: float = 10000.0,
|
||||
rope_p_zero_freqs: float = 0.0, # proportion of frequencies set to 0
|
||||
):
|
||||
super().__init__()
|
||||
n_freqs = dim_head // 2
|
||||
n_zero_freqs = round(rope_p_zero_freqs * n_freqs)
|
||||
|
||||
omega = cat((
|
||||
torch.zeros(n_zero_freqs),
|
||||
rope_min_freq * (rope_max_freq / rope_min_freq) ** torch.linspace(0, 1, n_freqs - n_zero_freqs),
|
||||
))
|
||||
|
||||
directions = rearrange(
|
||||
make_directions(heads * n_freqs, dim_pos),
|
||||
'(h f) p -> h f p',
|
||||
h = heads
|
||||
)
|
||||
|
||||
omega_expanded = rearrange(omega, 'f -> f 1')
|
||||
self.register_buffer('freqs', directions * omega_expanded) # shape: (h, f, p)
|
||||
|
||||
def forward(self, input: Tensor, pos: Tensor) -> Tensor:
|
||||
# input shape: (b, h, n, d) where d = head_dim
|
||||
# pos shape: (b, n, p) where p = pos_dim
|
||||
# self.freqs shape: (h, f, p) where f = d // 2
|
||||
|
||||
x, y = input.float().chunk(2, dim = -1) # both (b, h, n, f)
|
||||
|
||||
# Expand dimensions for broadcasting
|
||||
freqs = rearrange(self.freqs, 'h f p -> 1 h 1 f p')
|
||||
positions = rearrange(pos.float(), 'b n p -> b 1 n 1 p')
|
||||
|
||||
# Compute theta for each (batch, head, seq, freq)
|
||||
theta = reduce(freqs * positions, 'b h n f p -> b h n f', 'sum')
|
||||
|
||||
cos_theta = torch.cos(theta)
|
||||
sin_theta = torch.sin(theta)
|
||||
|
||||
# Apply rotation
|
||||
x_out = x * cos_theta - y * sin_theta
|
||||
y_out = x * sin_theta + y * cos_theta
|
||||
|
||||
output = cat((x_out, y_out), dim=-1)
|
||||
return output.type_as(input)
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., rotary_emb = None):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.rotary_emb = rotary_emb
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qk = nn.Linear(dim, inner_dim * 2, bias = False)
|
||||
self.to_v = nn.Linear(dim, inner_dim, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x, pos = None):
|
||||
x = self.norm(x)
|
||||
qkv = (*self.to_qk(x).chunk(2, dim = -1), self.to_v(x))
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
# Apply rotary embeddings if available
|
||||
if exists(self.rotary_emb):
|
||||
assert exists(pos)
|
||||
q = self.rotary_emb(q, pos)
|
||||
k = self.rotary_emb(k, pos)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., rotary_emb = None):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, rotary_emb = rotary_emb),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
def forward(self, x, pos = None):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x, pos) + x
|
||||
x = ff(x) + x
|
||||
return self.norm(x)
|
||||
|
||||
class ViTND(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ndim: int,
|
||||
input_shape: int | tuple[int, ...],
|
||||
patch_size: int | tuple[int, ...],
|
||||
num_classes: int,
|
||||
dim: int,
|
||||
depth: int,
|
||||
heads: int,
|
||||
mlp_dim: int,
|
||||
channels: int = 3,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.,
|
||||
emb_dropout: float = 0.,
|
||||
rope_min_freq: float = 1.0,
|
||||
rope_max_freq: float = 10000.0,
|
||||
rope_p_zero_freqs: float = 0.0
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert 1 <= ndim <= 7, 'ndim must be between 1 and 7'
|
||||
|
||||
self.ndim = ndim
|
||||
|
||||
input_shape = ensure_tuple(input_shape, ndim)
|
||||
patch_size = ensure_tuple(patch_size, ndim)
|
||||
|
||||
for i, (inp_dim, patch_dim) in enumerate(zip(input_shape, patch_size)):
|
||||
assert inp_dim % patch_dim == 0, f'Input dimension {i} ({inp_dim}) must be divisible by patch size ({patch_dim})'
|
||||
|
||||
num_patches_per_dim = [inp_dim // patch_dim for inp_dim, patch_dim in zip(input_shape, patch_size)]
|
||||
num_patches = 1
|
||||
for n in num_patches_per_dim:
|
||||
num_patches *= n
|
||||
|
||||
patch_dim = channels
|
||||
for p in patch_size:
|
||||
patch_dim *= p
|
||||
|
||||
dim_names = 'fghijkl'[:ndim]
|
||||
|
||||
input_dims = [f'({d} p{i})' for i, d in enumerate(dim_names)]
|
||||
patch_dims = [f'p{i}' for i in range(ndim)]
|
||||
|
||||
input_pattern = f'b c {join(input_dims)}'
|
||||
output_pattern = f'b {join(dim_names)} ({join(patch_dims)} c)'
|
||||
rearrange_str = f'{input_pattern} -> {output_pattern}'
|
||||
|
||||
rearrange_kwargs = {f'p{i}': p for i, p in enumerate(patch_size)}
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange(rearrange_str, **rearrange_kwargs),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
# Create rotary embeddings
|
||||
self.rotary_emb = GoldenGateRoPENd(
|
||||
dim_pos = ndim,
|
||||
heads = heads,
|
||||
dim_head = dim_head,
|
||||
rope_min_freq = rope_min_freq,
|
||||
rope_max_freq = rope_max_freq,
|
||||
rope_p_zero_freqs = rope_p_zero_freqs
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, rotary_emb = self.rotary_emb)
|
||||
|
||||
self.to_latent = nn.Identity()
|
||||
self.mlp_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def muon_parameters(self):
|
||||
params = []
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, Attention):
|
||||
params.extend([
|
||||
m.to_v.weight,
|
||||
m.to_out[0].weight
|
||||
])
|
||||
elif isinstance(m, FeedForward):
|
||||
params.extend([
|
||||
m.net[1].weight,
|
||||
m.net[-2].weight
|
||||
])
|
||||
|
||||
return params
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
return_embed = False
|
||||
):
|
||||
x = self.to_patch_embedding(x) # (b, *spatial_dims, patch_dim)
|
||||
|
||||
batch, *spatial_dims, _, device = *x.shape, x.device
|
||||
|
||||
# Generate position coordinates
|
||||
|
||||
grids = [arange(d, device = device, dtype = torch.float32) for d in spatial_dims]
|
||||
grid = torch.meshgrid(*grids, indexing = 'ij')
|
||||
pos = stack(grid, dim = -1) # (*spatial_dims, ndim)
|
||||
|
||||
# flatten spatial dimensions for attention with nd rotary
|
||||
|
||||
pos = repeat(pos, '... p -> b (...) p', b = batch)
|
||||
x, packed_shape = pack([x], 'b * d')
|
||||
|
||||
x = self.dropout(x)
|
||||
|
||||
embed = self.transformer(x, pos)
|
||||
|
||||
# return the embed with reconstituted patch shape
|
||||
|
||||
if return_embed:
|
||||
embed, = unpack(embed, packed_shape, 'b * d')
|
||||
return embed
|
||||
|
||||
# pooling to logits
|
||||
|
||||
pooled = reduce(embed, 'b n d -> b d', 'mean')
|
||||
|
||||
pooled = self.to_latent(pooled)
|
||||
return self.mlp_head(pooled)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
model = ViTND(
|
||||
ndim = 5,
|
||||
input_shape = (4, 8, 16, 32, 64),
|
||||
patch_size = (2, 2, 4, 4, 8),
|
||||
num_classes = 1000,
|
||||
dim = 512,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048,
|
||||
channels = 3,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
data = torch.randn(2, 3, 4, 8, 16, 32, 64)
|
||||
|
||||
logits = model(data)
|
||||
|
||||
embed = model(data, return_embed = True) # (2, 2, 4, 4, 8, 8, 512)
|
||||
Reference in New Issue
Block a user