mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
Compare commits
53 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3cff5e547a | ||
|
|
fdaf7f92b9 | ||
|
|
0ebd4edab9 | ||
|
|
aa49c2783a | ||
|
|
6aa0374313 | ||
|
|
b35a97de05 | ||
|
|
1374b93145 | ||
|
|
4386742cd1 | ||
|
|
5cf8384c56 | ||
|
|
f7d59cecb5 | ||
|
|
a583cb5988 | ||
|
|
25871013f5 | ||
|
|
e66862bcd5 | ||
|
|
39fd9ac8be | ||
|
|
3becf087bb | ||
|
|
f6bc14c81d | ||
|
|
845c844b3b | ||
|
|
5f2bc0c796 | ||
|
|
35bf273037 | ||
|
|
1123063a5e | ||
|
|
f8bec5ede2 | ||
|
|
297e7d00a2 | ||
|
|
29ac8e143c | ||
|
|
e05cd6d8b8 | ||
|
|
b46233c3d6 | ||
|
|
68e13a3c7d | ||
|
|
b22dc0ecd2 | ||
|
|
db05a141a6 | ||
|
|
9f49a31977 | ||
|
|
ab63fc9cc8 | ||
|
|
c3018d1433 | ||
|
|
b7ed6bad28 | ||
|
|
e7cba9ba6d | ||
|
|
56373c0cbd | ||
|
|
24196a3e8a | ||
|
|
f6d7287b6b | ||
|
|
d47c57e32f | ||
|
|
0449865786 | ||
|
|
6693d47d0b | ||
|
|
141239ca86 | ||
|
|
0b5c9b4559 | ||
|
|
e300cdd7dc | ||
|
|
36ddc7a6ba | ||
|
|
1d1a63fc5c | ||
|
|
74b62009f8 | ||
|
|
f50d7d1436 | ||
|
|
82f2fa751d | ||
|
|
fcb9501cdd | ||
|
|
c4651a35a3 | ||
|
|
9d43e4d0bb | ||
|
|
5e808f48d1 | ||
|
|
bed48b5912 | ||
|
|
73199ab486 |
4
.github/workflows/python-publish.yml
vendored
4
.github/workflows/python-publish.yml
vendored
@@ -18,9 +18,9 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v2
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.x'
|
||||
- name: Install dependencies
|
||||
|
||||
10
.github/workflows/python-test.yml
vendored
10
.github/workflows/python-test.yml
vendored
@@ -18,17 +18,17 @@ jobs:
|
||||
python-version: [3.8, 3.9]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cpu
|
||||
python -m pip install -e .
|
||||
python -m pip install pytest
|
||||
python -m pip install wheel
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
python setup.py test
|
||||
pytest -q
|
||||
|
||||
118
README.md
118
README.md
@@ -198,6 +198,38 @@ preds = v(
|
||||
) # (5, 1000)
|
||||
```
|
||||
|
||||
Finally, if you would like to make use of a flavor of NaViT using <a href="https://pytorch.org/tutorials/prototype/nestedtensor.html">nested tensors</a> (which will omit a lot of the masking and padding altogether), make sure you are on version `2.5` and import as follows
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.na_vit_nested_tensor import NaViT
|
||||
|
||||
v = NaViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.,
|
||||
token_dropout_prob = 0.1
|
||||
)
|
||||
|
||||
# 5 images of different resolutions - List[Tensor]
|
||||
|
||||
images = [
|
||||
torch.randn(3, 256, 256), torch.randn(3, 128, 128),
|
||||
torch.randn(3, 128, 256), torch.randn(3, 256, 128),
|
||||
torch.randn(3, 64, 256)
|
||||
]
|
||||
|
||||
preds = v(images)
|
||||
|
||||
assert preds.shape == (5, 1000)
|
||||
```
|
||||
|
||||
## Distillation
|
||||
|
||||
<img src="./images/distill.png" width="300px"></img>
|
||||
@@ -1186,7 +1218,8 @@ pred = cct(video)
|
||||
|
||||
<img src="./images/vivit.png" width="350px"></img>
|
||||
|
||||
This <a href="https://arxiv.org/abs/2103.15691">paper</a> offers 3 different types of architectures for efficient attention of videos, with the main theme being factorizing the attention across space and time. This repository will offer the first variant, which is a spatial transformer followed by a temporal one.
|
||||
This <a href="https://arxiv.org/abs/2103.15691">paper</a> offers 3 different types of architectures for efficient attention of videos, with the main theme being factorizing the attention across space and time. This repository includes the factorized encoder and the factorized self-attention variant.
|
||||
The factorized encoder variant is a spatial transformer followed by a temporal one. The factorized self-attention variant is a spatio-temporal transformer with alternating spatial and temporal self-attention layers.
|
||||
|
||||
```python
|
||||
import torch
|
||||
@@ -1202,7 +1235,8 @@ v = ViT(
|
||||
spatial_depth = 6, # depth of the spatial transformer
|
||||
temporal_depth = 6, # depth of the temporal transformer
|
||||
heads = 8,
|
||||
mlp_dim = 2048
|
||||
mlp_dim = 2048,
|
||||
variant = 'factorized_encoder', # or 'factorized_self_attention'
|
||||
)
|
||||
|
||||
video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width)
|
||||
@@ -2099,4 +2133,84 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Loshchilov2024nGPTNT,
|
||||
title = {nGPT: Normalized Transformer with Representation Learning on the Hypersphere},
|
||||
author = {Ilya Loshchilov and Cheng-Ping Hsieh and Simeng Sun and Boris Ginsburg},
|
||||
year = {2024},
|
||||
url = {https://api.semanticscholar.org/CorpusID:273026160}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Liu2017DeepHL,
|
||||
title = {Deep Hyperspherical Learning},
|
||||
author = {Weiyang Liu and Yanming Zhang and Xingguo Li and Zhen Liu and Bo Dai and Tuo Zhao and Le Song},
|
||||
booktitle = {Neural Information Processing Systems},
|
||||
year = {2017},
|
||||
url = {https://api.semanticscholar.org/CorpusID:5104558}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Zhou2024ValueRL,
|
||||
title = {Value Residual Learning For Alleviating Attention Concentration In Transformers},
|
||||
author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan},
|
||||
year = {2024},
|
||||
url = {https://api.semanticscholar.org/CorpusID:273532030}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Zhu2024HyperConnections,
|
||||
title = {Hyper-Connections},
|
||||
author = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou},
|
||||
journal = {ArXiv},
|
||||
year = {2024},
|
||||
volume = {abs/2409.19606},
|
||||
url = {https://api.semanticscholar.org/CorpusID:272987528}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Fuller2025SimplerFV,
|
||||
title = {Simpler Fast Vision Transformers with a Jumbo CLS Token},
|
||||
author = {Anthony Fuller and Yousef Yassin and Daniel G. Kyrollos and Evan Shelhamer and James R. Green},
|
||||
year = {2025},
|
||||
url = {https://api.semanticscholar.org/CorpusID:276557720}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{xiong2025ndrope,
|
||||
author = {Jerry Xiong},
|
||||
title = {On n-dimensional rotary positional embeddings},
|
||||
year = {2025},
|
||||
url = {https://jerryxio.ng/posts/nd-rope/}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{anonymous2025vat,
|
||||
title = {{VAT}: Vision Action Transformer by Unlocking Full Representation of ViT},
|
||||
author = {Anonymous},
|
||||
booktitle = {Submitted to The Fourteenth International Conference on Learning Representations},
|
||||
year = {2025},
|
||||
url = {https://openreview.net/forum?id=TalHOvvLZu},
|
||||
note = {under review}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{carrigg2025decorrelationspeedsvisiontransformers,
|
||||
title = {Decorrelation Speeds Up Vision Transformers},
|
||||
author = {Kieran Carrigg and Rob van Gastel and Melda Yeghaian and Sander Dalm and Faysal Boughorbel and Marcel van Gerven},
|
||||
year = {2025},
|
||||
eprint = {2510.14657},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV},
|
||||
url = {https://arxiv.org/abs/2510.14657},
|
||||
}
|
||||
```
|
||||
|
||||
*I visualise a time when we will be to robots what dogs are to humans, and I’m rooting for the machines.* — Claude Shannon
|
||||
|
||||
63
pyproject.toml
Normal file
63
pyproject.toml
Normal file
@@ -0,0 +1,63 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=61", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "vit-pytorch"
|
||||
version = "1.16.2"
|
||||
description = "Vision Transformer (ViT) - Pytorch"
|
||||
readme = { file = "README.md", content-type = "text/markdown" }
|
||||
license = { file = "LICENSE" }
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" },
|
||||
]
|
||||
requires-python = ">=3.8"
|
||||
keywords = [
|
||||
"artificial intelligence",
|
||||
"attention mechanism",
|
||||
"image recognition",
|
||||
]
|
||||
classifiers = [
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Developers",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3 :: Only",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
]
|
||||
dependencies = [
|
||||
"einops>=0.7.0",
|
||||
"torch>=1.10",
|
||||
"torchvision",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
test = [
|
||||
"pytest",
|
||||
"torch==2.4.0",
|
||||
"torchvision==0.19.0",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/lucidrains/vit-pytorch"
|
||||
Repository = "https://github.com/lucidrains/vit-pytorch"
|
||||
|
||||
[tool.setuptools]
|
||||
include-package-data = true
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
include = ["vit_pytorch*"]
|
||||
exclude = ["examples*", "tests*", "test*"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests", "."]
|
||||
python_files = ["test_*.py", "*_test.py"]
|
||||
addopts = "-q"
|
||||
filterwarnings = [
|
||||
"ignore::FutureWarning",
|
||||
]
|
||||
42
setup.py
42
setup.py
@@ -1,42 +0,0 @@
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
with open('README.md') as f:
|
||||
long_description = f.read()
|
||||
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '1.7.5',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
long_description=long_description,
|
||||
long_description_content_type = 'text/markdown',
|
||||
author = 'Phil Wang',
|
||||
author_email = 'lucidrains@gmail.com',
|
||||
url = 'https://github.com/lucidrains/vit-pytorch',
|
||||
keywords = [
|
||||
'artificial intelligence',
|
||||
'attention mechanism',
|
||||
'image recognition'
|
||||
],
|
||||
install_requires=[
|
||||
'einops>=0.7.0',
|
||||
'torch>=1.10',
|
||||
'torchvision'
|
||||
],
|
||||
setup_requires=[
|
||||
'pytest-runner',
|
||||
],
|
||||
tests_require=[
|
||||
'pytest',
|
||||
'torch==1.12.1',
|
||||
'torchvision==0.13.1'
|
||||
],
|
||||
classifiers=[
|
||||
'Development Status :: 4 - Beta',
|
||||
'Intended Audience :: Developers',
|
||||
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||
'License :: OSI Approved :: MIT License',
|
||||
'Programming Language :: Python :: 3.6',
|
||||
],
|
||||
)
|
||||
BIN
tests/.DS_Store
vendored
Normal file
BIN
tests/.DS_Store
vendored
Normal file
Binary file not shown.
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
from vit_pytorch import ViT
|
||||
|
||||
def test():
|
||||
def test_vit():
|
||||
v = ViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
107
train_vit_decorr.py
Normal file
107
train_vit_decorr.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "accelerate",
|
||||
# "vit-pytorch",
|
||||
# "wandb"
|
||||
# ]
|
||||
# ///
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import torchvision.transforms as T
|
||||
from torchvision.datasets import CIFAR100
|
||||
|
||||
# constants
|
||||
|
||||
BATCH_SIZE = 32
|
||||
LEARNING_RATE = 3e-4
|
||||
EPOCHS = 10
|
||||
DECORR_LOSS_WEIGHT = 1e-1
|
||||
|
||||
TRACK_EXPERIMENT_ONLINE = False
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
# data
|
||||
|
||||
transform = T.Compose([
|
||||
T.ToTensor(),
|
||||
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
])
|
||||
|
||||
dataset = CIFAR100(
|
||||
root = 'data',
|
||||
download = True,
|
||||
train = True,
|
||||
transform = transform
|
||||
)
|
||||
|
||||
dataloader = DataLoader(dataset, batch_size = BATCH_SIZE, shuffle = True)
|
||||
|
||||
# model
|
||||
|
||||
from vit_pytorch.vit_with_decorr import ViT
|
||||
|
||||
vit = ViT(
|
||||
dim = 128,
|
||||
num_classes = 100,
|
||||
image_size = 32,
|
||||
patch_size = 4,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
dim_head = 64,
|
||||
mlp_dim = 128 * 4,
|
||||
decorr_sample_frac = 1. # use all tokens
|
||||
)
|
||||
|
||||
# optim
|
||||
|
||||
from torch.optim import Adam
|
||||
|
||||
optim = Adam(vit.parameters(), lr = LEARNING_RATE)
|
||||
|
||||
# prepare
|
||||
|
||||
from accelerate import Accelerator
|
||||
|
||||
accelerator = Accelerator()
|
||||
|
||||
vit, optim, dataloader = accelerator.prepare(vit, optim, dataloader)
|
||||
|
||||
# experiment
|
||||
|
||||
import wandb
|
||||
|
||||
wandb.init(
|
||||
project = 'vit-decorr',
|
||||
mode = 'disabled' if not TRACK_EXPERIMENT_ONLINE else 'online'
|
||||
)
|
||||
|
||||
wandb.run.name = 'baseline'
|
||||
|
||||
# loop
|
||||
|
||||
for _ in range(EPOCHS):
|
||||
for images, labels in dataloader:
|
||||
|
||||
logits, decorr_aux_loss = vit(images)
|
||||
loss = F.cross_entropy(logits, labels)
|
||||
|
||||
|
||||
total_loss = (
|
||||
loss +
|
||||
decorr_aux_loss * DECORR_LOSS_WEIGHT
|
||||
)
|
||||
|
||||
wandb.log(dict(loss = loss, decorr_loss = decorr_aux_loss))
|
||||
|
||||
accelerator.print(f'loss: {loss.item():.3f} | decorr aux loss: {decorr_aux_loss.item():.3f}')
|
||||
|
||||
accelerator.backward(total_loss)
|
||||
optim.step()
|
||||
optim.zero_grad()
|
||||
161
vit_pytorch/accept_video_wrapper.py
Normal file
161
vit_pytorch/accept_video_wrapper.py
Normal file
@@ -0,0 +1,161 @@
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from torch import is_tensor, randn
|
||||
from torch.nn import Module, Linear, Parameter
|
||||
from torch.utils._pytree import tree_flatten, tree_unflatten
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
def default(v, d):
|
||||
return v if exists(v) else d
|
||||
|
||||
# classes
|
||||
|
||||
class AcceptVideoWrapper(Module):
|
||||
def __init__(
|
||||
self,
|
||||
image_net: Module,
|
||||
forward_function = 'forward',
|
||||
add_time_pos_emb = False,
|
||||
dim_emb = None,
|
||||
time_seq_len = None,
|
||||
embed_is_channel_first = False,
|
||||
output_pos_add_pos_emb = 0, # defaults to first output position to add embedding
|
||||
proj_embed_to_dim = None
|
||||
):
|
||||
super().__init__()
|
||||
self.image_net = image_net
|
||||
self.forward_function = forward_function # for openclip, used in TRI-LBM
|
||||
|
||||
self.add_time_pos_emb = add_time_pos_emb
|
||||
self.output_pos_add_pos_emb = output_pos_add_pos_emb
|
||||
|
||||
# maybe project the image embedding
|
||||
|
||||
self.embed_proj = None
|
||||
|
||||
if exists(proj_embed_to_dim):
|
||||
assert exists(dim_emb), '`dim_emb` must be passed in'
|
||||
self.embed_proj = Linear(dim_emb, proj_embed_to_dim)
|
||||
|
||||
# time positional embedding
|
||||
|
||||
if add_time_pos_emb:
|
||||
assert exists(dim_emb) and exists(time_seq_len), '`dim_emb` and `time_seq_len` must be set if adding positional embeddings to the output'
|
||||
self.time_seq_len = time_seq_len
|
||||
|
||||
dim_pos_emb = default(proj_embed_to_dim, dim_emb)
|
||||
|
||||
self.pos_emb = Parameter(randn(time_seq_len, dim_pos_emb) * 1e-2)
|
||||
|
||||
self.embed_is_channel_first = embed_is_channel_first
|
||||
|
||||
def forward(
|
||||
self,
|
||||
video, # (b c t h w)
|
||||
eval_with_no_grad = False,
|
||||
forward_kwargs = dict()
|
||||
):
|
||||
add_time_pos_emb = self.add_time_pos_emb
|
||||
time = video.shape[2]
|
||||
|
||||
# maybe validate time positional embedding
|
||||
|
||||
if add_time_pos_emb:
|
||||
assert time <= self.time_seq_len, f'received video with {time} frames but `time_seq_len` ({self.time_seq_len}) is too low'
|
||||
|
||||
video = rearrange(video, 'b c t h w -> b t c h w')
|
||||
|
||||
video = rearrange(video, 'b t ... -> (b t) ...')
|
||||
|
||||
# forward through image net for outputs
|
||||
|
||||
func = getattr(self.image_net, self.forward_function)
|
||||
|
||||
if eval_with_no_grad:
|
||||
self.image_net.eval()
|
||||
|
||||
context = torch.no_grad if eval_with_no_grad else nullcontext
|
||||
|
||||
with context():
|
||||
outputs = func(video, **forward_kwargs)
|
||||
|
||||
# handle multiple outputs, say logits and embeddings returned from extractor - also handle some reduce aux loss being returned
|
||||
|
||||
outputs, tree_spec = tree_flatten(outputs)
|
||||
|
||||
outputs = tuple(rearrange(t, '(b t) ... -> b t ...', t = time) if is_tensor(t) and t.numel() > 1 else t for t in outputs)
|
||||
|
||||
# maybe project embedding
|
||||
|
||||
if exists(self.embed_proj):
|
||||
outputs = list(outputs)
|
||||
|
||||
embed = outputs[self.output_pos_add_pos_emb]
|
||||
|
||||
outputs[self.output_pos_add_pos_emb] = self.embed_proj(embed)
|
||||
|
||||
# maybe add time positional embedding
|
||||
|
||||
if add_time_pos_emb:
|
||||
|
||||
outputs = list(outputs)
|
||||
embed = outputs[self.output_pos_add_pos_emb]
|
||||
|
||||
pos_emb = rearrange(self.pos_emb, 't d -> 1 t d')
|
||||
|
||||
# handle the network outputting embeddings with spatial dimensions intact - assume embedded dimension is last
|
||||
|
||||
dims_to_unsqueeze = embed.ndim - pos_emb.ndim
|
||||
|
||||
one_dims = ((1,) * dims_to_unsqueeze)
|
||||
|
||||
if self.embed_is_channel_first:
|
||||
pos_emb = pos_emb.reshape(*pos_emb.shape, *one_dims)
|
||||
else:
|
||||
pos_emb = pos_emb.reshape(*pos_emb.shape[:2], *one_dims, pos_emb.shape[-1])
|
||||
|
||||
pos_emb = pos_emb[:, :embed.shape[1]]
|
||||
|
||||
embed = embed + pos_emb
|
||||
|
||||
outputs[self.output_pos_add_pos_emb] = embed
|
||||
|
||||
return tree_unflatten(outputs, tree_spec)
|
||||
|
||||
# main
|
||||
|
||||
if __name__ == '__main__':
|
||||
from vit_pytorch import ViT
|
||||
|
||||
v = ViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
videos = torch.randn(1, 3, 7, 256, 256)
|
||||
|
||||
# step up the difficulty and return embeddings for robotics
|
||||
|
||||
from vit_pytorch.extractor import Extractor
|
||||
v = Extractor(v)
|
||||
|
||||
video_acceptor = AcceptVideoWrapper(v, add_time_pos_emb = True, output_pos_add_pos_emb = 1, time_seq_len = 12, dim_emb = 1024, proj_embed_to_dim = 512)
|
||||
|
||||
logits, embeddings = video_acceptor(videos, eval_with_no_grad = True) # always (batch, channels, time, height, width) - time is always dimension 2
|
||||
|
||||
assert logits.shape == (1, 7, 1000)
|
||||
assert embeddings.shape == (1, 7, 65, 512)
|
||||
@@ -316,6 +316,9 @@ class CCT(nn.Module):
|
||||
pooling_kernel_size=3,
|
||||
pooling_stride=2,
|
||||
pooling_padding=1,
|
||||
dropout_rate=0.,
|
||||
attention_dropout=0.1,
|
||||
stochastic_depth_rate=0.1,
|
||||
*args, **kwargs
|
||||
):
|
||||
super().__init__()
|
||||
@@ -340,9 +343,9 @@ class CCT(nn.Module):
|
||||
width=img_width),
|
||||
embedding_dim=embedding_dim,
|
||||
seq_pool=True,
|
||||
dropout_rate=0.,
|
||||
attention_dropout=0.1,
|
||||
stochastic_depth=0.1,
|
||||
dropout_rate=dropout_rate,
|
||||
attention_dropout=attention_dropout,
|
||||
stochastic_depth_rate=stochastic_depth_rate,
|
||||
*args, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
@@ -167,8 +167,10 @@ class Tokenizer(nn.Module):
|
||||
stride,
|
||||
padding,
|
||||
frame_stride=1,
|
||||
frame_padding=None,
|
||||
frame_pooling_stride=1,
|
||||
frame_pooling_kernel_size=1,
|
||||
frame_pooling_padding=None,
|
||||
pooling_kernel_size=3,
|
||||
pooling_stride=2,
|
||||
pooling_padding=1,
|
||||
@@ -188,16 +190,22 @@ class Tokenizer(nn.Module):
|
||||
|
||||
n_filter_list_pairs = zip(n_filter_list[:-1], n_filter_list[1:])
|
||||
|
||||
if frame_padding is None:
|
||||
frame_padding = frame_kernel_size // 2
|
||||
|
||||
if frame_pooling_padding is None:
|
||||
frame_pooling_padding = frame_pooling_kernel_size // 2
|
||||
|
||||
self.conv_layers = nn.Sequential(
|
||||
*[nn.Sequential(
|
||||
nn.Conv3d(chan_in, chan_out,
|
||||
kernel_size=(frame_kernel_size, kernel_size, kernel_size),
|
||||
stride=(frame_stride, stride, stride),
|
||||
padding=(frame_kernel_size // 2, padding, padding), bias=conv_bias),
|
||||
padding=(frame_padding, padding, padding), bias=conv_bias),
|
||||
nn.Identity() if not exists(activation) else activation(),
|
||||
nn.MaxPool3d(kernel_size=(frame_pooling_kernel_size, pooling_kernel_size, pooling_kernel_size),
|
||||
stride=(frame_pooling_stride, pooling_stride, pooling_stride),
|
||||
padding=(frame_pooling_kernel_size // 2, pooling_padding, pooling_padding)) if max_pool else nn.Identity()
|
||||
padding=(frame_pooling_padding, pooling_padding, pooling_padding)) if max_pool else nn.Identity()
|
||||
)
|
||||
for chan_in, chan_out in n_filter_list_pairs
|
||||
])
|
||||
@@ -324,8 +332,10 @@ class CCT(nn.Module):
|
||||
n_conv_layers=1,
|
||||
frame_stride=1,
|
||||
frame_kernel_size=3,
|
||||
frame_padding=None,
|
||||
frame_pooling_kernel_size=1,
|
||||
frame_pooling_stride=1,
|
||||
frame_pooling_padding=None,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
padding=3,
|
||||
@@ -342,8 +352,10 @@ class CCT(nn.Module):
|
||||
n_output_channels=embedding_dim,
|
||||
frame_stride=frame_stride,
|
||||
frame_kernel_size=frame_kernel_size,
|
||||
frame_padding=frame_padding,
|
||||
frame_pooling_stride=frame_pooling_stride,
|
||||
frame_pooling_kernel_size=frame_pooling_kernel_size,
|
||||
frame_pooling_padding=frame_pooling_padding,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
|
||||
204
vit_pytorch/jumbo_vit.py
Normal file
204
vit_pytorch/jumbo_vit.py
Normal file
@@ -0,0 +1,204 @@
|
||||
# Simpler Fast Vision Transformers with a Jumbo CLS Token
|
||||
# https://arxiv.org/abs/2502.15021
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Module, ModuleList
|
||||
|
||||
from einops import rearrange, repeat, reduce, pack, unpack
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def divisible_by(num, den):
|
||||
return (num % den) == 0
|
||||
|
||||
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
|
||||
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
|
||||
assert divisible_by(dim, 4), "feature dimension must be multiple of 4 for sincos emb"
|
||||
|
||||
omega = torch.arange(dim // 4) / (dim // 4 - 1)
|
||||
omega = temperature ** -omega
|
||||
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
pos_emb = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
|
||||
|
||||
return pos_emb.type(dtype)
|
||||
|
||||
# classes
|
||||
|
||||
def FeedForward(dim, mult = 4.):
|
||||
hidden_dim = int(dim * mult)
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class JumboViT(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_size,
|
||||
patch_size,
|
||||
num_classes,
|
||||
dim,
|
||||
depth,
|
||||
heads,
|
||||
mlp_dim,
|
||||
num_jumbo_cls = 1, # differing from paper, allow for multiple jumbo cls, so one could break it up into 2 jumbo cls tokens with 3x the dim, as an example
|
||||
jumbo_cls_k = 6, # they use a CLS token with this factor times the dimension - 6 was the value they settled on
|
||||
jumbo_ff_mult = 2, # expansion factor of the jumbo cls token feedforward
|
||||
channels = 3,
|
||||
dim_head = 64
|
||||
):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert divisible_by(image_height, patch_height) and divisible_by(image_width, patch_width), 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = posemb_sincos_2d(
|
||||
h = image_height // patch_height,
|
||||
w = image_width // patch_width,
|
||||
dim = dim,
|
||||
)
|
||||
|
||||
jumbo_cls_dim = dim * jumbo_cls_k
|
||||
|
||||
self.jumbo_cls_token = nn.Parameter(torch.zeros(num_jumbo_cls, jumbo_cls_dim))
|
||||
|
||||
jumbo_cls_to_tokens = Rearrange('b n (k d) -> b (n k) d', k = jumbo_cls_k)
|
||||
self.jumbo_cls_to_tokens = jumbo_cls_to_tokens
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = ModuleList([])
|
||||
|
||||
# attention and feedforwards
|
||||
|
||||
self.jumbo_ff = nn.Sequential(
|
||||
Rearrange('b (n k) d -> b n (k d)', k = jumbo_cls_k),
|
||||
FeedForward(jumbo_cls_dim, int(jumbo_cls_dim * jumbo_ff_mult)), # they use separate parameters for the jumbo feedforward, weight tied for parameter efficient
|
||||
jumbo_cls_to_tokens
|
||||
)
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head),
|
||||
FeedForward(dim, mlp_dim),
|
||||
]))
|
||||
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.linear_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, img):
|
||||
|
||||
batch, device = img.shape[0], img.device
|
||||
|
||||
x = self.to_patch_embedding(img)
|
||||
|
||||
# pos embedding
|
||||
|
||||
pos_emb = self.pos_embedding.to(device, dtype = x.dtype)
|
||||
|
||||
x = x + pos_emb
|
||||
|
||||
# add cls tokens
|
||||
|
||||
cls_tokens = repeat(self.jumbo_cls_token, 'nj d -> b nj d', b = batch)
|
||||
|
||||
jumbo_tokens = self.jumbo_cls_to_tokens(cls_tokens)
|
||||
|
||||
x, cls_packed_shape = pack([jumbo_tokens, x], 'b * d')
|
||||
|
||||
# attention and feedforwards
|
||||
|
||||
for layer, (attn, ff) in enumerate(self.layers, start = 1):
|
||||
is_last = layer == len(self.layers)
|
||||
|
||||
x = attn(x) + x
|
||||
|
||||
# jumbo feedforward
|
||||
|
||||
jumbo_cls_tokens, x = unpack(x, cls_packed_shape, 'b * d')
|
||||
|
||||
x = ff(x) + x
|
||||
jumbo_cls_tokens = self.jumbo_ff(jumbo_cls_tokens) + jumbo_cls_tokens
|
||||
|
||||
if is_last:
|
||||
continue
|
||||
|
||||
x, _ = pack([jumbo_cls_tokens, x], 'b * d')
|
||||
|
||||
pooled = reduce(jumbo_cls_tokens, 'b n d -> b d', 'mean')
|
||||
|
||||
# normalization and project to logits
|
||||
|
||||
embed = self.norm(pooled)
|
||||
|
||||
embed = self.to_latent(embed)
|
||||
logits = self.linear_head(embed)
|
||||
return logits
|
||||
|
||||
# copy pasteable file
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
v = JumboViT(
|
||||
num_classes = 1000,
|
||||
image_size = 64,
|
||||
patch_size = 8,
|
||||
dim = 16,
|
||||
depth = 2,
|
||||
heads = 2,
|
||||
mlp_dim = 32,
|
||||
jumbo_cls_k = 3,
|
||||
jumbo_ff_mult = 2,
|
||||
)
|
||||
|
||||
images = torch.randn(1, 3, 64, 64)
|
||||
|
||||
logits = v(images)
|
||||
assert logits.shape == (1, 1000)
|
||||
@@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
from typing import List, Union
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -245,7 +247,7 @@ class NaViT(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batched_images: Union[List[Tensor], List[List[Tensor]]], # assume different resolution images already grouped correctly
|
||||
batched_images: List[Tensor] | List[List[Tensor]], # assume different resolution images already grouped correctly
|
||||
group_images = False,
|
||||
group_max_seq_len = 2048
|
||||
):
|
||||
@@ -264,6 +266,11 @@ class NaViT(nn.Module):
|
||||
max_seq_len = group_max_seq_len
|
||||
)
|
||||
|
||||
# if List[Tensor] is not grouped -> List[List[Tensor]]
|
||||
|
||||
if torch.is_tensor(batched_images[0]):
|
||||
batched_images = [batched_images]
|
||||
|
||||
# process images into variable lengthed sequences with attention mask
|
||||
|
||||
num_images = []
|
||||
|
||||
330
vit_pytorch/na_vit_nested_tensor.py
Normal file
330
vit_pytorch/na_vit_nested_tensor.py
Normal file
@@ -0,0 +1,330 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import packaging.version as pkg_version
|
||||
|
||||
from torch import nn, Tensor
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Module, ModuleList
|
||||
from torch.nested import nested_tensor
|
||||
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def divisible_by(numer, denom):
|
||||
return (numer % denom) == 0
|
||||
|
||||
# feedforward
|
||||
|
||||
def FeedForward(dim, hidden_dim, dropout = 0.):
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim, bias = False),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., qk_norm = True):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim, bias = False)
|
||||
|
||||
dim_inner = heads * dim_head
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
self.to_queries = nn.Linear(dim, dim_inner, bias = False)
|
||||
self.to_keys = nn.Linear(dim, dim_inner, bias = False)
|
||||
self.to_values = nn.Linear(dim, dim_inner, bias = False)
|
||||
|
||||
# in the paper, they employ qk rmsnorm, a way to stabilize attention
|
||||
# will use layernorm in place of rmsnorm, which has been shown to work in certain papers. requires l2norm on non-ragged dimension to be supported in nested tensors
|
||||
|
||||
self.query_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity()
|
||||
self.key_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity()
|
||||
|
||||
self.dropout = dropout
|
||||
|
||||
self.to_out = nn.Linear(dim_inner, dim, bias = False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
context: Tensor | None = None
|
||||
):
|
||||
x = self.norm(x)
|
||||
|
||||
# for attention pooling, one query pooling to entire sequence
|
||||
|
||||
context = default(context, x)
|
||||
|
||||
# queries, keys, values
|
||||
|
||||
query = self.to_queries(x)
|
||||
key = self.to_keys(context)
|
||||
value = self.to_values(context)
|
||||
|
||||
# split heads
|
||||
|
||||
def split_heads(t):
|
||||
return t.unflatten(-1, (self.heads, self.dim_head))
|
||||
|
||||
def transpose_head_seq(t):
|
||||
return t.transpose(1, 2)
|
||||
|
||||
query, key, value = map(split_heads, (query, key, value))
|
||||
|
||||
# qk norm for attention stability
|
||||
|
||||
query = self.query_norm(query)
|
||||
key = self.key_norm(key)
|
||||
|
||||
query, key, value = map(transpose_head_seq, (query, key, value))
|
||||
|
||||
# attention
|
||||
|
||||
out = F.scaled_dot_product_attention(
|
||||
query, key, value,
|
||||
dropout_p = self.dropout if self.training else 0.
|
||||
)
|
||||
|
||||
# merge heads
|
||||
|
||||
out = out.transpose(1, 2).flatten(-2)
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., qk_norm = True):
|
||||
super().__init__()
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, qk_norm = qk_norm),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
self.norm = nn.LayerNorm(dim, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class NaViT(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_size,
|
||||
patch_size,
|
||||
num_classes,
|
||||
dim,
|
||||
depth,
|
||||
heads,
|
||||
mlp_dim,
|
||||
channels = 3,
|
||||
dim_head = 64,
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.,
|
||||
qk_rmsnorm = True,
|
||||
token_dropout_prob: float | None = None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if pkg_version.parse(torch.__version__) < pkg_version.parse('2.5'):
|
||||
print('nested tensor NaViT was tested on pytorch 2.5')
|
||||
|
||||
|
||||
image_height, image_width = pair(image_size)
|
||||
|
||||
# what percent of tokens to dropout
|
||||
# if int or float given, then assume constant dropout prob
|
||||
# otherwise accept a callback that in turn calculates dropout prob from height and width
|
||||
|
||||
self.token_dropout_prob = token_dropout_prob
|
||||
|
||||
# calculate patching related stuff
|
||||
|
||||
assert divisible_by(image_height, patch_size) and divisible_by(image_width, patch_size), 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
patch_height_dim, patch_width_dim = (image_height // patch_size), (image_width // patch_size)
|
||||
patch_dim = channels * (patch_size ** 2)
|
||||
|
||||
self.channels = channels
|
||||
self.patch_size = patch_size
|
||||
self.to_patches = Rearrange('c (h p1) (w p2) -> h w (c p1 p2)', p1 = patch_size, p2 = patch_size)
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.pos_embed_height = nn.Parameter(torch.randn(patch_height_dim, dim))
|
||||
self.pos_embed_width = nn.Parameter(torch.randn(patch_width_dim, dim))
|
||||
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, qk_rmsnorm)
|
||||
|
||||
# final attention pooling queries
|
||||
|
||||
self.attn_pool_queries = nn.Parameter(torch.randn(dim))
|
||||
self.attn_pool = Attention(dim = dim, dim_head = dim_head, heads = heads)
|
||||
|
||||
# output to logits
|
||||
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim, bias = False),
|
||||
nn.Linear(dim, num_classes, bias = False)
|
||||
)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
def forward(
|
||||
self,
|
||||
images: List[Tensor], # different resolution images
|
||||
):
|
||||
batch, device = len(images), self.device
|
||||
arange = partial(torch.arange, device = device)
|
||||
|
||||
assert all([image.ndim == 3 and image.shape[0] == self.channels for image in images]), f'all images must have {self.channels} channels and number of dimensions of 3 (channels, height, width)'
|
||||
|
||||
all_patches = [self.to_patches(image) for image in images]
|
||||
|
||||
# prepare factorized positional embedding height width indices
|
||||
|
||||
positions = []
|
||||
|
||||
for patches in all_patches:
|
||||
patch_height, patch_width = patches.shape[:2]
|
||||
hw_indices = torch.stack(torch.meshgrid((arange(patch_height), arange(patch_width)), indexing = 'ij'), dim = -1)
|
||||
hw_indices = rearrange(hw_indices, 'h w c -> (h w) c')
|
||||
positions.append(hw_indices)
|
||||
|
||||
# need the sizes to compute token dropout + positional embedding
|
||||
|
||||
tokens = [rearrange(patches, 'h w d -> (h w) d') for patches in all_patches]
|
||||
|
||||
# handle token dropout
|
||||
|
||||
seq_lens = torch.tensor([i.shape[0] for i in tokens], device = device)
|
||||
|
||||
if self.training and self.token_dropout_prob > 0:
|
||||
|
||||
keep_seq_lens = ((1. - self.token_dropout_prob) * seq_lens).int().clamp(min = 1)
|
||||
|
||||
kept_tokens = []
|
||||
kept_positions = []
|
||||
|
||||
for one_image_tokens, one_image_positions, seq_len, num_keep in zip(tokens, positions, seq_lens, keep_seq_lens):
|
||||
keep_indices = torch.randn((seq_len,), device = device).topk(num_keep, dim = -1).indices
|
||||
|
||||
one_image_kept_tokens = one_image_tokens[keep_indices]
|
||||
one_image_kept_positions = one_image_positions[keep_indices]
|
||||
|
||||
kept_tokens.append(one_image_kept_tokens)
|
||||
kept_positions.append(one_image_kept_positions)
|
||||
|
||||
tokens, positions, seq_lens = kept_tokens, kept_positions, keep_seq_lens
|
||||
|
||||
# add all height and width factorized positions
|
||||
|
||||
height_indices, width_indices = torch.cat(positions).unbind(dim = -1)
|
||||
height_embed, width_embed = self.pos_embed_height[height_indices], self.pos_embed_width[width_indices]
|
||||
|
||||
pos_embed = height_embed + width_embed
|
||||
|
||||
# use nested tensor for transformers and save on padding computation
|
||||
|
||||
tokens = torch.cat(tokens)
|
||||
|
||||
# linear projection to patch embeddings
|
||||
|
||||
tokens = self.to_patch_embedding(tokens)
|
||||
|
||||
# absolute positions
|
||||
|
||||
tokens = tokens + pos_embed
|
||||
|
||||
tokens = nested_tensor(tokens.split(seq_lens.tolist()), layout = torch.jagged, device = device)
|
||||
|
||||
# embedding dropout
|
||||
|
||||
tokens = self.dropout(tokens)
|
||||
|
||||
# transformer
|
||||
|
||||
tokens = self.transformer(tokens)
|
||||
|
||||
# attention pooling
|
||||
# will use a jagged tensor for queries, as SDPA requires all inputs to be jagged, or not
|
||||
|
||||
attn_pool_queries = [rearrange(self.attn_pool_queries, '... -> 1 ...')] * batch
|
||||
|
||||
attn_pool_queries = nested_tensor(attn_pool_queries, layout = torch.jagged)
|
||||
|
||||
pooled = self.attn_pool(attn_pool_queries, tokens)
|
||||
|
||||
# back to unjagged
|
||||
|
||||
logits = torch.stack(pooled.unbind())
|
||||
|
||||
logits = rearrange(logits, 'b 1 d -> b d')
|
||||
|
||||
logits = self.to_latent(logits)
|
||||
|
||||
return self.mlp_head(logits)
|
||||
|
||||
# quick test
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
v = NaViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.,
|
||||
token_dropout_prob = 0.1
|
||||
)
|
||||
|
||||
# 5 images of different resolutions - List[Tensor]
|
||||
|
||||
images = [
|
||||
torch.randn(3, 256, 256), torch.randn(3, 128, 128),
|
||||
torch.randn(3, 128, 256), torch.randn(3, 256, 128),
|
||||
torch.randn(3, 64, 256)
|
||||
]
|
||||
|
||||
assert v(images).shape == (5, 1000)
|
||||
|
||||
v(images).sum().backward()
|
||||
356
vit_pytorch/na_vit_nested_tensor_3d.py
Normal file
356
vit_pytorch/na_vit_nested_tensor_3d.py
Normal file
@@ -0,0 +1,356 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import packaging.version as pkg_version
|
||||
|
||||
from torch import nn, Tensor
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Module, ModuleList
|
||||
from torch.nested import nested_tensor
|
||||
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def divisible_by(numer, denom):
|
||||
return (numer % denom) == 0
|
||||
|
||||
# feedforward
|
||||
|
||||
def FeedForward(dim, hidden_dim, dropout = 0.):
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim, bias = False),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., qk_norm = True):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim, bias = False)
|
||||
|
||||
dim_inner = heads * dim_head
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
self.to_queries = nn.Linear(dim, dim_inner, bias = False)
|
||||
self.to_keys = nn.Linear(dim, dim_inner, bias = False)
|
||||
self.to_values = nn.Linear(dim, dim_inner, bias = False)
|
||||
|
||||
# in the paper, they employ qk rmsnorm, a way to stabilize attention
|
||||
# will use layernorm in place of rmsnorm, which has been shown to work in certain papers. requires l2norm on non-ragged dimension to be supported in nested tensors
|
||||
|
||||
self.query_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity()
|
||||
self.key_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity()
|
||||
|
||||
self.dropout = dropout
|
||||
|
||||
self.to_out = nn.Linear(dim_inner, dim, bias = False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
context: Tensor | None = None
|
||||
):
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
# for attention pooling, one query pooling to entire sequence
|
||||
|
||||
context = default(context, x)
|
||||
|
||||
# queries, keys, values
|
||||
|
||||
query = self.to_queries(x)
|
||||
key = self.to_keys(context)
|
||||
value = self.to_values(context)
|
||||
|
||||
# split heads
|
||||
|
||||
def split_heads(t):
|
||||
return t.unflatten(-1, (self.heads, self.dim_head))
|
||||
|
||||
def transpose_head_seq(t):
|
||||
return t.transpose(1, 2)
|
||||
|
||||
query, key, value = map(split_heads, (query, key, value))
|
||||
|
||||
# qk norm for attention stability
|
||||
|
||||
query = self.query_norm(query)
|
||||
key = self.key_norm(key)
|
||||
|
||||
query, key, value = map(transpose_head_seq, (query, key, value))
|
||||
|
||||
# attention
|
||||
|
||||
out = F.scaled_dot_product_attention(
|
||||
query, key, value,
|
||||
dropout_p = self.dropout if self.training else 0.
|
||||
)
|
||||
|
||||
# merge heads
|
||||
|
||||
out = out.transpose(1, 2).flatten(-2)
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., qk_norm = True):
|
||||
super().__init__()
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, qk_norm = qk_norm),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
self.norm = nn.LayerNorm(dim, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class NaViT(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_size,
|
||||
max_frames,
|
||||
patch_size,
|
||||
frame_patch_size,
|
||||
num_classes,
|
||||
dim,
|
||||
depth,
|
||||
heads,
|
||||
mlp_dim,
|
||||
channels = 3,
|
||||
dim_head = 64,
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.,
|
||||
num_registers = 4,
|
||||
qk_rmsnorm = True,
|
||||
token_dropout_prob: float | None = None
|
||||
):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
|
||||
if pkg_version.parse(torch.__version__) < pkg_version.parse('2.5'):
|
||||
print('nested tensor NaViT was tested on pytorch 2.5')
|
||||
|
||||
# what percent of tokens to dropout
|
||||
# if int or float given, then assume constant dropout prob
|
||||
# otherwise accept a callback that in turn calculates dropout prob from height and width
|
||||
|
||||
self.token_dropout_prob = token_dropout_prob
|
||||
|
||||
# calculate patching related stuff
|
||||
|
||||
assert divisible_by(image_height, patch_size) and divisible_by(image_width, patch_size), 'Image dimensions must be divisible by the patch size.'
|
||||
assert divisible_by(max_frames, frame_patch_size)
|
||||
|
||||
patch_frame_dim, patch_height_dim, patch_width_dim = (max_frames // frame_patch_size), (image_height // patch_size), (image_width // patch_size)
|
||||
|
||||
patch_dim = channels * (patch_size ** 2) * frame_patch_size
|
||||
|
||||
self.channels = channels
|
||||
self.patch_size = patch_size
|
||||
self.to_patches = Rearrange('c (f pf) (h p1) (w p2) -> f h w (c pf p1 p2)', p1 = patch_size, p2 = patch_size, pf = frame_patch_size)
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.pos_embed_frame = nn.Parameter(torch.zeros(patch_frame_dim, dim))
|
||||
self.pos_embed_height = nn.Parameter(torch.zeros(patch_height_dim, dim))
|
||||
self.pos_embed_width = nn.Parameter(torch.zeros(patch_width_dim, dim))
|
||||
|
||||
# register tokens
|
||||
|
||||
self.register_tokens = nn.Parameter(torch.zeros(num_registers, dim))
|
||||
|
||||
nn.init.normal_(self.pos_embed_frame, std = 0.02)
|
||||
nn.init.normal_(self.pos_embed_height, std = 0.02)
|
||||
nn.init.normal_(self.pos_embed_width, std = 0.02)
|
||||
nn.init.normal_(self.register_tokens, std = 0.02)
|
||||
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, qk_rmsnorm)
|
||||
|
||||
# final attention pooling queries
|
||||
|
||||
self.attn_pool_queries = nn.Parameter(torch.randn(dim))
|
||||
self.attn_pool = Attention(dim = dim, dim_head = dim_head, heads = heads)
|
||||
|
||||
# output to logits
|
||||
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim, bias = False),
|
||||
nn.Linear(dim, num_classes, bias = False)
|
||||
)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
def forward(
|
||||
self,
|
||||
volumes: List[Tensor], # different resolution images / CT scans
|
||||
):
|
||||
batch, device = len(volumes), self.device
|
||||
arange = partial(torch.arange, device = device)
|
||||
|
||||
assert all([volume.ndim == 4 and volume.shape[0] == self.channels for volume in volumes]), f'all volumes must have {self.channels} channels and number of dimensions of {self.channels} (channels, frame, height, width)'
|
||||
|
||||
all_patches = [self.to_patches(volume) for volume in volumes]
|
||||
|
||||
# prepare factorized positional embedding height width indices
|
||||
|
||||
positions = []
|
||||
|
||||
for patches in all_patches:
|
||||
patch_frame, patch_height, patch_width = patches.shape[:3]
|
||||
fhw_indices = torch.stack(torch.meshgrid((arange(patch_frame), arange(patch_height), arange(patch_width)), indexing = 'ij'), dim = -1)
|
||||
fhw_indices = rearrange(fhw_indices, 'f h w c -> (f h w) c')
|
||||
|
||||
positions.append(fhw_indices)
|
||||
|
||||
# need the sizes to compute token dropout + positional embedding
|
||||
|
||||
tokens = [rearrange(patches, 'f h w d -> (f h w) d') for patches in all_patches]
|
||||
|
||||
# handle token dropout
|
||||
|
||||
seq_lens = torch.tensor([i.shape[0] for i in tokens], device = device)
|
||||
|
||||
if self.training and self.token_dropout_prob > 0:
|
||||
|
||||
keep_seq_lens = ((1. - self.token_dropout_prob) * seq_lens).int().clamp(min = 1)
|
||||
|
||||
kept_tokens = []
|
||||
kept_positions = []
|
||||
|
||||
for one_image_tokens, one_image_positions, seq_len, num_keep in zip(tokens, positions, seq_lens, keep_seq_lens):
|
||||
keep_indices = torch.randn((seq_len,), device = device).topk(num_keep, dim = -1).indices
|
||||
|
||||
one_image_kept_tokens = one_image_tokens[keep_indices]
|
||||
one_image_kept_positions = one_image_positions[keep_indices]
|
||||
|
||||
kept_tokens.append(one_image_kept_tokens)
|
||||
kept_positions.append(one_image_kept_positions)
|
||||
|
||||
tokens, positions, seq_lens = kept_tokens, kept_positions, keep_seq_lens
|
||||
|
||||
# add all height and width factorized positions
|
||||
|
||||
|
||||
frame_indices, height_indices, width_indices = torch.cat(positions).unbind(dim = -1)
|
||||
frame_embed, height_embed, width_embed = self.pos_embed_frame[frame_indices], self.pos_embed_height[height_indices], self.pos_embed_width[width_indices]
|
||||
|
||||
pos_embed = frame_embed + height_embed + width_embed
|
||||
|
||||
tokens = torch.cat(tokens)
|
||||
|
||||
# linear projection to patch embeddings
|
||||
|
||||
tokens = self.to_patch_embedding(tokens)
|
||||
|
||||
# absolute positions
|
||||
|
||||
tokens = tokens + pos_embed
|
||||
|
||||
# add register tokens
|
||||
|
||||
tokens = tokens.split(seq_lens.tolist())
|
||||
|
||||
tokens = [torch.cat((self.register_tokens, one_tokens)) for one_tokens in tokens]
|
||||
|
||||
# use nested tensor for transformers and save on padding computation
|
||||
|
||||
tokens = nested_tensor(tokens, layout = torch.jagged, device = device)
|
||||
|
||||
# embedding dropout
|
||||
|
||||
tokens = self.dropout(tokens)
|
||||
|
||||
# transformer
|
||||
|
||||
tokens = self.transformer(tokens)
|
||||
|
||||
# attention pooling
|
||||
# will use a jagged tensor for queries, as SDPA requires all inputs to be jagged, or not
|
||||
|
||||
attn_pool_queries = [rearrange(self.attn_pool_queries, '... -> 1 ...')] * batch
|
||||
|
||||
attn_pool_queries = nested_tensor(attn_pool_queries, layout = torch.jagged)
|
||||
|
||||
pooled = self.attn_pool(attn_pool_queries, tokens)
|
||||
|
||||
# back to unjagged
|
||||
|
||||
logits = torch.stack(pooled.unbind())
|
||||
|
||||
logits = rearrange(logits, 'b 1 d -> b d')
|
||||
|
||||
logits = self.to_latent(logits)
|
||||
|
||||
return self.mlp_head(logits)
|
||||
|
||||
# quick test
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# works for torch 2.5
|
||||
|
||||
v = NaViT(
|
||||
image_size = 256,
|
||||
max_frames = 8,
|
||||
patch_size = 32,
|
||||
frame_patch_size = 2,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.,
|
||||
token_dropout_prob = 0.1
|
||||
)
|
||||
|
||||
# 5 volumetric data (videos or CT scans) of different resolutions - List[Tensor]
|
||||
|
||||
volumes = [
|
||||
torch.randn(3, 2, 256, 256), torch.randn(3, 8, 128, 128),
|
||||
torch.randn(3, 4, 128, 256), torch.randn(3, 2, 256, 128),
|
||||
torch.randn(3, 4, 64, 256)
|
||||
]
|
||||
|
||||
assert v(volumes).shape == (5, 1000)
|
||||
|
||||
v(volumes).sum().backward()
|
||||
264
vit_pytorch/normalized_vit.py
Normal file
264
vit_pytorch/normalized_vit.py
Normal file
@@ -0,0 +1,264 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Module, ModuleList
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.utils.parametrize as parametrize
|
||||
|
||||
from einops import rearrange, reduce
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# functions
|
||||
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
def default(v, d):
|
||||
return v if exists(v) else d
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def divisible_by(numer, denom):
|
||||
return (numer % denom) == 0
|
||||
|
||||
def l2norm(t, dim = -1):
|
||||
return F.normalize(t, dim = dim, p = 2)
|
||||
|
||||
# for use with parametrize
|
||||
|
||||
class L2Norm(Module):
|
||||
def __init__(self, dim = -1):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, t):
|
||||
return l2norm(t, dim = self.dim)
|
||||
|
||||
class NormLinear(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_out,
|
||||
norm_dim_in = True
|
||||
):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(dim, dim_out, bias = False)
|
||||
|
||||
parametrize.register_parametrization(
|
||||
self.linear,
|
||||
'weight',
|
||||
L2Norm(dim = -1 if norm_dim_in else 0)
|
||||
)
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.linear.weight
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
# attention and feedforward
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
dim_head = 64,
|
||||
heads = 8,
|
||||
dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
dim_inner = dim_head * heads
|
||||
self.to_q = NormLinear(dim, dim_inner)
|
||||
self.to_k = NormLinear(dim, dim_inner)
|
||||
self.to_v = NormLinear(dim, dim_inner)
|
||||
|
||||
self.dropout = dropout
|
||||
|
||||
self.q_scale = nn.Parameter(torch.ones(heads, 1, dim_head) * (dim_head ** 0.25))
|
||||
self.k_scale = nn.Parameter(torch.ones(heads, 1, dim_head) * (dim_head ** 0.25))
|
||||
|
||||
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
|
||||
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
||||
|
||||
self.to_out = NormLinear(dim_inner, dim, norm_dim_in = False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x
|
||||
):
|
||||
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
|
||||
|
||||
q, k, v = map(self.split_heads, (q, k, v))
|
||||
|
||||
# query key rmsnorm
|
||||
|
||||
q, k = map(l2norm, (q, k))
|
||||
|
||||
q = q * self.q_scale
|
||||
k = k * self.k_scale
|
||||
|
||||
# scale is 1., as scaling factor is moved to s_qk (dk ^ 0.25) - eq. 16
|
||||
|
||||
out = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
dropout_p = self.dropout if self.training else 0.,
|
||||
scale = 1.
|
||||
)
|
||||
|
||||
out = self.merge_heads(out)
|
||||
return self.to_out(out)
|
||||
|
||||
class FeedForward(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
dim_inner,
|
||||
dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
dim_inner = int(dim_inner * 2 / 3)
|
||||
|
||||
self.dim = dim
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_hidden = NormLinear(dim, dim_inner)
|
||||
self.to_gate = NormLinear(dim, dim_inner)
|
||||
|
||||
self.hidden_scale = nn.Parameter(torch.ones(dim_inner))
|
||||
self.gate_scale = nn.Parameter(torch.ones(dim_inner))
|
||||
|
||||
self.to_out = NormLinear(dim_inner, dim, norm_dim_in = False)
|
||||
|
||||
def forward(self, x):
|
||||
hidden, gate = self.to_hidden(x), self.to_gate(x)
|
||||
|
||||
hidden = hidden * self.hidden_scale
|
||||
gate = gate * self.gate_scale * (self.dim ** 0.5)
|
||||
|
||||
hidden = F.silu(gate) * hidden
|
||||
|
||||
hidden = self.dropout(hidden)
|
||||
return self.to_out(hidden)
|
||||
|
||||
# classes
|
||||
|
||||
class nViT(Module):
|
||||
""" https://arxiv.org/abs/2410.01131 """
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_size,
|
||||
patch_size,
|
||||
num_classes,
|
||||
dim,
|
||||
depth,
|
||||
heads,
|
||||
mlp_dim,
|
||||
dropout = 0.,
|
||||
channels = 3,
|
||||
dim_head = 64,
|
||||
residual_lerp_scale_init = None
|
||||
):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
|
||||
# calculate patching related stuff
|
||||
|
||||
assert divisible_by(image_height, patch_size) and divisible_by(image_width, patch_size), 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
patch_height_dim, patch_width_dim = (image_height // patch_size), (image_width // patch_size)
|
||||
patch_dim = channels * (patch_size ** 2)
|
||||
num_patches = patch_height_dim * patch_width_dim
|
||||
|
||||
self.channels = channels
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1 = patch_size, p2 = patch_size),
|
||||
NormLinear(patch_dim, dim, norm_dim_in = False),
|
||||
)
|
||||
|
||||
self.abs_pos_emb = NormLinear(dim, num_patches)
|
||||
|
||||
residual_lerp_scale_init = default(residual_lerp_scale_init, 1. / depth)
|
||||
|
||||
# layers
|
||||
|
||||
self.dim = dim
|
||||
self.scale = dim ** 0.5
|
||||
|
||||
self.layers = ModuleList([])
|
||||
self.residual_lerp_scales = nn.ParameterList([])
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, dim_head = dim_head, heads = heads, dropout = dropout),
|
||||
FeedForward(dim, dim_inner = mlp_dim, dropout = dropout),
|
||||
]))
|
||||
|
||||
self.residual_lerp_scales.append(nn.ParameterList([
|
||||
nn.Parameter(torch.ones(dim) * residual_lerp_scale_init / self.scale),
|
||||
nn.Parameter(torch.ones(dim) * residual_lerp_scale_init / self.scale),
|
||||
]))
|
||||
|
||||
self.logit_scale = nn.Parameter(torch.ones(num_classes))
|
||||
|
||||
self.to_pred = NormLinear(dim, num_classes)
|
||||
|
||||
@torch.no_grad()
|
||||
def norm_weights_(self):
|
||||
for module in self.modules():
|
||||
if not isinstance(module, NormLinear):
|
||||
continue
|
||||
|
||||
normed = module.weight
|
||||
original = module.linear.parametrizations.weight.original
|
||||
|
||||
original.copy_(normed)
|
||||
|
||||
def forward(self, images):
|
||||
device = images.device
|
||||
|
||||
tokens = self.to_patch_embedding(images)
|
||||
|
||||
seq_len = tokens.shape[-2]
|
||||
pos_emb = self.abs_pos_emb.weight[torch.arange(seq_len, device = device)]
|
||||
|
||||
tokens = l2norm(tokens + pos_emb)
|
||||
|
||||
for (attn, ff), (attn_alpha, ff_alpha) in zip(self.layers, self.residual_lerp_scales):
|
||||
|
||||
attn_out = l2norm(attn(tokens))
|
||||
tokens = l2norm(tokens.lerp(attn_out, attn_alpha * self.scale))
|
||||
|
||||
ff_out = l2norm(ff(tokens))
|
||||
tokens = l2norm(tokens.lerp(ff_out, ff_alpha * self.scale))
|
||||
|
||||
pooled = reduce(tokens, 'b n d -> b d', 'mean')
|
||||
|
||||
logits = self.to_pred(pooled)
|
||||
logits = logits * self.logit_scale * self.scale
|
||||
|
||||
return logits
|
||||
|
||||
# quick test
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
v = nViT(
|
||||
image_size = 256,
|
||||
patch_size = 16,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048,
|
||||
)
|
||||
|
||||
img = torch.randn(4, 3, 256, 256)
|
||||
logits = v(img) # (4, 1000)
|
||||
assert logits.shape == (4, 1000)
|
||||
@@ -20,6 +20,18 @@ def divisible_by(val, d):
|
||||
|
||||
# helper classes
|
||||
|
||||
class ChanLayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps = 1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
@@ -212,10 +224,10 @@ class RegionViT(nn.Module):
|
||||
if tokenize_local_3_conv:
|
||||
self.local_encoder = nn.Sequential(
|
||||
nn.Conv2d(3, init_dim, 3, 2, 1),
|
||||
nn.LayerNorm(init_dim),
|
||||
ChanLayerNorm(init_dim),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(init_dim, init_dim, 3, 2, 1),
|
||||
nn.LayerNorm(init_dim),
|
||||
ChanLayerNorm(init_dim),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(init_dim, init_dim, 3, 1, 1)
|
||||
)
|
||||
|
||||
@@ -3,14 +3,14 @@ from math import sqrt, pi, log
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
from torch.cuda.amp import autocast
|
||||
from torch.amp import autocast
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# rotary embeddings
|
||||
|
||||
@autocast(enabled = False)
|
||||
@autocast('cuda', enabled = False)
|
||||
def rotate_every_two(x):
|
||||
x = rearrange(x, '... (d j) -> ... d j', j = 2)
|
||||
x1, x2 = x.unbind(dim = -1)
|
||||
@@ -24,7 +24,7 @@ class AxialRotaryEmbedding(nn.Module):
|
||||
scales = torch.linspace(1., max_freq / 2, self.dim // 4)
|
||||
self.register_buffer('scales', scales)
|
||||
|
||||
@autocast(enabled = False)
|
||||
@autocast('cuda', enabled = False)
|
||||
def forward(self, x):
|
||||
device, dtype, n = x.device, x.dtype, int(sqrt(x.shape[-2]))
|
||||
|
||||
|
||||
@@ -146,7 +146,7 @@ class SimpleViT(Module):
|
||||
patch_dim = channels * patch_height * patch_width * frame_patch_size
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (pf p1 p2 c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
|
||||
@@ -103,7 +103,7 @@ class SimpleViT(nn.Module):
|
||||
patch_dim = channels * patch_height * patch_width * frame_patch_size
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (pf p1 p2 c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
|
||||
233
vit_pytorch/simple_vit_with_hyper_connections.py
Normal file
233
vit_pytorch/simple_vit_with_hyper_connections.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""
|
||||
ViT + Hyper-Connections + Register Tokens
|
||||
https://arxiv.org/abs/2409.19606
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn, tensor
|
||||
from torch.nn import Module, ModuleList
|
||||
|
||||
from einops import rearrange, repeat, reduce, einsum, pack, unpack
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# b - batch, h - heads, n - sequence, e - expansion rate / residual streams, d - feature dimension
|
||||
|
||||
# helpers
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
|
||||
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
|
||||
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
|
||||
omega = torch.arange(dim // 4) / (dim // 4 - 1)
|
||||
omega = 1.0 / (temperature ** omega)
|
||||
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
|
||||
return pe.type(dtype)
|
||||
|
||||
# hyper connections
|
||||
|
||||
class HyperConnection(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_residual_streams,
|
||||
layer_index
|
||||
):
|
||||
""" Appendix J - Algorithm 2, Dynamic only """
|
||||
super().__init__()
|
||||
|
||||
self.norm = nn.LayerNorm(dim, bias = False)
|
||||
|
||||
self.num_residual_streams = num_residual_streams
|
||||
self.layer_index = layer_index
|
||||
|
||||
self.static_beta = nn.Parameter(torch.ones(num_residual_streams))
|
||||
|
||||
init_alpha0 = torch.zeros((num_residual_streams, 1))
|
||||
init_alpha0[layer_index % num_residual_streams, 0] = 1.
|
||||
|
||||
self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
|
||||
|
||||
self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + 1))
|
||||
self.dynamic_alpha_scale = nn.Parameter(tensor(1e-2))
|
||||
self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim))
|
||||
self.dynamic_beta_scale = nn.Parameter(tensor(1e-2))
|
||||
|
||||
def width_connection(self, residuals):
|
||||
normed = self.norm(residuals)
|
||||
|
||||
wc_weight = (normed @ self.dynamic_alpha_fn).tanh()
|
||||
dynamic_alpha = wc_weight * self.dynamic_alpha_scale
|
||||
alpha = dynamic_alpha + self.static_alpha
|
||||
|
||||
dc_weight = (normed @ self.dynamic_beta_fn).tanh()
|
||||
dynamic_beta = dc_weight * self.dynamic_beta_scale
|
||||
beta = dynamic_beta + self.static_beta
|
||||
|
||||
# width connection
|
||||
mix_h = einsum(alpha, residuals, '... e1 e2, ... e1 d -> ... e2 d')
|
||||
|
||||
branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
|
||||
|
||||
return branch_input, residuals, beta
|
||||
|
||||
def depth_connection(
|
||||
self,
|
||||
branch_output,
|
||||
residuals,
|
||||
beta
|
||||
):
|
||||
return einsum(branch_output, beta, "b n d, b n e -> b n e d") + residuals
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, num_residual_streams):
|
||||
super().__init__()
|
||||
|
||||
self.num_residual_streams = num_residual_streams
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for layer_index in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
HyperConnection(dim, num_residual_streams, layer_index),
|
||||
Attention(dim, heads = heads, dim_head = dim_head),
|
||||
HyperConnection(dim, num_residual_streams, layer_index),
|
||||
FeedForward(dim, mlp_dim)
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x = repeat(x, 'b n d -> b n e d', e = self.num_residual_streams)
|
||||
|
||||
for attn_hyper_conn, attn, ff_hyper_conn, ff in self.layers:
|
||||
|
||||
x, attn_res, beta = attn_hyper_conn.width_connection(x)
|
||||
|
||||
x = attn(x)
|
||||
|
||||
x = attn_hyper_conn.depth_connection(x, attn_res, beta)
|
||||
|
||||
x, ff_res, beta = ff_hyper_conn.width_connection(x)
|
||||
|
||||
x = ff(x)
|
||||
|
||||
x = ff_hyper_conn.depth_connection(x, ff_res, beta)
|
||||
|
||||
x = reduce(x, 'b n e d -> b n d', 'sum')
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class SimpleViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, num_residual_streams, num_register_tokens = 4, channels = 3, dim_head = 64):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim))
|
||||
|
||||
self.pos_embedding = posemb_sincos_2d(
|
||||
h = image_height // patch_height,
|
||||
w = image_width // patch_width,
|
||||
dim = dim,
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, num_residual_streams)
|
||||
|
||||
self.pool = "mean"
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.linear_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, img):
|
||||
batch, device = img.shape[0], img.device
|
||||
|
||||
x = self.to_patch_embedding(img)
|
||||
x += self.pos_embedding.to(x)
|
||||
|
||||
r = repeat(self.register_tokens, 'n d -> b n d', b = batch)
|
||||
|
||||
x, ps = pack([x, r], 'b * d')
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
x, _ = unpack(x, ps, 'b * d')
|
||||
|
||||
x = x.mean(dim = 1)
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.linear_head(x)
|
||||
|
||||
# main
|
||||
|
||||
if __name__ == '__main__':
|
||||
vit = SimpleViT(
|
||||
num_classes = 1000,
|
||||
image_size = 256,
|
||||
patch_size = 8,
|
||||
dim = 1024,
|
||||
depth = 12,
|
||||
heads = 8,
|
||||
mlp_dim = 2048,
|
||||
num_residual_streams = 8
|
||||
)
|
||||
|
||||
images = torch.randn(3, 3, 256, 256)
|
||||
|
||||
logits = vit(images)
|
||||
159
vit_pytorch/simple_vit_with_value_residual.py
Normal file
159
vit_pytorch/simple_vit_with_value_residual.py
Normal file
@@ -0,0 +1,159 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Module, ModuleList
|
||||
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
def default(v, d):
|
||||
return v if exists(v) else d
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
|
||||
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
|
||||
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
|
||||
omega = torch.arange(dim // 4) / (dim // 4 - 1)
|
||||
omega = 1.0 / (temperature ** omega)
|
||||
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
|
||||
return pe.type(dtype)
|
||||
|
||||
# classes
|
||||
|
||||
def FeedForward(dim, hidden_dim):
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, learned_value_residual_mix = False):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
self.to_residual_mix = nn.Sequential(
|
||||
nn.Linear(dim, heads),
|
||||
nn.Sigmoid(),
|
||||
Rearrange('b n h -> b h n 1')
|
||||
) if learned_value_residual_mix else (lambda _: 0.5)
|
||||
|
||||
def forward(self, x, value_residual = None):
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
if exists(value_residual):
|
||||
mix = self.to_residual_mix(x)
|
||||
v = v * mix + value_residual * (1. - mix)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
|
||||
return self.to_out(out), v
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = ModuleList([])
|
||||
for i in range(depth):
|
||||
is_first = i == 0
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, learned_value_residual_mix = not is_first),
|
||||
FeedForward(dim, mlp_dim)
|
||||
]))
|
||||
def forward(self, x):
|
||||
value_residual = None
|
||||
|
||||
for attn, ff in self.layers:
|
||||
|
||||
attn_out, values = attn(x, value_residual = value_residual)
|
||||
value_residual = default(value_residual, values)
|
||||
|
||||
x = attn_out + x
|
||||
x = ff(x) + x
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class SimpleViT(Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = posemb_sincos_2d(
|
||||
h = image_height // patch_height,
|
||||
w = image_width // patch_width,
|
||||
dim = dim,
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
|
||||
|
||||
self.pool = "mean"
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.linear_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, img):
|
||||
device = img.device
|
||||
|
||||
x = self.to_patch_embedding(img)
|
||||
x += self.pos_embedding.to(device, dtype=x.dtype)
|
||||
|
||||
x = self.transformer(x)
|
||||
x = x.mean(dim = 1)
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.linear_head(x)
|
||||
|
||||
# quick test
|
||||
|
||||
if __name__ == '__main__':
|
||||
v = SimpleViT(
|
||||
num_classes = 1000,
|
||||
image_size = 256,
|
||||
patch_size = 8,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048,
|
||||
)
|
||||
|
||||
images = torch.randn(2, 3, 256, 256)
|
||||
|
||||
logits = v(images)
|
||||
777
vit_pytorch/vaat.py
Normal file
777
vit_pytorch/vaat.py
Normal file
@@ -0,0 +1,777 @@
|
||||
# vision-audio-action transformer - vaat
|
||||
|
||||
from __future__ import annotations
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, cat, stack, arange, tensor
|
||||
from torch.nn import Module, ModuleList
|
||||
|
||||
from torchaudio.transforms import Spectrogram
|
||||
|
||||
from einops import rearrange, repeat, reduce, pack, unpack
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
def default(v, d):
|
||||
return v if exists(v) else d
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
# 2d sinusoidal positional embedding
|
||||
# simple vit paper shows it is good enough compared to learned
|
||||
|
||||
def posemb_sincos_2d(
|
||||
patches,
|
||||
temperature = 10000,
|
||||
dtype = torch.float32
|
||||
):
|
||||
_, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype
|
||||
|
||||
y, x = torch.meshgrid(arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
|
||||
assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
|
||||
|
||||
omega = arange(dim // 4, device = device) / (dim // 4 - 1)
|
||||
omega = temperature ** -omega
|
||||
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
|
||||
pe = cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
|
||||
pe = pe.type(dtype)
|
||||
|
||||
return rearrange(pe, '(h w) d -> h w d', h = h, w = w)
|
||||
|
||||
# classes
|
||||
|
||||
class FiLM(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
):
|
||||
super().__init__()
|
||||
proj = nn.Linear(dim, dim * 2)
|
||||
|
||||
self.to_gamma_beta = nn.Sequential(
|
||||
proj,
|
||||
Rearrange('b (two d) -> two b 1 d', two = 2)
|
||||
)
|
||||
|
||||
nn.init.zeros_(proj.weight)
|
||||
nn.init.zeros_(proj.bias)
|
||||
|
||||
def forward(self, tokens, cond):
|
||||
gamma, beta = self.to_gamma_beta(cond)
|
||||
|
||||
return tokens * gamma + beta
|
||||
|
||||
class FeedForward(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
hidden_dim,
|
||||
dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
heads = 8,
|
||||
dim_head = 64,
|
||||
dropout = 0.,
|
||||
dim_context = None,
|
||||
cross_attend = False
|
||||
):
|
||||
super().__init__()
|
||||
dim_context = default(dim_context, dim)
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.cross_attend = cross_attend
|
||||
self.context_norm = nn.LayerNorm(dim_context) if cross_attend else None
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(dim_context, inner_dim * 2, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x, context = None):
|
||||
|
||||
assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross attending, or vice versa'
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
# handle norming of context for cross attention
|
||||
|
||||
kv_input = x
|
||||
|
||||
if self.cross_attend:
|
||||
context = self.context_norm(context)
|
||||
kv_input = context
|
||||
|
||||
# project for queries, keys, values
|
||||
|
||||
qkv = (self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1))
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
depth,
|
||||
heads,
|
||||
dim_head,
|
||||
mlp_dim,
|
||||
dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
return_hiddens = False
|
||||
):
|
||||
|
||||
hiddens = []
|
||||
|
||||
for attn, ff in self.layers:
|
||||
hiddens.append(x)
|
||||
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
if not return_hiddens:
|
||||
return x
|
||||
|
||||
return x, hiddens
|
||||
|
||||
class AST(Module):
|
||||
# audio spectrogram transformer https://arxiv.org/abs/2104.01778
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
depth,
|
||||
mlp_dim,
|
||||
num_classes = None,
|
||||
patch_size = 16,
|
||||
dim_head = 64,
|
||||
heads = 8,
|
||||
dropout = 0.,
|
||||
accept_spec = False,
|
||||
accept_spec_time_first = True,
|
||||
spec_n_fft = 128,
|
||||
spec_power = 2,
|
||||
spec_win_length = 24,
|
||||
spec_hop_length = None,
|
||||
spec_pad = 0,
|
||||
spec_center = True,
|
||||
spec_pad_mode = 'reflect',
|
||||
num_register_tokens = 4
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.depth = depth
|
||||
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
patch_input_dim = patch_height * patch_width
|
||||
|
||||
self.patch_size = (patch_height, patch_width)
|
||||
|
||||
self.to_patch_tokens = nn.Sequential(
|
||||
Rearrange('b (h p1) (w p2) -> b h w (p1 p2)', p1 = self.patch_size[0], p2 = self.patch_size[1]),
|
||||
nn.LayerNorm(patch_input_dim),
|
||||
nn.Linear(patch_input_dim, dim),
|
||||
nn.LayerNorm(dim)
|
||||
)
|
||||
|
||||
self.accept_spec = accept_spec
|
||||
self.accept_spec_time_first = accept_spec_time_first
|
||||
|
||||
self.spec = Spectrogram(
|
||||
n_fft = spec_n_fft,
|
||||
power = spec_power,
|
||||
win_length = spec_win_length,
|
||||
hop_length = spec_hop_length,
|
||||
pad = spec_pad,
|
||||
center = spec_center,
|
||||
pad_mode = spec_pad_mode
|
||||
)
|
||||
|
||||
self.transformer = Transformer(
|
||||
dim = dim,
|
||||
depth = depth,
|
||||
dim_head = dim_head,
|
||||
heads = heads,
|
||||
mlp_dim = mlp_dim,
|
||||
dropout = dropout,
|
||||
)
|
||||
|
||||
self.final_norm = nn.LayerNorm(dim)
|
||||
|
||||
self.mlp_head = nn.Linear(dim, num_classes) if exists(num_classes) else nn.Identity()
|
||||
|
||||
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
raw_audio_or_spec, # (b t) | (b f t)
|
||||
return_hiddens = False
|
||||
):
|
||||
batch, device = raw_audio_or_spec.shape[0], raw_audio_or_spec.device
|
||||
|
||||
assert (self.accept_spec and raw_audio_or_spec.ndim == 3) or (not self.accept_spec and raw_audio_or_spec.ndim == 2)
|
||||
|
||||
if self.accept_spec:
|
||||
spec = rearrange(raw_audio_or_spec, 'b t f -> b f t')
|
||||
else:
|
||||
spec = self.spec(raw_audio_or_spec)
|
||||
|
||||
# automatically crop if audio does not yield a 2d spectrogram that is divisible by patch sizes
|
||||
|
||||
height, width = spec.shape[-2:]
|
||||
patch_height, patch_width = self.patch_size
|
||||
|
||||
rounded_height = height // patch_height * patch_height
|
||||
rounded_width = width // patch_width * patch_width
|
||||
|
||||
spec = spec[..., :rounded_height, :rounded_width]
|
||||
|
||||
# to patches
|
||||
|
||||
tokens = self.to_patch_tokens(spec)
|
||||
|
||||
# get number of patches along height and width
|
||||
|
||||
_, num_patch_height, num_patch_width, _ = tokens.shape
|
||||
|
||||
# 2d sinusoidal positional embedding
|
||||
|
||||
tokens = tokens + posemb_sincos_2d(tokens)
|
||||
|
||||
tokens = rearrange(tokens, 'b ... c -> b (...) c')
|
||||
|
||||
# register tokens
|
||||
|
||||
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
|
||||
|
||||
tokens, packed_shape = pack((register_tokens, tokens), 'b * d')
|
||||
|
||||
# attention
|
||||
|
||||
attended, hiddens = self.transformer(tokens, return_hiddens = True)
|
||||
|
||||
# final global average and norm (most recent papers show this is superior to CLS token)
|
||||
|
||||
normed = self.final_norm(attended)
|
||||
|
||||
if return_hiddens:
|
||||
return normed, stack(hiddens)
|
||||
|
||||
register_tokens, normed = unpack(normed, packed_shape, 'b * d')
|
||||
|
||||
pooled = reduce(normed, 'b n d -> b d', 'mean')
|
||||
|
||||
maybe_logits = self.mlp_head(pooled)
|
||||
|
||||
return maybe_logits
|
||||
|
||||
class ViT(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_size,
|
||||
patch_size,
|
||||
num_classes,
|
||||
dim,
|
||||
depth,
|
||||
heads,
|
||||
mlp_dim,
|
||||
pool = 'cls',
|
||||
channels = 3,
|
||||
dim_head = 64,
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.,
|
||||
num_register_tokens = 0
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.depth = depth
|
||||
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(num_patches, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
self.pool = pool
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Linear(dim, num_classes)
|
||||
|
||||
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
|
||||
|
||||
def forward(self, img, return_hiddens = False):
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
x += self.pos_embedding[:n]
|
||||
|
||||
cls_tokens = repeat(self.cls_token, 'd -> b d', b = b)
|
||||
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = b)
|
||||
|
||||
x, packed_shape = pack((register_tokens, cls_tokens, x), 'b * d')
|
||||
|
||||
x = self.dropout(x)
|
||||
|
||||
x, hiddens = self.transformer(x, return_hiddens = True)
|
||||
|
||||
# return the representation trajectory
|
||||
|
||||
if return_hiddens:
|
||||
return x, stack(hiddens)
|
||||
|
||||
register_tokens, cls_tokens, x = unpack(x, packed_shape, 'b * d')
|
||||
|
||||
x = x.mean(dim = 1) if self.pool == 'mean' else cls_tokens
|
||||
|
||||
x = self.to_latent(x)
|
||||
|
||||
return self.mlp_head(x)
|
||||
|
||||
# proposed VAT
|
||||
|
||||
# https://openreview.net/forum?id=TalHOvvLZu
|
||||
# simple way to get SOTA on Libero dataset (beating fine-tuned pi-zero)
|
||||
|
||||
class VAAT(Module):
|
||||
def __init__(
|
||||
self,
|
||||
vit: ViT | dict,
|
||||
ast: AST | dict,
|
||||
*,
|
||||
dim,
|
||||
depth,
|
||||
heads,
|
||||
dim_head,
|
||||
dim_action,
|
||||
mlp_dim,
|
||||
num_image_views = None,
|
||||
num_audio_views = None,
|
||||
num_tasks = None,
|
||||
dim_extra_token = None,
|
||||
num_register_tokens = 4,
|
||||
action_chunk_len = 7,
|
||||
time_seq_len = 1,
|
||||
dropout = 0.,
|
||||
add_self_attn = True, # in the paper, they didn't have any ways for the action token to exchange information with the extra token, so we'll just add it as an option
|
||||
self_attn_heads = 4,
|
||||
self_attn_dim_head = 32,
|
||||
ast_layer_indices: tuple[int, ...] | None = None,
|
||||
vit_layer_indices: tuple[int, ...] | None = None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# vit
|
||||
|
||||
if isinstance(vit, dict):
|
||||
vit = ViT(**vit)
|
||||
|
||||
self.vit = vit
|
||||
|
||||
vit_dim = vit.dim
|
||||
|
||||
assert vit.depth == depth or exists(vit_layer_indices), f'if the VAAT depth is not equal to the ViT depth, you must pass in the indices from the ViT to be layered to the VAAT in order from bottom to top'
|
||||
|
||||
vit_layer_indices = default(vit_layer_indices, tuple(range(depth)))
|
||||
|
||||
assert len(vit_layer_indices) == depth, f'number of vit layer indices {len(vit_layer_indices)} does not much the VAT depth {depth}'
|
||||
|
||||
self.register_buffer('vit_layer_indices', tensor(vit_layer_indices), persistent = False)
|
||||
|
||||
# ast
|
||||
|
||||
if isinstance(ast, dict):
|
||||
ast = AST(**ast)
|
||||
|
||||
self.ast = ast
|
||||
|
||||
ast_dim = ast.dim
|
||||
|
||||
self.ast_accept_spec = ast.accept_spec
|
||||
|
||||
assert ast.depth == depth or exists(ast_layer_indices), f'if the VAAT depth is not equal to the AST depth, you must pass in the indices from the AST to be layered to the VAAT in order from bottom to top'
|
||||
|
||||
ast_layer_indices = default(ast_layer_indices, tuple(range(depth)))
|
||||
|
||||
assert len(ast_layer_indices) == depth, f'number of ast layer indices {len(ast_layer_indices)} does not much the VAAT depth {depth}'
|
||||
|
||||
self.register_buffer('ast_layer_indices', tensor(vit_layer_indices), persistent = False)
|
||||
|
||||
# handle maybe multiple frames
|
||||
|
||||
is_video = time_seq_len > 1
|
||||
|
||||
self.is_video = is_video
|
||||
self.time_seq_len = time_seq_len
|
||||
self.time_pos_emb = nn.Parameter(torch.randn(time_seq_len, vit_dim) * 1e-2) if is_video else None
|
||||
|
||||
# maybe view embeddings
|
||||
|
||||
self.image_view_emb = nn.Parameter(torch.randn(num_image_views, vit_dim) * 1e-2) if exists(num_image_views) and num_image_views > 1 else None
|
||||
|
||||
self.audio_view_emb = nn.Parameter(torch.randn(num_audio_views, ast_dim) * 1e-2) if exists(num_audio_views) and num_audio_views > 1 else None
|
||||
|
||||
# handle maybe task conditioning
|
||||
|
||||
self.has_tasks = exists(num_tasks)
|
||||
|
||||
if self.has_tasks:
|
||||
self.task_emb = nn.Parameter(torch.randn(num_tasks, dim) * 1e-2)
|
||||
|
||||
# register tokens from Darcet et al.
|
||||
|
||||
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
|
||||
|
||||
# to action tokens
|
||||
|
||||
self.action_pos_emb = nn.Parameter(torch.randn(action_chunk_len, dim) * 1e-2)
|
||||
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
maybe_film = FiLM(dim = dim) if self.has_tasks else None
|
||||
maybe_self_attn = Attention(dim = dim, heads = self_attn_heads, dim_head = self_attn_dim_head, dropout = dropout) if add_self_attn else None
|
||||
|
||||
self.layers.append(ModuleList([
|
||||
maybe_film,
|
||||
maybe_self_attn,
|
||||
Attention(dim = dim, dim_context = vit_dim, heads = heads, dim_head = dim_head, dropout = dropout, cross_attend = True),
|
||||
Attention(dim = dim, dim_context = ast_dim, heads = heads, dim_head = dim_head, dropout = dropout, cross_attend = True),
|
||||
FeedForward(dim = dim, hidden_dim = mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
self.final_norm = nn.LayerNorm(dim)
|
||||
self.to_pred_action = nn.Linear(dim, dim_action, bias = False)
|
||||
|
||||
# handle the extra token
|
||||
|
||||
self.accept_extra_token = exists(dim_extra_token)
|
||||
|
||||
if exists(dim_extra_token):
|
||||
self.to_extra_token = nn.Linear(dim_extra_token, dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
video_or_image, # (b v? c t? h w) - batch, views [wrist + third person or more], channels, maybe time, height, width
|
||||
audio_or_spec, # (b v? t) | (b v?f t) - batch, audio len | batch, spec freq, time
|
||||
*,
|
||||
extra = None, # (b d) - batch, dim extra
|
||||
tasks = None, # (b)
|
||||
actions = None, # (b k d) - batch, action chunk length, action dimension
|
||||
return_hiddens = False,
|
||||
freeze_vit = False,
|
||||
freeze_ast = False
|
||||
):
|
||||
batch = video_or_image.shape[0]
|
||||
return_loss = exists(actions)
|
||||
|
||||
# handle some various input dimensions
|
||||
|
||||
if video_or_image.ndim == 4:
|
||||
video_or_image = rearrange(video_or_image, 'b 1 c h w')
|
||||
|
||||
assert (
|
||||
(video_or_image.ndim == 5 and not self.is_video) or
|
||||
(video_or_image.ndim == 6 and self.is_video)
|
||||
)
|
||||
|
||||
if video_or_image.ndim == 5:
|
||||
video_or_image = rearrange(video_or_image, 'b v c h w -> b v c 1 h w')
|
||||
|
||||
assert video_or_image.shape[3] == self.time_seq_len
|
||||
|
||||
# audio shapes - adding view if impliciy to be 1
|
||||
|
||||
if audio_or_spec.ndim == 2 and not self.ast_accept_spec:
|
||||
audio_or_spec = rearrange(audio_or_spec, 'b t -> b 1 t')
|
||||
|
||||
elif audio_or_spec.ndim == 3 and self.ast_accept_spec:
|
||||
audio_or_spec = rearrange(audio_or_spec, 'b f t -> b 1 f t')
|
||||
|
||||
# to images
|
||||
|
||||
images = rearrange(video_or_image, 'b v c t h w -> b v t c h w')
|
||||
|
||||
images, image_packed_shape = pack([images], '* c h w')
|
||||
|
||||
# to audio
|
||||
|
||||
if self.ast_accept_spec:
|
||||
audio_or_spec, audio_packed_shape = pack([audio_or_spec], '* f t')
|
||||
else:
|
||||
audio_or_spec, audio_packed_shape = pack([audio_or_spec], '* t')
|
||||
|
||||
# get representation trajectory from vit
|
||||
|
||||
vit_forward_context = torch.no_grad if freeze_vit else nullcontext
|
||||
|
||||
with vit_forward_context():
|
||||
embed, hiddens = self.vit(images, return_hiddens = True)
|
||||
|
||||
hiddens = cat((hiddens, embed[None, ...]))
|
||||
|
||||
# extract the hiddens needed for the action cross attention
|
||||
|
||||
hiddens = hiddens[self.vit_layer_indices]
|
||||
|
||||
# unpack temporarily for embedding
|
||||
|
||||
hiddens, = unpack(hiddens, image_packed_shape, 'l * n d') # l for layers
|
||||
|
||||
# maybe add time embeddings
|
||||
|
||||
if self.is_video:
|
||||
time_pos_emb = rearrange(self.time_pos_emb, 't d -> t 1 d')
|
||||
hiddens = hiddens + time_pos_emb
|
||||
|
||||
# maybe view embeddings
|
||||
|
||||
if exists(self.image_view_emb):
|
||||
assert self.image_view_emb.shape[0] == hiddens.shape[2]
|
||||
|
||||
image_view_emb = rearrange(self.image_view_emb, 'v d -> v 1 1 d')
|
||||
hiddens = hiddens + image_view_emb
|
||||
|
||||
# get representation trajectory from ast
|
||||
|
||||
ast_forward_context = torch.no_grad if freeze_ast else nullcontext
|
||||
|
||||
with ast_forward_context():
|
||||
audio_embed, audio_hiddens = self.ast(audio_or_spec, return_hiddens = True)
|
||||
|
||||
audio_hiddens = cat((audio_hiddens, audio_embed[None, ...]))
|
||||
|
||||
# extract the hiddens needed for the action cross attention
|
||||
|
||||
audio_hiddens = audio_hiddens[self.ast_layer_indices]
|
||||
|
||||
# unpack audio temporarily for embedding
|
||||
|
||||
audio_hiddens, = unpack(audio_hiddens, audio_packed_shape, 'l * n d') # l for layers
|
||||
|
||||
# maybe audio view embeddings
|
||||
|
||||
if exists(self.audio_view_emb):
|
||||
assert self.audio_view_emb.shape[0] == audio_hiddens.shape[2]
|
||||
|
||||
audio_view_emb = rearrange(self.audio_view_emb, 'v d -> v 1 1 d')
|
||||
audio_hiddens = audio_hiddens + audio_view_emb
|
||||
|
||||
# maybe tasks
|
||||
|
||||
if exists(tasks):
|
||||
assert self.has_tasks, f'`num_tasks` must be set on `VAT` for task conditioning'
|
||||
|
||||
task_emb = self.task_emb[tasks]
|
||||
|
||||
# cross from actions to representation trajectory
|
||||
|
||||
image_context = rearrange(hiddens, 'l b v t n d -> l b (v t n) d')
|
||||
|
||||
audio_context = rearrange(audio_hiddens, 'l b v n d -> l b (v n) d')
|
||||
|
||||
# get main action tokens and maybe append extra
|
||||
|
||||
action_tokens = repeat(self.action_pos_emb, 'k d -> b k d', b = batch)
|
||||
|
||||
has_extra = exists(extra)
|
||||
|
||||
if has_extra:
|
||||
assert self.accept_extra_token
|
||||
|
||||
extra_token = self.to_extra_token(extra)
|
||||
|
||||
action_tokens, packed_extra = pack([action_tokens, extra_token], 'b * d')
|
||||
|
||||
# register tokens
|
||||
|
||||
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
|
||||
|
||||
action_tokens, registers_packed_shape = pack((register_tokens, action_tokens), 'b * d')
|
||||
|
||||
# cross attention
|
||||
|
||||
hiddens = [action_tokens]
|
||||
|
||||
for (maybe_film, maybe_self_attn, image_cross_attn, audio_cross_attn, ff), image_layer_context, audio_layer_context in zip(self.layers, image_context, audio_context):
|
||||
|
||||
if exists(tasks):
|
||||
action_tokens = maybe_film(action_tokens, task_emb)
|
||||
|
||||
action_tokens = image_cross_attn(action_tokens, image_layer_context) + action_tokens
|
||||
|
||||
action_tokens = audio_cross_attn(action_tokens, audio_layer_context) + action_tokens
|
||||
|
||||
if exists(maybe_self_attn):
|
||||
action_tokens = maybe_self_attn(action_tokens) + action_tokens
|
||||
|
||||
action_tokens = ff(action_tokens) + action_tokens
|
||||
|
||||
hiddens.append(action_tokens)
|
||||
|
||||
# unpack registers
|
||||
|
||||
_, action_tokens = unpack(action_tokens, registers_packed_shape, 'b * d')
|
||||
|
||||
# maybe unpack extra
|
||||
|
||||
if has_extra:
|
||||
action_tokens, _ = unpack(action_tokens, packed_extra, 'b * d')
|
||||
|
||||
# norm and prediction
|
||||
|
||||
action_tokens = self.final_norm(action_tokens)
|
||||
|
||||
pred_action = self.to_pred_action(action_tokens)
|
||||
|
||||
if not return_loss:
|
||||
if not return_hiddens:
|
||||
return pred_action
|
||||
|
||||
return pred_action, stack(hiddens)
|
||||
|
||||
assert pred_action.shape[1] == actions.shape[1]
|
||||
|
||||
# they found l1 loss suffices
|
||||
|
||||
return F.l1_loss(pred_action, actions)
|
||||
|
||||
# quick test
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
vit = ViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 384,
|
||||
heads = 8,
|
||||
depth = 4,
|
||||
mlp_dim = 384 * 4
|
||||
)
|
||||
|
||||
ast = AST(
|
||||
dim = 384,
|
||||
depth = 4,
|
||||
heads = 8,
|
||||
num_classes = 1000,
|
||||
patch_size = 16,
|
||||
mlp_dim = 384 * 4
|
||||
)
|
||||
|
||||
vaat = VAAT(
|
||||
vit,
|
||||
ast,
|
||||
dim = 512,
|
||||
depth = 9,
|
||||
heads = 8,
|
||||
dim_head = 64,
|
||||
mlp_dim = 2048,
|
||||
dim_action = 20,
|
||||
action_chunk_len = 7,
|
||||
time_seq_len = 4,
|
||||
num_image_views = 2,
|
||||
num_audio_views = 2,
|
||||
num_tasks = 4,
|
||||
add_self_attn = True,
|
||||
dim_extra_token = 33, # extra token with some variable dimension
|
||||
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
|
||||
0, 0, 1, 1, 2, 2, 3, 3, 4
|
||||
),
|
||||
ast_layer_indices = (
|
||||
1, 1, 1, 2, 2, 2, 3, 3, 3
|
||||
)
|
||||
)
|
||||
|
||||
images = torch.randn(2, 2, 3, 4, 256, 256) # (2 views with 4 frames)
|
||||
audio = torch.randn(2, 2, 14_100 * 5)
|
||||
|
||||
tasks = torch.randint(0, 4, (2,))
|
||||
extra = torch.randn(2, 33) # extra internal state
|
||||
|
||||
actions = torch.randn(2, 7, 20) # actions for learning
|
||||
|
||||
loss = vaat(images, audio, actions = actions, tasks = tasks, extra = extra, freeze_vit = True)
|
||||
loss.backward()
|
||||
|
||||
# after much training
|
||||
|
||||
pred_actions, hiddens = vaat(images, audio, tasks = tasks, extra = extra, return_hiddens = True)
|
||||
|
||||
assert pred_actions.shape == (2, 7, 20)
|
||||
528
vit_pytorch/vat.py
Normal file
528
vit_pytorch/vat.py
Normal file
@@ -0,0 +1,528 @@
|
||||
from __future__ import annotations
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, cat, stack, tensor
|
||||
from torch.nn import Module, ModuleList
|
||||
|
||||
from einops import rearrange, repeat, pack, unpack
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
def default(v, d):
|
||||
return v if exists(v) else d
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
# classes
|
||||
|
||||
class FiLM(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
):
|
||||
super().__init__()
|
||||
proj = nn.Linear(dim, dim * 2)
|
||||
|
||||
self.to_gamma_beta = nn.Sequential(
|
||||
proj,
|
||||
Rearrange('b (two d) -> two b 1 d', two = 2)
|
||||
)
|
||||
|
||||
nn.init.zeros_(proj.weight)
|
||||
nn.init.zeros_(proj.bias)
|
||||
|
||||
def forward(self, tokens, cond):
|
||||
gamma, beta = self.to_gamma_beta(cond)
|
||||
|
||||
return tokens * gamma + beta
|
||||
|
||||
class FeedForward(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
hidden_dim,
|
||||
dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_context = None,
|
||||
heads = 8,
|
||||
dim_head = 64,
|
||||
dropout = 0.,
|
||||
cross_attend = False
|
||||
):
|
||||
super().__init__()
|
||||
dim_context = default(dim_context, dim)
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.cross_attend = cross_attend
|
||||
self.context_norm = nn.LayerNorm(dim_context) if cross_attend else None
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(dim_context, inner_dim * 2, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x, context = None):
|
||||
|
||||
assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross attending, or vice versa'
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
# handle norming of context for cross attention
|
||||
|
||||
kv_input = x
|
||||
|
||||
if self.cross_attend:
|
||||
context = self.context_norm(context)
|
||||
kv_input = context
|
||||
|
||||
# project for queries, keys, values
|
||||
|
||||
qkv = (self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1))
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
depth,
|
||||
heads,
|
||||
dim_head,
|
||||
mlp_dim,
|
||||
dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
return_hiddens = False
|
||||
):
|
||||
|
||||
hiddens = []
|
||||
|
||||
for attn, ff in self.layers:
|
||||
hiddens.append(x)
|
||||
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
if not return_hiddens:
|
||||
return x
|
||||
|
||||
return x, hiddens
|
||||
|
||||
class ViT(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_size,
|
||||
patch_size,
|
||||
num_classes,
|
||||
dim,
|
||||
depth,
|
||||
heads,
|
||||
mlp_dim,
|
||||
pool = 'cls',
|
||||
channels = 3,
|
||||
dim_head = 64,
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.,
|
||||
num_register_tokens = 0
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.depth = depth
|
||||
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(num_patches, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
self.pool = pool
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Linear(dim, num_classes)
|
||||
|
||||
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
|
||||
|
||||
def forward(self, img, return_hiddens = False):
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
x += self.pos_embedding[:n]
|
||||
|
||||
cls_tokens = repeat(self.cls_token, 'd -> b d', b = b)
|
||||
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = b)
|
||||
|
||||
x, packed_shape = pack((register_tokens, cls_tokens, x), 'b * d')
|
||||
|
||||
x = self.dropout(x)
|
||||
|
||||
x, hiddens = self.transformer(x, return_hiddens = True)
|
||||
|
||||
# return the representation trajectory
|
||||
|
||||
if return_hiddens:
|
||||
return x, stack(hiddens)
|
||||
|
||||
register_tokens, cls_tokens, x = unpack(x, packed_shape, 'b * d')
|
||||
|
||||
x = x.mean(dim = 1) if self.pool == 'mean' else cls_tokens
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.mlp_head(x)
|
||||
|
||||
# proposed VAT
|
||||
|
||||
# https://openreview.net/forum?id=TalHOvvLZu
|
||||
# simple way to get SOTA on Libero dataset (beating fine-tuned pi-zero)
|
||||
|
||||
class VAT(Module):
|
||||
def __init__(
|
||||
self,
|
||||
vit: ViT | dict,
|
||||
*,
|
||||
dim,
|
||||
depth,
|
||||
heads,
|
||||
dim_head,
|
||||
dim_action,
|
||||
mlp_dim,
|
||||
num_views = None,
|
||||
num_tasks = None,
|
||||
dim_extra_token = None,
|
||||
num_register_tokens = 4,
|
||||
action_chunk_len = 7,
|
||||
time_seq_len = 1,
|
||||
dropout = 0.,
|
||||
add_self_attn = True, # in the paper, they didn't have any ways for the action token to exchange information with the extra token, so we'll just add it as an option
|
||||
self_attn_heads = 4,
|
||||
self_attn_dim_head = 32,
|
||||
vit_layer_indices: tuple[int, ...] | None = None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if isinstance(vit, dict):
|
||||
vit = ViT(**vit)
|
||||
|
||||
self.vit = vit
|
||||
|
||||
vit_dim = vit.dim
|
||||
|
||||
assert vit.depth == depth or exists(vit_layer_indices), f'if the VAT depth is not equal to the ViT depth, you must pass in the indices from the ViT to be layered to the VAT in order from bottom to top'
|
||||
|
||||
vit_layer_indices = default(vit_layer_indices, tuple(range(depth)))
|
||||
|
||||
assert len(vit_layer_indices) == depth, f'number of vit layer indices {len(vit_layer_indices)} does not much the VAT depth {depth}'
|
||||
|
||||
self.register_buffer('layer_indices', tensor(vit_layer_indices), persistent = False)
|
||||
|
||||
# handle maybe multiple frames
|
||||
|
||||
is_video = time_seq_len > 1
|
||||
|
||||
self.is_video = is_video
|
||||
self.time_seq_len = time_seq_len
|
||||
self.time_pos_emb = nn.Parameter(torch.randn(time_seq_len, vit_dim) * 1e-2) if is_video else None
|
||||
|
||||
# maybe view embeddings
|
||||
|
||||
self.view_emb = nn.Parameter(torch.randn(num_views, vit_dim) * 1e-2) if exists(num_views) and num_views > 1 else None
|
||||
|
||||
# handle maybe task conditioning
|
||||
|
||||
self.has_tasks = exists(num_tasks)
|
||||
|
||||
if self.has_tasks:
|
||||
self.task_emb = nn.Parameter(torch.randn(num_tasks, dim) * 1e-2)
|
||||
|
||||
# register tokens from Darcet et al.
|
||||
|
||||
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
|
||||
|
||||
# to action tokens
|
||||
|
||||
self.action_pos_emb = nn.Parameter(torch.randn(action_chunk_len, dim) * 1e-2)
|
||||
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
maybe_film = FiLM(dim = dim) if self.has_tasks else None
|
||||
maybe_self_attn = Attention(dim = dim, heads = self_attn_heads, dim_head = self_attn_dim_head, dropout = dropout) if add_self_attn else None
|
||||
|
||||
self.layers.append(ModuleList([
|
||||
maybe_film,
|
||||
maybe_self_attn,
|
||||
Attention(dim = dim, dim_context = vit_dim, heads = heads, dim_head = dim_head, dropout = dropout, cross_attend = True),
|
||||
FeedForward(dim = dim, hidden_dim = mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
self.final_norm = nn.LayerNorm(dim)
|
||||
self.to_pred_action = nn.Linear(dim, dim_action, bias = False)
|
||||
|
||||
# handle the extra token
|
||||
|
||||
self.accept_extra_token = exists(dim_extra_token)
|
||||
|
||||
if exists(dim_extra_token):
|
||||
self.to_extra_token = nn.Linear(dim_extra_token, dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
video_or_image, # (b v? c t? h w) - batch, views [wrist + third person or more], channels, maybe time, height, width
|
||||
*,
|
||||
extra = None, # (b d) - batch, dim extra
|
||||
tasks = None, # (b)
|
||||
actions = None, # (b k d) - batch, action chunk length, action dimension
|
||||
return_hiddens = False,
|
||||
freeze_vit = False
|
||||
):
|
||||
batch = video_or_image.shape[0]
|
||||
return_loss = exists(actions)
|
||||
|
||||
# handle some various input dimensions
|
||||
|
||||
if video_or_image.ndim == 4:
|
||||
video_or_image = rearrange(video_or_image, 'b 1 c h w')
|
||||
|
||||
assert (
|
||||
(video_or_image.ndim == 5 and not self.is_video) or
|
||||
(video_or_image.ndim == 6 and self.is_video)
|
||||
)
|
||||
|
||||
if video_or_image.ndim == 5:
|
||||
video_or_image = rearrange(video_or_image, 'b v c h w -> b v c 1 h w')
|
||||
|
||||
assert video_or_image.shape[3] == self.time_seq_len
|
||||
|
||||
# to images
|
||||
|
||||
images = rearrange(video_or_image, 'b v c t h w -> b v t c h w')
|
||||
|
||||
images, packed_shape = pack([images], '* c h w')
|
||||
|
||||
# get representation trajectory from vit
|
||||
|
||||
vit_forward_context = torch.no_grad if freeze_vit else nullcontext
|
||||
|
||||
with vit_forward_context():
|
||||
embed, hiddens = self.vit(images, return_hiddens = True)
|
||||
|
||||
hiddens = cat((hiddens, embed[None, ...]))
|
||||
|
||||
# extract the hiddens needed for the action cross attention
|
||||
|
||||
hiddens = hiddens[self.layer_indices]
|
||||
|
||||
# pack temporarily for embedding
|
||||
|
||||
hiddens, = unpack(hiddens, packed_shape, 'l * n d') # l for layers
|
||||
|
||||
# maybe add time embeddings
|
||||
|
||||
if self.is_video:
|
||||
time_pos_emb = rearrange(self.time_pos_emb, 't d -> t 1 d')
|
||||
hiddens = hiddens + time_pos_emb
|
||||
|
||||
# maybe view embeddings
|
||||
|
||||
if exists(self.view_emb):
|
||||
assert self.view_emb.shape[0] == hiddens.shape[2]
|
||||
|
||||
view_emb = rearrange(self.view_emb, 'v d -> v 1 1 d')
|
||||
hiddens = hiddens + view_emb
|
||||
|
||||
# maybe tasks
|
||||
|
||||
if exists(tasks):
|
||||
assert self.has_tasks, f'`num_tasks` must be set on `VAT` for task conditioning'
|
||||
|
||||
task_emb = self.task_emb[tasks]
|
||||
|
||||
# cross from actions to representation trajectory
|
||||
|
||||
context = rearrange(hiddens, 'l b v t n d -> l b (v t n) d')
|
||||
|
||||
# get main action tokens and maybe append extra
|
||||
|
||||
action_tokens = repeat(self.action_pos_emb, 'k d -> b k d', b = batch)
|
||||
|
||||
has_extra = exists(extra)
|
||||
|
||||
if has_extra:
|
||||
assert self.accept_extra_token
|
||||
|
||||
extra_token = self.to_extra_token(extra)
|
||||
|
||||
action_tokens, packed_extra = pack([action_tokens, extra_token], 'b * d')
|
||||
|
||||
# register tokens
|
||||
|
||||
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
|
||||
|
||||
action_tokens, registers_packed_shape = pack((register_tokens, action_tokens), 'b * d')
|
||||
|
||||
# cross attention
|
||||
|
||||
hiddens = [action_tokens]
|
||||
|
||||
for (maybe_film, maybe_self_attn, cross_attn, ff), layer_context in zip(self.layers, context):
|
||||
|
||||
if exists(tasks):
|
||||
action_tokens = maybe_film(action_tokens, task_emb)
|
||||
|
||||
action_tokens = cross_attn(action_tokens, layer_context) + action_tokens
|
||||
|
||||
if exists(maybe_self_attn):
|
||||
action_tokens = maybe_self_attn(action_tokens) + action_tokens
|
||||
|
||||
action_tokens = ff(action_tokens) + action_tokens
|
||||
|
||||
hiddens.append(action_tokens)
|
||||
|
||||
# unpack registers
|
||||
|
||||
_, action_tokens = unpack(action_tokens, registers_packed_shape, 'b * d')
|
||||
|
||||
# maybe unpack extra
|
||||
|
||||
if has_extra:
|
||||
action_tokens, _ = unpack(action_tokens, packed_extra, 'b * d')
|
||||
|
||||
# norm and prediction
|
||||
|
||||
action_tokens = self.final_norm(action_tokens)
|
||||
|
||||
pred_action = self.to_pred_action(action_tokens)
|
||||
|
||||
if not return_loss:
|
||||
if not return_hiddens:
|
||||
return pred_action
|
||||
|
||||
return pred_action, stack(hiddens)
|
||||
|
||||
assert pred_action.shape[1] == actions.shape[1]
|
||||
|
||||
# they found l1 loss suffices
|
||||
|
||||
return F.l1_loss(pred_action, actions)
|
||||
|
||||
# quick test
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
vit = ViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 256,
|
||||
heads = 8,
|
||||
depth = 4,
|
||||
mlp_dim = 1024
|
||||
)
|
||||
|
||||
vat = VAT(
|
||||
vit,
|
||||
dim = 512,
|
||||
depth = 9,
|
||||
heads = 8,
|
||||
dim_head = 64,
|
||||
mlp_dim = 2048,
|
||||
dim_action = 20,
|
||||
action_chunk_len = 7,
|
||||
time_seq_len = 4,
|
||||
num_views = 2,
|
||||
num_tasks = 4,
|
||||
add_self_attn = True,
|
||||
dim_extra_token = 33, # extra token with some variable dimension
|
||||
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
|
||||
0, 0, 1, 1, 2, 2, 3, 3, 4
|
||||
)
|
||||
)
|
||||
|
||||
images = torch.randn(2, 2, 3, 4, 256, 256) # (2 views with 4 frames)
|
||||
tasks = torch.randint(0, 4, (2,))
|
||||
extra = torch.randn(2, 33) # extra internal state
|
||||
|
||||
actions = torch.randn(2, 7, 20) # actions for learning
|
||||
|
||||
loss = vat(images, actions = actions, tasks = tasks, extra = extra, freeze_vit = True)
|
||||
loss.backward()
|
||||
|
||||
# after much training
|
||||
|
||||
pred_actions, hiddens = vat(images, tasks = tasks, extra = extra, return_hiddens = True)
|
||||
|
||||
assert pred_actions.shape == (2, 7, 20)
|
||||
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Module, ModuleList
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
@@ -11,7 +12,7 @@ def pair(t):
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
class FeedForward(Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
@@ -26,7 +27,7 @@ class FeedForward(nn.Module):
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
@@ -62,13 +63,14 @@ class Attention(nn.Module):
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = nn.ModuleList([])
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
@@ -80,7 +82,7 @@ class Transformer(nn.Module):
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class ViT(nn.Module):
|
||||
class ViT(Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
@@ -90,7 +92,9 @@ class ViT(nn.Module):
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
num_cls_tokens = 1 if pool == 'cls' else 0
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
|
||||
@@ -99,8 +103,9 @@ class ViT(nn.Module):
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(num_cls_tokens, dim))
|
||||
self.pos_embedding = nn.Parameter(torch.randn(num_patches + num_cls_tokens, dim))
|
||||
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
@@ -111,12 +116,15 @@ class ViT(nn.Module):
|
||||
self.mlp_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, img):
|
||||
batch = img.shape[0]
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
cls_tokens = repeat(self.cls_token, '... d -> b ... d', b = batch)
|
||||
x = torch.cat((cls_tokens, x), dim = 1)
|
||||
|
||||
seq = x.shape[1]
|
||||
|
||||
x = x + self.pos_embedding[:seq]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
@@ -89,7 +89,7 @@ class ViT(nn.Module):
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b (f h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b (f h w) (pf p1 p2 c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
|
||||
191
vit_pytorch/vit_nd.py
Normal file
191
vit_pytorch/vit_nd.py
Normal file
@@ -0,0 +1,191 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Module
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def join(arr, delimiter = ' '):
|
||||
return delimiter.join(arr)
|
||||
|
||||
def ensure_tuple(t, length):
|
||||
if isinstance(t, (tuple, list)):
|
||||
assert len(t) == length, f'Expected tuple of length {length}, got {len(t)}'
|
||||
return tuple(t)
|
||||
return (t,) * length
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return self.norm(x)
|
||||
|
||||
class ViTND(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ndim: int,
|
||||
input_shape: int | tuple[int, ...],
|
||||
patch_size: int | tuple[int, ...],
|
||||
num_classes: int,
|
||||
dim: int,
|
||||
depth: int,
|
||||
heads: int,
|
||||
mlp_dim: int,
|
||||
pool: str = 'cls',
|
||||
channels: int = 3,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.,
|
||||
emb_dropout: float = 0.
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert 1 <= ndim <= 7, 'ndim must be between 1 and 7'
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
|
||||
self.ndim = ndim
|
||||
self.pool = pool
|
||||
|
||||
input_shape = ensure_tuple(input_shape, ndim)
|
||||
patch_size = ensure_tuple(patch_size, ndim)
|
||||
|
||||
for i, (inp_dim, patch_dim) in enumerate(zip(input_shape, patch_size)):
|
||||
assert inp_dim % patch_dim == 0, f'Input dimension {i} ({inp_dim}) must be divisible by patch size ({patch_dim})'
|
||||
|
||||
num_patches_per_dim = [inp_dim // patch_dim for inp_dim, patch_dim in zip(input_shape, patch_size)]
|
||||
num_patches = 1
|
||||
for n in num_patches_per_dim:
|
||||
num_patches *= n
|
||||
|
||||
patch_dim = channels
|
||||
for p in patch_size:
|
||||
patch_dim *= p
|
||||
|
||||
dim_names = 'fghijkl'[:ndim]
|
||||
|
||||
input_dims = [f'({d} p{i})' for i, d in enumerate(dim_names)]
|
||||
patch_dims = [f'p{i}' for i in range(ndim)]
|
||||
|
||||
input_pattern = f'b c {join(input_dims)}'
|
||||
output_pattern = f'b ({join(dim_names)}) ({join(patch_dims)} c)'
|
||||
rearrange_str = f'{input_pattern} -> {output_pattern}'
|
||||
|
||||
rearrange_kwargs = {f'p{i}': p for i, p in enumerate(patch_size)}
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange(rearrange_str, **rearrange_kwargs),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
self.to_latent = nn.Identity()
|
||||
self.mlp_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.to_patch_embedding(x)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim = 1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
x = x[:, 1:].mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.mlp_head(x)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
model = ViTND(
|
||||
ndim = 4,
|
||||
input_shape = (8, 16, 32, 64),
|
||||
patch_size = (2, 4, 4, 8),
|
||||
num_classes = 1000,
|
||||
dim = 512,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048,
|
||||
channels = 3,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
occupancy_time = torch.randn(2, 3, 8, 16, 32, 64)
|
||||
|
||||
logits = model(occupancy_time)
|
||||
325
vit_pytorch/vit_nd_rotary.py
Normal file
325
vit_pytorch/vit_nd_rotary.py
Normal file
@@ -0,0 +1,325 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from torch import nn, arange, cat, stack, Tensor
|
||||
from torch.nn import Module, ModuleList
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat, reduce, pack, unpack
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def l2norm(t):
|
||||
return F.normalize(t, dim = -1, p = 2)
|
||||
|
||||
def join(arr, delimiter = ' '):
|
||||
return delimiter.join(arr)
|
||||
|
||||
def ensure_tuple(t, length):
|
||||
if isinstance(t, (tuple, list)):
|
||||
assert len(t) == length, f'Expected tuple of length {length}, got {len(t)}'
|
||||
return tuple(t)
|
||||
return (t,) * length
|
||||
|
||||
# golden gate rotary - Jerry Xiong, PhD student at UIUC
|
||||
# https://jerryxio.ng/posts/nd-rope/
|
||||
|
||||
def _phi(m: int) -> float:
|
||||
x = 2.0
|
||||
for _ in range(10):
|
||||
x = (1 + x) ** (1.0 / (m + 1.0))
|
||||
return x
|
||||
|
||||
def make_directions(n: int, d: int) -> Tensor:
|
||||
g = _phi(d)
|
||||
alpha = (1.0 / g) ** arange(1, d + 1, dtype = torch.float64)
|
||||
i = arange(1, n + 1, dtype = torch.float64).unsqueeze(1)
|
||||
z = torch.fmod(i * alpha, 1.0)
|
||||
directions = torch.erfinv(2.0 * z - 1.0)
|
||||
directions = l2norm(directions)
|
||||
return directions.float()
|
||||
|
||||
class GoldenGateRoPENd(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim_pos: int,
|
||||
heads: int,
|
||||
dim_head: int,
|
||||
rope_min_freq: float = 1.0,
|
||||
rope_max_freq: float = 10000.0,
|
||||
rope_p_zero_freqs: float = 0.0, # proportion of frequencies set to 0
|
||||
):
|
||||
super().__init__()
|
||||
n_freqs = dim_head // 2
|
||||
n_zero_freqs = round(rope_p_zero_freqs * n_freqs)
|
||||
|
||||
omega = cat((
|
||||
torch.zeros(n_zero_freqs),
|
||||
rope_min_freq * (rope_max_freq / rope_min_freq) ** torch.linspace(0, 1, n_freqs - n_zero_freqs),
|
||||
))
|
||||
|
||||
directions = rearrange(
|
||||
make_directions(heads * n_freqs, dim_pos),
|
||||
'(h f) p -> h f p',
|
||||
h = heads
|
||||
)
|
||||
|
||||
omega_expanded = rearrange(omega, 'f -> f 1')
|
||||
self.register_buffer('freqs', directions * omega_expanded) # shape: (h, f, p)
|
||||
|
||||
def forward(self, input: Tensor, pos: Tensor) -> Tensor:
|
||||
# input shape: (b, h, n, d) where d = head_dim
|
||||
# pos shape: (b, n, p) where p = pos_dim
|
||||
# self.freqs shape: (h, f, p) where f = d // 2
|
||||
|
||||
x, y = input.float().chunk(2, dim = -1) # both (b, h, n, f)
|
||||
|
||||
# Expand dimensions for broadcasting
|
||||
freqs = rearrange(self.freqs, 'h f p -> 1 h 1 f p')
|
||||
positions = rearrange(pos.float(), 'b n p -> b 1 n 1 p')
|
||||
|
||||
# Compute theta for each (batch, head, seq, freq)
|
||||
theta = reduce(freqs * positions, 'b h n f p -> b h n f', 'sum')
|
||||
|
||||
cos_theta = torch.cos(theta)
|
||||
sin_theta = torch.sin(theta)
|
||||
|
||||
# Apply rotation
|
||||
x_out = x * cos_theta - y * sin_theta
|
||||
y_out = x * sin_theta + y * cos_theta
|
||||
|
||||
output = cat((x_out, y_out), dim=-1)
|
||||
return output.type_as(input)
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., rotary_emb = None):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.rotary_emb = rotary_emb
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qk = nn.Linear(dim, inner_dim * 2, bias = False)
|
||||
self.to_v = nn.Linear(dim, inner_dim, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x, pos = None):
|
||||
x = self.norm(x)
|
||||
qkv = (*self.to_qk(x).chunk(2, dim = -1), self.to_v(x))
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
# Apply rotary embeddings if available
|
||||
if exists(self.rotary_emb):
|
||||
assert exists(pos)
|
||||
q = self.rotary_emb(q, pos)
|
||||
k = self.rotary_emb(k, pos)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., rotary_emb = None):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, rotary_emb = rotary_emb),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
def forward(self, x, pos = None):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x, pos) + x
|
||||
x = ff(x) + x
|
||||
return self.norm(x)
|
||||
|
||||
class ViTND(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ndim: int,
|
||||
input_shape: int | tuple[int, ...],
|
||||
patch_size: int | tuple[int, ...],
|
||||
num_classes: int,
|
||||
dim: int,
|
||||
depth: int,
|
||||
heads: int,
|
||||
mlp_dim: int,
|
||||
channels: int = 3,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.,
|
||||
emb_dropout: float = 0.,
|
||||
rope_min_freq: float = 1.0,
|
||||
rope_max_freq: float = 10000.0,
|
||||
rope_p_zero_freqs: float = 0.0
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert 1 <= ndim <= 7, 'ndim must be between 1 and 7'
|
||||
|
||||
self.ndim = ndim
|
||||
|
||||
input_shape = ensure_tuple(input_shape, ndim)
|
||||
patch_size = ensure_tuple(patch_size, ndim)
|
||||
|
||||
for i, (inp_dim, patch_dim) in enumerate(zip(input_shape, patch_size)):
|
||||
assert inp_dim % patch_dim == 0, f'Input dimension {i} ({inp_dim}) must be divisible by patch size ({patch_dim})'
|
||||
|
||||
num_patches_per_dim = [inp_dim // patch_dim for inp_dim, patch_dim in zip(input_shape, patch_size)]
|
||||
num_patches = 1
|
||||
for n in num_patches_per_dim:
|
||||
num_patches *= n
|
||||
|
||||
patch_dim = channels
|
||||
for p in patch_size:
|
||||
patch_dim *= p
|
||||
|
||||
dim_names = 'fghijkl'[:ndim]
|
||||
|
||||
input_dims = [f'({d} p{i})' for i, d in enumerate(dim_names)]
|
||||
patch_dims = [f'p{i}' for i in range(ndim)]
|
||||
|
||||
input_pattern = f'b c {join(input_dims)}'
|
||||
output_pattern = f'b {join(dim_names)} ({join(patch_dims)} c)'
|
||||
rearrange_str = f'{input_pattern} -> {output_pattern}'
|
||||
|
||||
rearrange_kwargs = {f'p{i}': p for i, p in enumerate(patch_size)}
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange(rearrange_str, **rearrange_kwargs),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
# Create rotary embeddings
|
||||
self.rotary_emb = GoldenGateRoPENd(
|
||||
dim_pos = ndim,
|
||||
heads = heads,
|
||||
dim_head = dim_head,
|
||||
rope_min_freq = rope_min_freq,
|
||||
rope_max_freq = rope_max_freq,
|
||||
rope_p_zero_freqs = rope_p_zero_freqs
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, rotary_emb = self.rotary_emb)
|
||||
|
||||
self.to_latent = nn.Identity()
|
||||
self.mlp_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def muon_parameters(self):
|
||||
params = []
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, Attention):
|
||||
params.extend([
|
||||
m.to_v.weight,
|
||||
m.to_out[0].weight
|
||||
])
|
||||
elif isinstance(m, FeedForward):
|
||||
params.extend([
|
||||
m.net[1].weight,
|
||||
m.net[-2].weight
|
||||
])
|
||||
|
||||
return params
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
return_embed = False
|
||||
):
|
||||
x = self.to_patch_embedding(x) # (b, *spatial_dims, patch_dim)
|
||||
|
||||
batch, *spatial_dims, _, device = *x.shape, x.device
|
||||
|
||||
# Generate position coordinates
|
||||
|
||||
grids = [arange(d, device = device, dtype = torch.float32) for d in spatial_dims]
|
||||
grid = torch.meshgrid(*grids, indexing = 'ij')
|
||||
pos = stack(grid, dim = -1) # (*spatial_dims, ndim)
|
||||
|
||||
# flatten spatial dimensions for attention with nd rotary
|
||||
|
||||
pos = repeat(pos, '... p -> b (...) p', b = batch)
|
||||
x, packed_shape = pack([x], 'b * d')
|
||||
|
||||
x = self.dropout(x)
|
||||
|
||||
embed = self.transformer(x, pos)
|
||||
|
||||
# return the embed with reconstituted patch shape
|
||||
|
||||
if return_embed:
|
||||
embed, = unpack(embed, packed_shape, 'b * d')
|
||||
return embed
|
||||
|
||||
# pooling to logits
|
||||
|
||||
pooled = reduce(embed, 'b n d -> b d', 'mean')
|
||||
|
||||
pooled = self.to_latent(pooled)
|
||||
return self.mlp_head(pooled)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
model = ViTND(
|
||||
ndim = 5,
|
||||
input_shape = (4, 8, 16, 32, 64),
|
||||
patch_size = (2, 2, 4, 4, 8),
|
||||
num_classes = 1000,
|
||||
dim = 512,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048,
|
||||
channels = 3,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
data = torch.randn(2, 3, 4, 8, 16, 32, 64)
|
||||
|
||||
logits = model(data)
|
||||
|
||||
embed = model(data, return_embed = True) # (2, 2, 4, 4, 8, 8, 512)
|
||||
234
vit_pytorch/vit_with_decorr.py
Normal file
234
vit_pytorch/vit_with_decorr.py
Normal file
@@ -0,0 +1,234 @@
|
||||
# https://arxiv.org/abs/2510.14657
|
||||
# but instead of their decorr module updated with SGD, remove all projections and just return a decorrelation auxiliary loss
|
||||
|
||||
import torch
|
||||
from torch import nn, stack, tensor
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Module, ModuleList
|
||||
|
||||
from einops import rearrange, repeat, reduce, einsum, pack, unpack
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
def default(v, d):
|
||||
return v if exists(v) else d
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
# decorr loss
|
||||
|
||||
class DecorrelationLoss(Module):
|
||||
def __init__(
|
||||
self,
|
||||
sample_frac = 1.,
|
||||
soft_validate_num_sampled = False
|
||||
):
|
||||
super().__init__()
|
||||
assert 0. <= sample_frac <= 1.
|
||||
self.need_sample = sample_frac < 1.
|
||||
self.sample_frac = sample_frac
|
||||
|
||||
self.soft_validate_num_sampled = soft_validate_num_sampled
|
||||
self.register_buffer('zero', tensor(0.), persistent = False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tokens
|
||||
):
|
||||
batch, seq_len, dim, device = *tokens.shape[-3:], tokens.device
|
||||
|
||||
if self.need_sample:
|
||||
num_sampled = int(seq_len * self.sample_frac)
|
||||
|
||||
assert self.soft_validate_num_sampled or num_sampled >= 2.
|
||||
|
||||
if num_sampled <= 1:
|
||||
return self.zero
|
||||
|
||||
tokens, packed_shape = pack([tokens], '* n d e')
|
||||
|
||||
indices = torch.randn(tokens.shape[:2]).argsort(dim = -1)[..., :num_sampled, :]
|
||||
|
||||
batch_arange = torch.arange(tokens.shape[0], device = tokens.device)
|
||||
batch_arange = rearrange(batch_arange, 'b -> b 1')
|
||||
|
||||
tokens = tokens[batch_arange, indices]
|
||||
tokens, = unpack(tokens, packed_shape, '* n d e')
|
||||
|
||||
dist = einsum(tokens, tokens, '... n d, ... n e -> ... d e') / tokens.shape[-2]
|
||||
eye = torch.eye(dim, device = device)
|
||||
|
||||
loss = dist.pow(2) * (1. - eye) / ((dim - 1) * dim)
|
||||
|
||||
loss = reduce(loss, '... b d e -> b', 'sum')
|
||||
return loss.mean()
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
normed = self.norm(x)
|
||||
return self.net(x), normed
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
normed = self.norm(x)
|
||||
|
||||
qkv = self.to_qkv(normed).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
|
||||
return self.to_out(out), normed
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
normed_inputs = []
|
||||
|
||||
for attn, ff in self.layers:
|
||||
attn_out, attn_normed_inp = attn(x)
|
||||
x = attn_out + x
|
||||
|
||||
ff_out, ff_normed_inp = ff(x)
|
||||
x = ff_out + x
|
||||
|
||||
normed_inputs.append(attn_normed_inp)
|
||||
normed_inputs.append(ff_normed_inp)
|
||||
|
||||
return self.norm(x), stack(normed_inputs)
|
||||
|
||||
class ViT(Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., decorr_sample_frac = 1.):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
self.pool = pool
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Linear(dim, num_classes)
|
||||
|
||||
# decorrelation loss related
|
||||
|
||||
self.has_decorr_loss = decorr_sample_frac > 0.
|
||||
|
||||
if self.has_decorr_loss:
|
||||
self.decorr_loss = DecorrelationLoss(decorr_sample_frac)
|
||||
|
||||
self.register_buffer('zero', torch.tensor(0.), persistent = False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img,
|
||||
return_decorr_aux_loss = None
|
||||
):
|
||||
return_decorr_aux_loss = default(return_decorr_aux_loss, self.training) and self.has_decorr_loss
|
||||
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.dropout(x)
|
||||
|
||||
x, normed_layer_inputs = self.transformer(x)
|
||||
|
||||
# maybe return decor loss
|
||||
|
||||
decorr_aux_loss = self.zero
|
||||
|
||||
if return_decorr_aux_loss:
|
||||
decorr_aux_loss = self.decorr_loss(normed_layer_inputs)
|
||||
|
||||
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.mlp_head(x), decorr_aux_loss
|
||||
|
||||
# quick test
|
||||
|
||||
if __name__ == '__main__':
|
||||
decorr_loss = DecorrelationLoss(0.1)
|
||||
|
||||
hiddens = torch.randn(6, 2, 512, 256)
|
||||
|
||||
decorr_loss(hiddens)
|
||||
decorr_loss(hiddens[0])
|
||||
|
||||
decorr_loss = DecorrelationLoss(0.0001, soft_validate_num_sampled = True)
|
||||
out = decorr_loss(hiddens)
|
||||
assert out.item() == 0
|
||||
@@ -78,6 +78,30 @@ class Transformer(nn.Module):
|
||||
x = ff(x) + x
|
||||
return self.norm(x)
|
||||
|
||||
class FactorizedTransformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
b, f, n, _ = x.shape
|
||||
for spatial_attn, temporal_attn, ff in self.layers:
|
||||
x = rearrange(x, 'b f n d -> (b f) n d')
|
||||
x = spatial_attn(x) + x
|
||||
x = rearrange(x, '(b f) n d -> (b n) f d', b=b, f=f)
|
||||
x = temporal_attn(x) + x
|
||||
x = ff(x) + x
|
||||
x = rearrange(x, '(b n) f d -> b f n d', b=b, n=n)
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class ViT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -96,7 +120,8 @@ class ViT(nn.Module):
|
||||
channels = 3,
|
||||
dim_head = 64,
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.
|
||||
emb_dropout = 0.,
|
||||
variant = 'factorized_encoder',
|
||||
):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
@@ -104,6 +129,7 @@ class ViT(nn.Module):
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size'
|
||||
assert variant in ('factorized_encoder', 'factorized_self_attention'), f'variant = {variant} is not implemented'
|
||||
|
||||
num_image_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
num_frame_patches = (frames // frame_patch_size)
|
||||
@@ -115,7 +141,7 @@ class ViT(nn.Module):
|
||||
self.global_average_pool = pool == 'mean'
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b f (h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b f (h w) (pf p1 p2 c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim)
|
||||
@@ -125,15 +151,20 @@ class ViT(nn.Module):
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.spatial_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None
|
||||
self.temporal_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None
|
||||
|
||||
self.spatial_transformer = Transformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)
|
||||
self.temporal_transformer = Transformer(dim, temporal_depth, heads, dim_head, mlp_dim, dropout)
|
||||
if variant == 'factorized_encoder':
|
||||
self.temporal_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None
|
||||
self.spatial_transformer = Transformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)
|
||||
self.temporal_transformer = Transformer(dim, temporal_depth, heads, dim_head, mlp_dim, dropout)
|
||||
elif variant == 'factorized_self_attention':
|
||||
assert spatial_depth == temporal_depth, 'Spatial and temporal depth must be the same for factorized self-attention'
|
||||
self.factorized_transformer = FactorizedTransformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
self.pool = pool
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Linear(dim, num_classes)
|
||||
self.variant = variant
|
||||
|
||||
def forward(self, video):
|
||||
x = self.to_patch_embedding(video)
|
||||
@@ -147,32 +178,37 @@ class ViT(nn.Module):
|
||||
|
||||
x = self.dropout(x)
|
||||
|
||||
x = rearrange(x, 'b f n d -> (b f) n d')
|
||||
if self.variant == 'factorized_encoder':
|
||||
x = rearrange(x, 'b f n d -> (b f) n d')
|
||||
|
||||
# attend across space
|
||||
# attend across space
|
||||
|
||||
x = self.spatial_transformer(x)
|
||||
x = self.spatial_transformer(x)
|
||||
x = rearrange(x, '(b f) n d -> b f n d', b = b)
|
||||
|
||||
x = rearrange(x, '(b f) n d -> b f n d', b = b)
|
||||
# excise out the spatial cls tokens or average pool for temporal attention
|
||||
|
||||
# excise out the spatial cls tokens or average pool for temporal attention
|
||||
x = x[:, :, 0] if not self.global_average_pool else reduce(x, 'b f n d -> b f d', 'mean')
|
||||
|
||||
x = x[:, :, 0] if not self.global_average_pool else reduce(x, 'b f n d -> b f d', 'mean')
|
||||
# append temporal CLS tokens
|
||||
|
||||
# append temporal CLS tokens
|
||||
if exists(self.temporal_cls_token):
|
||||
temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = b)
|
||||
|
||||
if exists(self.temporal_cls_token):
|
||||
temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = b)
|
||||
x = torch.cat((temporal_cls_tokens, x), dim = 1)
|
||||
|
||||
|
||||
x = torch.cat((temporal_cls_tokens, x), dim = 1)
|
||||
# attend across time
|
||||
|
||||
# attend across time
|
||||
x = self.temporal_transformer(x)
|
||||
|
||||
x = self.temporal_transformer(x)
|
||||
# excise out temporal cls token or average pool
|
||||
|
||||
# excise out temporal cls token or average pool
|
||||
x = x[:, 0] if not self.global_average_pool else reduce(x, 'b f d -> b d', 'mean')
|
||||
|
||||
x = x[:, 0] if not self.global_average_pool else reduce(x, 'b f d -> b d', 'mean')
|
||||
elif self.variant == 'factorized_self_attention':
|
||||
x = self.factorized_transformer(x)
|
||||
x = x[:, 0, 0] if not self.global_average_pool else reduce(x, 'b f n d -> b d', 'mean')
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.mlp_head(x)
|
||||
|
||||
Reference in New Issue
Block a user