Compare commits

...

20 Commits
1.16.2 ... main

Author SHA1 Message Date
lucidrains
7e18d0302e stop relying on github 2026-03-27 08:07:42 -07:00
lucidrains
b80676e09c add an attention residual example (kimi team) as well as dino / byol redone with sigreg (lejepa) 2026-03-27 07:51:12 -07:00
lucidrains
fc1e727428 add ability to condition on binned advantages for the vision action transformers 2026-02-19 14:10:44 -08:00
lucidrains
6032a54b48 patch 2026-02-11 11:49:51 -08:00
Harikrishna KP
06a1f42924 Fix ViViT Transformer not passing use_flash_attn to Attention and duplicate mask reshape (#360)
Two related bugs in vivit.py:

1. Transformer.__init__ accepted use_flash_attn but never forwarded it to the
   Attention modules it creates. Since Attention defaults to use_flash_attn=True,
   setting use_flash_attn=False on ViViT had no effect on the factorized_encoder
   variant's spatial and temporal transformers.

2. Attention.forward reshaped the mask from 2D to 4D before the flash/non-flash
   branch (line 82), then attempted to reshape it again inside the non-flash
   branch (line 92). When the non-flash code path is actually reached with a
   mask, einops raises an error because the mask is already 4D.

   These bugs masked each other: bug #1 prevented bug #2 from triggering because
   the non-flash path was never taken even when requested.

Fix: pass use_flash_attn through to Attention in Transformer.__init__, and
remove the redundant second mask rearrange in the non-flash branch.
2026-02-11 11:49:31 -08:00
Phil Wang
6ae6a3ab64 cleanup 2026-02-04 13:29:40 -08:00
lucidrains
827300beed add vit with keel post ln, proposed by bytedance for scaling depth 2026-02-04 09:09:17 -08:00
lucidrains
a7c4e7f79f best practices 2026-01-28 05:04:20 -08:00
lucidrains
54ec3f2af5 address https://github.com/lucidrains/vit-pytorch/issues/357 2026-01-17 05:11:36 -08:00
lucidrains
9aa52cce49 do an actual vat with siglip arch, and have gemini flash craft the weight loading script from hf 2026-01-15 11:27:05 -08:00
lucidrains
4c89017444 fix up vivit 2026-01-08 06:36:40 -08:00
Eyal Mazuz
580258d99e Allow to pass mask parameter for temporal transformer in ViVit (#356)
* Mask for temporal transformer in ViVit

This allows to pad videos to certain length which allow the transformer
to ignore padded frames using batch sizes > 1

* Added flash attention to vivit

* Added flash attention to vivit

* Added flash attention to vivit
2026-01-08 06:08:54 -08:00
lucidrains
6f1caef987 allow for no final output head on the vit 2026-01-06 13:00:48 -08:00
lucidrains
fb5014f0ee get a version of n-dimensional vit with golden gate polar coordinate embeddings into the repo for future use 2025-12-25 09:11:13 -08:00
Phil Wang
0b7518ef45 educate 2025-12-21 07:06:20 -08:00
lucidrains
077d8c188f fix distill 2025-12-10 15:52:10 -08:00
lucidrains
5888f05300 1.16.4 2025-12-07 04:32:52 -08:00
Amit Moryossef
d518e89573 cache position grids in NaViT forward pass (#354)
Use lru_cache to cache unique (ph, pw, device) position grids, avoiding
redundant computation when multiple images share the same patch
dimensions. Cache persists across forward passes.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-07 04:32:30 -08:00
lucidrains
dd6462d19b release small navit perf 2025-12-06 04:57:12 -08:00
Amit Moryossef
a1ee1daa1a optimize NaViT with SDPA and vectorized forward pass (#353)
- Replace manual attention with F.scaled_dot_product_attention
- Use repeat_interleave instead of meshgrid for position computation
- Build image_ids efficiently with repeat_interleave instead of F.pad
- Remove unused Rearrange import

~56% speedup (91ms -> 58ms on 512 variable-sized images)
Numerically equivalent (max diff ~5e-4, within flash attention tolerance)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude <noreply@anthropic.com>
2025-12-06 04:56:40 -08:00
44 changed files with 2478 additions and 657 deletions

3
.github/FUNDING.yml vendored
View File

@@ -1,3 +0,0 @@
# These are supported funding model platforms
github: [lucidrains]

View File

@@ -1,36 +0,0 @@
# This workflow will upload a Python Package using Twine when a release is created
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
# This workflow uses actions that are not certified by GitHub.
# They are provided by a third-party and are governed by
# separate terms of service, privacy policy, and support
# documentation.
name: Upload Python Package
on:
release:
types: [published]
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install build
- name: Build package
run: python -m build
- name: Publish package
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}

View File

@@ -1,34 +0,0 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: Test
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8, 3.9]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cpu
python -m pip install -e .
python -m pip install pytest
- name: Test with pytest
run: |
pytest -q

3
.gitignore vendored
View File

@@ -127,3 +127,6 @@ dmypy.json
# Pyre type checker
.pyre/
# scripts
*.sh

134
README.md
View File

@@ -49,7 +49,7 @@
## Vision Transformer - Pytorch
Implementation of <a href="https://openreview.net/pdf?id=YicbFdNTTy">Vision Transformer</a>, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. Significance is further explained in <a href="https://www.youtube.com/watch?v=TrdevFK_am4">Yannic Kilcher's</a> video. There's really not much to code here, but may as well lay it out for everyone so we expedite the attention revolution.
Implementation of <a href="https://openreview.net/pdf?id=YicbFdNTTy">Vision Transformer</a>, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. Significance is further explained in <a href="https://www.youtube.com/watch?v=TrdevFK_am4">Yannic Kilcher's</a> video. There's really not much to code here, but may as well lay it out for everyone so we expedite the [attention](https://www.youtube.com/watch?v=eMlx5fFNoYc) revolution.
For a Pytorch implementation with pretrained models, please see Ross Wightman's repository <a href="https://github.com/rwightman/pytorch-image-models">here</a>.
@@ -90,26 +90,26 @@ preds = v(img) # (1, 1000)
## Parameters
- `image_size`: int.
- `image_size`: int.
Image size. If you have rectangular images, make sure your image size is the maximum of the width and height
- `patch_size`: int.
Size of patches. `image_size` must be divisible by `patch_size`.
- `patch_size`: int.
Size of patches. `image_size` must be divisible by `patch_size`.
The number of patches is: ` n = (image_size // patch_size) ** 2` and `n` **must be greater than 16**.
- `num_classes`: int.
- `num_classes`: int.
Number of classes to classify.
- `dim`: int.
- `dim`: int.
Last dimension of output tensor after linear transformation `nn.Linear(..., dim)`.
- `depth`: int.
- `depth`: int.
Number of Transformer blocks.
- `heads`: int.
Number of heads in Multi-head Attention layer.
- `mlp_dim`: int.
Dimension of the MLP (FeedForward) layer.
- `channels`: int, default `3`.
Number of image's channels.
- `dropout`: float between `[0, 1]`, default `0.`.
Dropout rate.
- `emb_dropout`: float between `[0, 1]`, default `0`.
- `heads`: int.
Number of heads in Multi-head Attention layer.
- `mlp_dim`: int.
Dimension of the MLP (FeedForward) layer.
- `channels`: int, default `3`.
Number of image's channels.
- `dropout`: float between `[0, 1]`, default `0.`.
Dropout rate.
- `emb_dropout`: float between `[0, 1]`, default `0`.
Embedding dropout rate.
- `pool`: string, either `cls` token pooling or `mean` pooling
@@ -972,7 +972,7 @@ torch.save(model.state_dict(), './pretrained-net.pt')
<img src="./images/mp3.png" width="400px"></img>
New <a href="https://arxiv.org/abs/2207.07611">paper</a> that introduces masked position prediction pre-training criteria. This strategy is more efficient than the Masked Autoencoder strategy and has comparable performance.
New <a href="https://arxiv.org/abs/2207.07611">paper</a> that introduces masked position prediction pre-training criteria. This strategy is more efficient than the Masked Autoencoder strategy and has comparable performance.
```python
import torch
@@ -1358,10 +1358,10 @@ learner = Dino(
hidden_layer = 'to_latent', # hidden layer name or index, from which to extract the embedding
projection_hidden_size = 256, # projector network hidden dimension
projection_layers = 4, # number of layers in projection network
num_classes_K = 65336, # output logits dimensions (referenced as K in paper)
num_classes_K = 65536, # output logits dimensions (referenced as K in paper)
student_temp = 0.9, # student temperature
teacher_temp = 0.04, # teacher temperature, needs to be annealed from 0.04 to 0.07 over 30 epochs
local_upper_crop_scale = 0.4, # upper bound for local crop - 0.4 was recommended in the paper
local_upper_crop_scale = 0.4, # upper bound for local crop - 0.4 was recommended in the paper
global_lower_crop_scale = 0.5, # lower bound for global crop - 0.5 was recommended in the paper
moving_average_decay = 0.9, # moving average of encoder - paper showed anywhere from 0.9 to 0.999 was ok
center_moving_average_decay = 0.9, # moving average of teacher centers - paper showed anywhere from 0.9 to 0.999 was ok
@@ -1735,7 +1735,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
```bibtex
@misc{touvron2020training,
title = {Training data-efficient image transformers & distillation through attention},
title = {Training data-efficient image transformers & distillation through attention},
author = {Hugo Touvron and Matthieu Cord and Matthijs Douze and Francisco Massa and Alexandre Sablayrolles and Hervé Jégou},
year = {2020},
eprint = {2012.12877},
@@ -1768,7 +1768,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
```bibtex
@misc{touvron2021going,
title = {Going deeper with Image Transformers},
title = {Going deeper with Image Transformers},
author = {Hugo Touvron and Matthieu Cord and Alexandre Sablayrolles and Gabriel Synnaeve and Hervé Jégou},
year = {2021},
eprint = {2103.17239},
@@ -1801,7 +1801,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
```bibtex
@misc{heo2021rethinking,
title = {Rethinking Spatial Dimensions of Vision Transformers},
title = {Rethinking Spatial Dimensions of Vision Transformers},
author = {Byeongho Heo and Sangdoo Yun and Dongyoon Han and Sanghyuk Chun and Junsuk Choe and Seong Joon Oh},
year = {2021},
eprint = {2103.16302},
@@ -1845,7 +1845,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
```bibtex
@misc{su2021roformer,
title = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
title = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
author = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
year = {2021},
eprint = {2104.09864},
@@ -1867,7 +1867,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
```bibtex
@misc{chen2021regionvit,
title = {RegionViT: Regional-to-Local Attention for Vision Transformers},
title = {RegionViT: Regional-to-Local Attention for Vision Transformers},
author = {Chun-Fu Chen and Rameswar Panda and Quanfu Fan},
year = {2021},
eprint = {2106.02689},
@@ -1878,7 +1878,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
```bibtex
@misc{wang2021crossformer,
title = {CrossFormer: A Versatile Vision Transformer Hinging on Cross-scale Attention},
title = {CrossFormer: A Versatile Vision Transformer Hinging on Cross-scale Attention},
author = {Wenxiao Wang and Lu Yao and Long Chen and Binbin Lin and Deng Cai and Xiaofei He and Wei Liu},
year = {2021},
eprint = {2108.00154},
@@ -1900,7 +1900,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
```bibtex
@misc{he2021masked,
title = {Masked Autoencoders Are Scalable Vision Learners},
title = {Masked Autoencoders Are Scalable Vision Learners},
author = {Kaiming He and Xinlei Chen and Saining Xie and Yanghao Li and Piotr Dollár and Ross Girshick},
year = {2021},
eprint = {2111.06377},
@@ -1911,7 +1911,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
```bibtex
@misc{xie2021simmim,
title = {SimMIM: A Simple Framework for Masked Image Modeling},
title = {SimMIM: A Simple Framework for Masked Image Modeling},
author = {Zhenda Xie and Zheng Zhang and Yue Cao and Yutong Lin and Jianmin Bao and Zhuliang Yao and Qi Dai and Han Hu},
year = {2021},
eprint = {2111.09886},
@@ -1944,7 +1944,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
```bibtex
@misc{lee2021vision,
title = {Vision Transformer for Small-Size Datasets},
title = {Vision Transformer for Small-Size Datasets},
author = {Seung Hoon Lee and Seunghyun Lee and Byung Cheol Song},
year = {2021},
eprint = {2112.13492},
@@ -1966,7 +1966,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
```bibtex
@misc{yang2022scalablevit,
title = {ScalableViT: Rethinking the Context-oriented Generalization of Vision Transformer},
title = {ScalableViT: Rethinking the Context-oriented Generalization of Vision Transformer},
author = {Rui Yang and Hailong Ma and Jie Wu and Yansong Tang and Xuefeng Xiao and Min Zheng and Xiu Li},
year = {2022},
eprint = {2203.10790},
@@ -2203,13 +2203,85 @@ Coming from computer vision and new to transformers? Here are some resources tha
```bibtex
@misc{carrigg2025decorrelationspeedsvisiontransformers,
title = {Decorrelation Speeds Up Vision Transformers},
title = {Decorrelation Speeds Up Vision Transformers},
author = {Kieran Carrigg and Rob van Gastel and Melda Yeghaian and Sander Dalm and Faysal Boughorbel and Marcel van Gerven},
year = {2025},
eprint = {2510.14657},
archivePrefix = {arXiv},
primaryClass = {cs.CV},
url = {https://arxiv.org/abs/2510.14657},
url = {https://arxiv.org/abs/2510.14657},
}
```
```bibtex
@misc{gopalakrishnan2025decouplingwhatwherepolar,
title = {Decoupling the "What" and "Where" With Polar Coordinate Positional Embeddings},
author = {Anand Gopalakrishnan and Robert Csordás and Jürgen Schmidhuber and Michael C. Mozer},
year = {2025},
eprint = {2509.10534},
archivePrefix = {arXiv},
primaryClass = {cs.LG},
url = {https://arxiv.org/abs/2509.10534},
}
```
```bibtex
@misc{qiu2025gatedattentionlargelanguage,
title = {Gated Attention for Large Language Models: Non-linearity, Sparsity, and Attention-Sink-Free},
author = {Zihan Qiu and Zekun Wang and Bo Zheng and Zeyu Huang and Kaiyue Wen and Songlin Yang and Rui Men and Le Yu and Fei Huang and Suozhi Huang and Dayiheng Liu and Jingren Zhou and Junyang Lin},
year = {2025},
eprint = {2505.06708},
archivePrefix = {arXiv},
primaryClass = {cs.CL},
url = {https://arxiv.org/abs/2505.06708}
}
```
```bibtex
@misc{chen2026postlayernormbackstableexpressive,
title = {Post-LayerNorm Is Back: Stable, ExpressivE, and Deep},
author = {Chen Chen and Lai Wei},
year = {2026},
eprint = {2601.19895},
archivePrefix = {arXiv},
primaryClass = {cs.LG},
url = {https://arxiv.org/abs/2601.19895},
}
```
```bibtex
@misc{intelligence2025pi06vlalearnsexperience,
title = {$\pi^{*}_{0.6}$: a VLA That Learns From Experience},
author = {Physical Intelligence and Ali Amin and Raichelle Aniceto and Ashwin Balakrishna and Kevin Black and Ken Conley and Grace Connors and James Darpinian and Karan Dhabalia and Jared DiCarlo and Danny Driess and Michael Equi and Adnan Esmail and Yunhao Fang and Chelsea Finn and Catherine Glossop and Thomas Godden and Ivan Goryachev and Lachy Groom and Hunter Hancock and Karol Hausman and Gashon Hussein and Brian Ichter and Szymon Jakubczak and Rowan Jen and Tim Jones and Ben Katz and Liyiming Ke and Chandra Kuchi and Marinda Lamb and Devin LeBlanc and Sergey Levine and Adrian Li-Bell and Yao Lu and Vishnu Mano and Mohith Mothukuri and Suraj Nair and Karl Pertsch and Allen Z. Ren and Charvi Sharma and Lucy Xiaoyang Shi and Laura Smith and Jost Tobias Springenberg and Kyle Stachowicz and Will Stoeckle and Alex Swerdlow and James Tanner and Marcel Torne and Quan Vuong and Anna Walling and Haohuan Wang and Blake Williams and Sukwon Yoo and Lili Yu and Ury Zhilinsky and Zhiyuan Zhou},
year = {2025},
eprint = {2511.14759},
archivePrefix = {arXiv},
primaryClass = {cs.LG},
url = {https://arxiv.org/abs/2511.14759},
}
```
```bibtex
@misc{kimiteam2026attentionresiduals,
title = {Attention Residuals},
author = {Kimi Team and Guangyu Chen and Yu Zhang and Jianlin Su and Weixin Xu and Siyuan Pan and Yaoyu Wang and Yucheng Wang and Guanduo Chen and Bohong Yin and Yutian Chen and Junjie Yan and Ming Wei and Y. Zhang and Fanqing Meng and Chao Hong and Xiaotong Xie and Shaowei Liu and Enzhe Lu and Yunpeng Tai and Yanru Chen and Xin Men and Haiqing Guo and Y. Charles and Haoyu Lu and Lin Sui and Jinguo Zhu and Zaida Zhou and Weiran He and Weixiao Huang and Xinran Xu and Yuzhi Wang and Guokun Lai and Yulun Du and Yuxin Wu and Zhilin Yang and Xinyu Zhou},
year = {2026},
eprint = {2603.15031},
archivePrefix = {arXiv},
primaryClass = {cs.CL},
url = {https://arxiv.org/abs/2603.15031},
}
```
```bibtex
@misc{balestriero2025lejepa,
title = {LeJEPA: Provable and Scalable Self-Supervised Learning Without the Heuristics},
author = {Randall Balestriero and Yann LeCun},
year = {2025},
eprint = {2511.08544},
archivePrefix = {arXiv},
primaryClass = {cs.LG},
url = {https://arxiv.org/abs/2511.08544},
}
```

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "vit-pytorch"
version = "1.16.2"
version = "1.19.1"
description = "Vision Transformer (ViT) - Pytorch"
readme = { file = "README.md", content-type = "text/markdown" }
license = { file = "LICENSE" }
@@ -31,8 +31,8 @@ classifiers = [
"Programming Language :: Python :: 3.12",
]
dependencies = [
"einops>=0.7.0",
"torch>=1.10",
"einops>=0.8.2",
"torch>=2.4",
"torchvision",
]
@@ -44,8 +44,8 @@ test = [
]
[project.urls]
Homepage = "https://github.com/lucidrains/vit-pytorch"
Repository = "https://github.com/lucidrains/vit-pytorch"
Homepage = "https://codeberg.org/lucidrains/vit-pytorch"
Repository = "https://codeberg.org/lucidrains/vit-pytorch"
[tool.setuptools]
include-package-data = true

View File

@@ -26,7 +26,7 @@ class AcceptVideoWrapper(Module):
dim_emb = None,
time_seq_len = None,
embed_is_channel_first = False,
output_pos_add_pos_emb = 0, # defaults to first output position to add embedding
output_pos_add_pos_emb = 0, # defaults to first output position to add embedding
proj_embed_to_dim = None
):
super().__init__()

View File

@@ -25,12 +25,12 @@ class DistillMixin:
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = b)
cls_tokens = repeat(self.cls_token, 'n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim = 1)
x += self.pos_embedding[:, :(n + 1)]
x += self.pos_embedding[:(n + 1)]
if distilling:
distill_tokens = repeat(distill_token, '1 n d -> b n d', b = b)
distill_tokens = repeat(distill_token, 'n d -> b n d', b = b)
x = torch.cat((x, distill_tokens), dim = 1)
x = self._attend(x)
@@ -125,7 +125,7 @@ class DistillWrapper(Module):
self.alpha = alpha
self.hard = hard
self.distillation_token = nn.Parameter(torch.randn(1, 1, dim))
self.distillation_token = nn.Parameter(torch.randn(1, dim))
self.distill_mlp = nn.Sequential(
nn.LayerNorm(dim) if mlp_layernorm else nn.Identity(),

View File

@@ -103,7 +103,7 @@ class JumboViT(Module):
h = image_height // patch_height,
w = image_width // patch_width,
dim = dim,
)
)
jumbo_cls_dim = dim * jumbo_cls_k

View File

@@ -145,7 +145,7 @@ class ViT(nn.Module):
return x
def forward(self, img):
x = self.img_to_tokens(img)
x = self.img_to_tokens(img)
x = self.transformer(x)
@@ -160,7 +160,7 @@ class Adapter(nn.Module):
*,
vit,
num_memories_per_layer = 10,
num_classes = 2,
num_classes = 2,
):
super().__init__()
assert isinstance(vit, ViT)
@@ -188,7 +188,7 @@ class Adapter(nn.Module):
)
# specialized attention mask to preserve the output of the original ViT
# it allows the memory CLS token to attend to all other tokens (and the learnable memory layer tokens), but not vice versa
# it allows the memory CLS token to attend to all other tokens (and the learnable memory layer tokens), but not vice versa
attn_mask = torch.ones((num_patches, num_patches), dtype = torch.bool)
attn_mask = F.pad(attn_mask, (1, num_memories_per_layer), value = False) # main tokens cannot attend to learnable memories per layer
@@ -203,7 +203,7 @@ class Adapter(nn.Module):
# add task specific memory tokens
memory_cls_tokens = repeat(self.memory_cls_token, 'd -> b 1 d', b = b)
tokens = torch.cat((memory_cls_tokens, tokens), dim = 1)
tokens = torch.cat((memory_cls_tokens, tokens), dim = 1)
# pass memories along with image tokens through transformer for attending

319
vit_pytorch/lejepa.py Normal file
View File

@@ -0,0 +1,319 @@
import random
from functools import wraps
import torch
from torch import nn
from torch.nn import Module
import torch.nn.functional as F
from torchvision import transforms as T
from einops import rearrange
# helper functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def singleton(cache_key):
def inner_fn(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
instance = getattr(self, cache_key)
if instance is not None:
return instance
instance = fn(self, *args, **kwargs)
setattr(self, cache_key, instance)
return instance
return wrapper
return inner_fn
def get_module_device(module):
return next(module.parameters()).device
def l2norm(t, eps = 1e-6):
return F.normalize(t, dim = -1, eps = eps)
# loss function
def sigreg_loss(
x,
num_slices = 1024,
domain = (-5, 5),
num_knots = 17
):
# Randall Balestriero - https://arxiv.org/abs/2511.08544
dim, device = x.shape[-1], x.device
# slice sampling
rand_projs = torch.randn((num_slices, dim), device = device)
rand_projs = l2norm(rand_projs)
# integration points
t = torch.linspace(*domain, num_knots, device = device)
# theoretical CF for N(0, 1) and Gauss. window
exp_f = (-0.5 * t.square()).exp()
# empirical CF
x_t = torch.einsum('... d, m d -> ... m', x, rand_projs)
x_t = rearrange(x_t, '... m -> (...) m')
x_t = rearrange(x_t, 'n m -> n m 1') * t
ecf = (1j * x_t).exp().mean(dim = 0)
# weighted L2 distance
err = ecf.sub(exp_f).abs().square().mul(exp_f)
return torch.trapezoid(err, t, dim = -1).mean()
# augmentation utils
class RandomApply(Module):
def __init__(self, fn, p):
super().__init__()
self.fn = fn
self.p = p
def forward(self, x):
if random.random() > self.p:
return x
return self.fn(x)
# MLP class for projector
class L2Norm(Module):
def forward(self, x, eps = 1e-6):
return l2norm(x, eps)
class MLP(Module):
def __init__(self, dim, dim_out, num_layers, hidden_size = 256):
super().__init__()
layers = []
dims = (dim, *((hidden_size,) * (num_layers - 1)))
for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
is_last = ind == (len(dims) - 1)
layers.extend([
nn.Linear(layer_dim_in, layer_dim_out),
nn.GELU() if not is_last else nn.Identity()
])
self.net = nn.Sequential(
*layers,
L2Norm(),
nn.Linear(hidden_size, dim_out)
)
def forward(self, x):
return self.net(x)
# wrapper
class NetWrapper(Module):
def __init__(self, net, output_dim, projection_hidden_size, projection_num_layers, layer = -2):
super().__init__()
self.net = net
self.layer = layer
self.projector = None
self.projection_hidden_size = projection_hidden_size
self.projection_num_layers = projection_num_layers
self.output_dim = output_dim
self.hidden = {}
self.hook_registered = False
def _find_layer(self):
if type(self.layer) == str:
modules = dict([*self.net.named_modules()])
return modules.get(self.layer, None)
elif type(self.layer) == int:
children = [*self.net.children()]
return children[self.layer]
return None
def _hook(self, _, input, output):
device = input[0].device
self.hidden[device] = output.flatten(1)
def _register_hook(self):
layer = self._find_layer()
assert layer is not None, f'hidden layer ({self.layer}) not found'
handle = layer.register_forward_hook(self._hook)
self.hook_registered = True
@singleton('projector')
def _get_projector(self, hidden):
_, dim = hidden.shape
projector = MLP(dim, self.output_dim, self.projection_num_layers, self.projection_hidden_size)
return projector.to(hidden)
def get_embedding(self, x):
if self.layer == -1:
return self.net(x)
if not self.hook_registered:
self._register_hook()
self.hidden.clear()
_ = self.net(x)
hidden = self.hidden[x.device]
self.hidden.clear()
assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
return hidden
def forward(self, x, return_projection = True):
embed = self.get_embedding(x)
if not return_projection:
return embed
projector = self._get_projector(embed)
return projector(embed), embed
# main class
class LeJEPA(nn.Module):
def __init__(
self,
net,
image_size,
hidden_layer = -2,
projection_hidden_size = 256,
num_classes_K = 65336,
projection_layers = 4,
local_upper_crop_scale = 0.4,
global_lower_crop_scale = 0.5,
target_loss_weight = 1.,
sigreg_loss_weight = 1.,
sigreg_loss_kwargs = dict(
num_slices = 1024,
domain = (-5, 5),
num_knots = 17
),
augment_fn = None,
augment_fn2 = None
):
super().__init__()
self.net = net
# default BYOL augmentation
DEFAULT_AUG = torch.nn.Sequential(
RandomApply(
T.ColorJitter(0.8, 0.8, 0.8, 0.2),
p = 0.3
),
T.RandomGrayscale(p=0.2),
T.RandomHorizontalFlip(),
RandomApply(
T.GaussianBlur((3, 3), (1.0, 2.0)),
p = 0.2
),
T.Normalize(
mean=torch.tensor([0.485, 0.456, 0.406]),
std=torch.tensor([0.229, 0.224, 0.225])),
)
self.augment1 = default(augment_fn, DEFAULT_AUG)
self.augment2 = default(augment_fn2, DEFAULT_AUG)
# local and global crops
self.local_crop = T.RandomResizedCrop((image_size, image_size), scale = (0.05, local_upper_crop_scale))
self.global_crop = T.RandomResizedCrop((image_size, image_size), scale = (global_lower_crop_scale, 1.))
self.encoder = NetWrapper(net, num_classes_K, projection_hidden_size, projection_layers, layer = hidden_layer)
self.target_loss_weight = target_loss_weight
self.sigreg_loss_weight = sigreg_loss_weight
self.sigreg_loss_kwargs = sigreg_loss_kwargs
# get device of network and make wrapper same device
device = get_module_device(net)
self.to(device)
# send a mock image tensor to instantiate singleton parameters
self.forward(torch.randn(2, 3, image_size, image_size, device=device))
def forward(
self,
x,
return_embedding = False,
return_projection = True
):
if return_embedding:
return self.encoder(x, return_projection = return_projection)
image_one, image_two = self.augment1(x), self.augment2(x)
local_image_one, local_image_two = self.local_crop(image_one), self.local_crop(image_two)
global_image_one, global_image_two = self.global_crop(image_one), self.global_crop(image_two)
local_images = torch.cat((local_image_one, local_image_two), dim = 0)
proj_locals, _ = self.encoder(local_images)
proj_local_one, proj_local_two = proj_locals.chunk(2, dim = 0)
with torch.no_grad():
global_images = torch.cat((global_image_one, global_image_two), dim = 0)
proj_globals, _ = self.encoder(global_images)
proj_global_one, proj_global_two = proj_globals.chunk(2, dim = 0)
# invariance loss
mse_loss = F.mse_loss(proj_local_one, proj_global_two) + F.mse_loss(proj_local_two, proj_global_one)
# sigreg loss
sreg_loss = sigreg_loss(proj_locals, **self.sigreg_loss_kwargs)
return mse_loss * self.target_loss_weight + sreg_loss * self.sigreg_loss_weight
# quick run
if __name__ == '__main__':
from vit_pytorch import ViT
model = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048
)
learner = LeJEPA(
model,
image_size = 256,
hidden_layer = 'to_latent', # layer name where output is hidden dimension
projection_hidden_size = 256, # projector network hidden dimension
projection_layers = 4, # number of layers in projection network
num_classes_K = 65336, # output dimension
target_loss_weight = 1.0,
sigreg_loss_weight = 1.0
)
opt = torch.optim.Adam(learner.parameters(), lr = 3e-4)
images = torch.randn(8, 3, 256, 256)
loss = learner(images)
opt.zero_grad()
loss.backward()
opt.step()
print('loss:', loss.item())

View File

@@ -182,7 +182,7 @@ class LeViT(nn.Module):
def forward(self, img):
x = self.conv_embedding(img)
x = self.backbone(x)
x = self.backbone(x)
x = self.pool(x)

View File

@@ -52,7 +52,7 @@ class MAE(nn.Module):
if self.encoder.pool == "cls":
tokens += self.encoder.pos_embedding[:, 1:(num_patches + 1)]
elif self.encoder.pool == "mean":
tokens += self.encoder.pos_embedding.to(device, dtype=tokens.dtype)
tokens += self.encoder.pos_embedding.to(device, dtype=tokens.dtype)
# calculate of patches needed to be masked, and get random indices, dividing it up for mask vs unmasked
@@ -87,7 +87,7 @@ class MAE(nn.Module):
mask_tokens = mask_tokens + self.decoder_pos_emb(masked_indices)
# concat the masked tokens to the decoder tokens and attend with decoder
decoder_tokens = torch.zeros(batch, num_patches, self.decoder_dim, device=device)
decoder_tokens[batch_range, unmasked_indices] = unmasked_decoder_tokens
decoder_tokens[batch_range, masked_indices] = mask_tokens

View File

@@ -77,7 +77,7 @@ class Dropsample(nn.Module):
def __init__(self, prob = 0):
super().__init__()
self.prob = prob
def forward(self, x):
device = x.device

View File

@@ -72,7 +72,7 @@ class Dropsample(Module):
def __init__(self, prob = 0):
super().__init__()
self.prob = prob
def forward(self, x):
device = x.device

View File

@@ -1,243 +1,243 @@
import torch
import torch.nn as nn
from einops import rearrange
from einops.layers.torch import Reduce
# helpers
def conv_1x1_bn(inp, oup):
return nn.Sequential(
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.SiLU()
)
def conv_nxn_bn(inp, oup, kernel_size=3, stride=1):
return nn.Sequential(
nn.Conv2d(inp, oup, kernel_size, stride, 1, bias=False),
nn.BatchNorm2d(oup),
nn.SiLU()
)
# classes
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.SiLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b p h n d -> b p n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
"""Transformer block described in ViT.
Paper: https://arxiv.org/abs/2010.11929
Based on: https://github.com/lucidrains/vit-pytorch
"""
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads, dim_head, dropout),
FeedForward(dim, mlp_dim, dropout)
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class MV2Block(nn.Module):
"""MV2 block described in MobileNetV2.
Paper: https://arxiv.org/pdf/1801.04381
Based on: https://github.com/tonylins/pytorch-mobilenet-v2
"""
def __init__(self, inp, oup, stride=1, expansion=4):
super().__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = int(inp * expansion)
self.use_res_connect = self.stride == 1 and inp == oup
if expansion == 1:
self.conv = nn.Sequential(
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
1, groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
)
else:
self.conv = nn.Sequential(
# pw
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
1, groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
)
def forward(self, x):
out = self.conv(x)
if self.use_res_connect:
out = out + x
return out
class MobileViTBlock(nn.Module):
def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):
super().__init__()
self.ph, self.pw = patch_size
self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
self.conv2 = conv_1x1_bn(channel, dim)
self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout)
self.conv3 = conv_1x1_bn(dim, channel)
self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)
def forward(self, x):
y = x.clone()
# Local representations
x = self.conv1(x)
x = self.conv2(x)
# Global representations
_, _, h, w = x.shape
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
x = self.transformer(x)
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)
# Fusion
x = self.conv3(x)
x = torch.cat((x, y), 1)
x = self.conv4(x)
return x
class MobileViT(nn.Module):
"""MobileViT.
Paper: https://arxiv.org/abs/2110.02178
Based on: https://github.com/chinhsuanwu/mobilevit-pytorch
"""
def __init__(
self,
image_size,
dims,
channels,
num_classes,
expansion=4,
kernel_size=3,
patch_size=(2, 2),
depths=(2, 4, 3)
):
super().__init__()
assert len(dims) == 3, 'dims must be a tuple of 3'
assert len(depths) == 3, 'depths must be a tuple of 3'
ih, iw = image_size
ph, pw = patch_size
assert ih % ph == 0 and iw % pw == 0
init_dim, *_, last_dim = channels
self.conv1 = conv_nxn_bn(3, init_dim, stride=2)
self.stem = nn.ModuleList([])
self.stem.append(MV2Block(channels[0], channels[1], 1, expansion))
self.stem.append(MV2Block(channels[1], channels[2], 2, expansion))
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
self.trunk = nn.ModuleList([])
self.trunk.append(nn.ModuleList([
MV2Block(channels[3], channels[4], 2, expansion),
MobileViTBlock(dims[0], depths[0], channels[5],
kernel_size, patch_size, int(dims[0] * 2))
]))
self.trunk.append(nn.ModuleList([
MV2Block(channels[5], channels[6], 2, expansion),
MobileViTBlock(dims[1], depths[1], channels[7],
kernel_size, patch_size, int(dims[1] * 4))
]))
self.trunk.append(nn.ModuleList([
MV2Block(channels[7], channels[8], 2, expansion),
MobileViTBlock(dims[2], depths[2], channels[9],
kernel_size, patch_size, int(dims[2] * 4))
]))
self.to_logits = nn.Sequential(
conv_1x1_bn(channels[-2], last_dim),
Reduce('b c h w -> b c', 'mean'),
nn.Linear(channels[-1], num_classes, bias=False)
)
def forward(self, x):
x = self.conv1(x)
for conv in self.stem:
x = conv(x)
for conv, attn in self.trunk:
x = conv(x)
x = attn(x)
return self.to_logits(x)
import torch
import torch.nn as nn
from einops import rearrange
from einops.layers.torch import Reduce
# helpers
def conv_1x1_bn(inp, oup):
return nn.Sequential(
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.SiLU()
)
def conv_nxn_bn(inp, oup, kernel_size=3, stride=1):
return nn.Sequential(
nn.Conv2d(inp, oup, kernel_size, stride, 1, bias=False),
nn.BatchNorm2d(oup),
nn.SiLU()
)
# classes
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.SiLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b p h n d -> b p n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
"""Transformer block described in ViT.
Paper: https://arxiv.org/abs/2010.11929
Based on: https://github.com/lucidrains/vit-pytorch
"""
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads, dim_head, dropout),
FeedForward(dim, mlp_dim, dropout)
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class MV2Block(nn.Module):
"""MV2 block described in MobileNetV2.
Paper: https://arxiv.org/pdf/1801.04381
Based on: https://github.com/tonylins/pytorch-mobilenet-v2
"""
def __init__(self, inp, oup, stride=1, expansion=4):
super().__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = int(inp * expansion)
self.use_res_connect = self.stride == 1 and inp == oup
if expansion == 1:
self.conv = nn.Sequential(
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
1, groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
)
else:
self.conv = nn.Sequential(
# pw
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
1, groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
)
def forward(self, x):
out = self.conv(x)
if self.use_res_connect:
out = out + x
return out
class MobileViTBlock(nn.Module):
def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):
super().__init__()
self.ph, self.pw = patch_size
self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
self.conv2 = conv_1x1_bn(channel, dim)
self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout)
self.conv3 = conv_1x1_bn(dim, channel)
self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)
def forward(self, x):
y = x.clone()
# Local representations
x = self.conv1(x)
x = self.conv2(x)
# Global representations
_, _, h, w = x.shape
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
x = self.transformer(x)
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)
# Fusion
x = self.conv3(x)
x = torch.cat((x, y), 1)
x = self.conv4(x)
return x
class MobileViT(nn.Module):
"""MobileViT.
Paper: https://arxiv.org/abs/2110.02178
Based on: https://github.com/chinhsuanwu/mobilevit-pytorch
"""
def __init__(
self,
image_size,
dims,
channels,
num_classes,
expansion=4,
kernel_size=3,
patch_size=(2, 2),
depths=(2, 4, 3)
):
super().__init__()
assert len(dims) == 3, 'dims must be a tuple of 3'
assert len(depths) == 3, 'depths must be a tuple of 3'
ih, iw = image_size
ph, pw = patch_size
assert ih % ph == 0 and iw % pw == 0
init_dim, *_, last_dim = channels
self.conv1 = conv_nxn_bn(3, init_dim, stride=2)
self.stem = nn.ModuleList([])
self.stem.append(MV2Block(channels[0], channels[1], 1, expansion))
self.stem.append(MV2Block(channels[1], channels[2], 2, expansion))
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
self.trunk = nn.ModuleList([])
self.trunk.append(nn.ModuleList([
MV2Block(channels[3], channels[4], 2, expansion),
MobileViTBlock(dims[0], depths[0], channels[5],
kernel_size, patch_size, int(dims[0] * 2))
]))
self.trunk.append(nn.ModuleList([
MV2Block(channels[5], channels[6], 2, expansion),
MobileViTBlock(dims[1], depths[1], channels[7],
kernel_size, patch_size, int(dims[1] * 4))
]))
self.trunk.append(nn.ModuleList([
MV2Block(channels[7], channels[8], 2, expansion),
MobileViTBlock(dims[2], depths[2], channels[9],
kernel_size, patch_size, int(dims[2] * 4))
]))
self.to_logits = nn.Sequential(
conv_1x1_bn(channels[-2], last_dim),
Reduce('b c h w -> b c', 'mean'),
nn.Linear(channels[-1], num_classes, bias=False)
)
def forward(self, x):
x = self.conv1(x)
for conv in self.stem:
x = conv(x)
for conv, attn in self.trunk:
x = conv(x)
x = attn(x)
return self.to_logits(x)

View File

@@ -110,7 +110,7 @@ class ViT(nn.Module):
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
@@ -178,7 +178,7 @@ class MP3(nn.Module):
attended_tokens = self.vit.transformer(tokens, tokens_unmasked)
logits = rearrange(self.mlp_head(attended_tokens), 'b n d -> (b n) d')
# Define labels
labels = repeat(torch.arange(num_patches, device = device), 'n -> (b n)', b = batch)
loss = F.cross_entropy(logits, labels)

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from functools import partial
from functools import partial, lru_cache
from typing import List
import torch
@@ -9,7 +9,6 @@ from torch import nn, Tensor
from torch.nn.utils.rnn import pad_sequence as orig_pad_sequence
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
@@ -28,6 +27,12 @@ def pair(t):
def divisible_by(numer, denom):
return (numer % denom) == 0
@lru_cache(maxsize=128)
def posemb_grid(ph, pw, device):
h_idx = torch.arange(ph, device=device).repeat_interleave(pw)
w_idx = torch.arange(pw, device=device).repeat(ph)
return torch.stack([h_idx, w_idx], dim=-1)
# auto grouping images
def group_images_by_max_seq_len(
@@ -117,8 +122,7 @@ class Attention(nn.Module):
self.q_norm = RMSNorm(heads, dim_head)
self.k_norm = RMSNorm(heads, dim_head)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.dropout_p = dropout
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
@@ -145,19 +149,22 @@ class Attention(nn.Module):
q = self.q_norm(q)
k = self.k_norm(k)
dots = torch.matmul(q, k.transpose(-1, -2))
# combine masks if both exist
if exists(mask) or exists(attn_mask):
if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
if exists(mask) and exists(attn_mask):
attn_mask = mask & attn_mask
elif exists(mask):
attn_mask = mask
if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
dots = dots.masked_fill(~mask, -torch.finfo(dots.dtype).max)
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask = attn_mask,
dropout_p = self.dropout_p if self.training else 0.,
scale = 1. # RMSNorm already includes sqrt(dim) scaling
)
if exists(attn_mask):
dots = dots.masked_fill(~attn_mask, -torch.finfo(dots.dtype).max)
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
@@ -281,42 +288,41 @@ class NaViT(nn.Module):
for images in batched_images:
num_images.append(len(images))
sequences = []
positions = []
image_ids = torch.empty((0,), device = device, dtype = torch.long)
for image_id, image in enumerate(images):
assert image.ndim ==3 and image.shape[0] == c
# compute patch dimensions for all images
patch_dims = []
for image in images:
assert image.ndim == 3 and image.shape[0] == c
image_dims = image.shape[-2:]
assert all([divisible_by(dim, p) for dim in image_dims]), f'height and width {image_dims} of images must be divisible by patch size {p}'
patch_dims.append((image_dims[0] // p, image_dims[1] // p))
ph, pw = map(lambda dim: dim // p, image_dims)
# extract patches for all images
sequences = [rearrange(img, 'c (h p1) (w p2) -> (h w) (c p1 p2)', p1=p, p2=p) for img in images]
pos = torch.stack(torch.meshgrid((
arange(ph),
arange(pw)
), indexing = 'ij'), dim = -1)
# compute positions - uses lru_cache to avoid redundant computation across forward passes
positions = [posemb_grid(ph, pw, device) for ph, pw in patch_dims]
pos = rearrange(pos, 'h w c -> (h w) c')
seq = rearrange(image, 'c (h p1) (w p2) -> (h w) (c p1 p2)', p1 = p, p2 = p)
seq_len = seq.shape[-2]
if has_token_dropout:
# handle token dropout
if has_token_dropout:
for i, (seq, pos) in enumerate(zip(sequences, positions)):
image_dims = images[i].shape[-2:]
token_dropout = self.calc_token_dropout(*image_dims)
seq_len = seq.shape[0]
num_keep = max(1, int(seq_len * (1 - token_dropout)))
keep_indices = torch.randn((seq_len,), device = device).topk(num_keep, dim = -1).indices
keep_indices = torch.randn((seq_len,), device=device).topk(num_keep, dim=-1).indices
sequences[i] = seq[keep_indices]
positions[i] = pos[keep_indices]
seq = seq[keep_indices]
pos = pos[keep_indices]
image_ids = F.pad(image_ids, (0, seq.shape[-2]), value = image_id)
sequences.append(seq)
positions.append(pos)
# build image_ids efficiently using repeat_interleave
patch_counts = [seq.shape[0] for seq in sequences]
image_ids = torch.repeat_interleave(
arange(len(images)),
torch.tensor(patch_counts, device=device)
)
batched_image_ids.append(image_ids)
batched_sequences.append(torch.cat(sequences, dim = 0))
batched_positions.append(torch.cat(positions, dim = 0))
batched_sequences.append(torch.cat(sequences, dim=0))
batched_positions.append(torch.cat(positions, dim=0))
# derive key padding mask
@@ -337,11 +343,11 @@ class NaViT(nn.Module):
# need to know how many images for final attention pooling
num_images = torch.tensor(num_images, device = device, dtype = torch.long)
num_images = torch.tensor(num_images, device = device, dtype = torch.long)
# to patches
x = self.to_patch_embedding(patches)
x = self.to_patch_embedding(patches)
# factorized 2d absolute positional embedding

View File

@@ -64,7 +64,7 @@ class Attention(Module):
self.to_out = nn.Linear(dim_inner, dim, bias = False)
def forward(
self,
self,
x,
context: Tensor | None = None
):

View File

@@ -64,7 +64,7 @@ class Attention(Module):
self.to_out = nn.Linear(dim_inner, dim, bias = False)
def forward(
self,
self,
x,
context: Tensor | None = None
):

View File

@@ -154,7 +154,7 @@ class PiT(nn.Module):
for ind, (layer_depth, layer_heads) in enumerate(zip(depth, heads)):
not_last = ind < (len(depth) - 1)
layers.append(Transformer(dim, layer_depth, layer_heads, dim_head, mlp_dim, dropout))
if not_last:

View File

@@ -146,7 +146,7 @@ class R2LTransformer(nn.Module):
region_tokens = rearrange(region_tokens, 'b c h w -> b (h w) c')
# calculate local relative positional bias
h_range = torch.arange(window_size_h, device = device)
w_range = torch.arange(window_size_w, device = device)

View File

@@ -187,7 +187,7 @@ class InteractiveWindowedSelfAttention(nn.Module):
out = rearrange(out, '(b x y) h (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz_h, y = width // wsz_w, w1 = wsz_h, w2 = wsz_w)
# add LIM output
# add LIM output
out = out + local_out

View File

@@ -26,7 +26,7 @@ def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
omega = 1. / (temperature ** omega)
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
return pe.type(dtype)
@@ -57,7 +57,7 @@ class Attend(nn.Module):
config = self.cuda_config if q.is_cuda else self.cpu_config
# flash attention - https://arxiv.org/abs/2205.14135
with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(q, k, v)

View File

@@ -34,7 +34,7 @@ def posemb_sincos_3d(patches, temperature = 10000, dtype = torch.float32):
z = z.flatten()[:, None] * omega[None, :]
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos(), z.sin(), z.cos()), dim = 1)
@@ -52,7 +52,7 @@ class Attend(Module):
def flash_attn(self, q, k, v):
# flash attention - https://arxiv.org/abs/2205.14135
with torch.backends.cuda.sdp_kernel(**self.config._asdict()):
out = F.scaled_dot_product_attention(q, k, v)

View File

@@ -35,7 +35,7 @@ def FeedForward(dim, hidden_dim):
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64):

View File

@@ -98,7 +98,7 @@ class SimpleViT(nn.Module):
h = image_height // patch_height,
w = image_width // patch_width,
dim = dim,
)
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

View File

@@ -26,7 +26,7 @@ def posemb_sincos_3d(patches, temperature = 10000, dtype = torch.float32):
z = z.flatten()[:, None] * omega[None, :]
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos(), z.sin(), z.cos()), dim = 1)

View File

@@ -0,0 +1,230 @@
import torch
from torch import nn
from torch.nn import Module, ModuleList
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
def last(arr):
return arr[-1]
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def divisible_by(num, den):
return (num % den) == 0
def posemb_sincos_2d(h, w, dim, temperature = 10000, dtype = torch.float32):
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing = 'ij')
assert divisible_by(dim, 4), 'feature dimension must be multiple of 4 for sincos emb'
omega = torch.arange(dim // 4) / (dim // 4 - 1)
omega = 1.0 / (temperature ** omega)
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
return pe.type(dtype)
# classes
class FeedForward(Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
def forward(self, x):
return self.net(x)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x, context = None):
x = self.norm(x)
context = default(context, x)
q = self.to_q(x)
k, v = self.to_kv(context).chunk(2, dim = -1)
q, k, v = tuple(rearrange(t, 'b n (h d) -> b h n d', h = self.heads) for t in (q, k, v))
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class AttentionPool(Module):
def __init__(self, dim, heads = 8, dim_head = 64, use_learned_query = True):
super().__init__()
self.use_learned_query = use_learned_query
self.norm_context = nn.LayerNorm(dim)
self.attn = Attention(dim, heads = heads, dim_head = dim_head)
if use_learned_query:
self.query = nn.Parameter(torch.randn(dim))
def forward(self, context, query = None):
batch = context.shape[0]
context = self.norm_context(context)
if self.use_learned_query:
q = repeat(self.query, 'd -> b 1 d', b = batch)
else:
q = query
return self.attn(q, context = context)
class AttentionResidual(Module):
"""
replaces the standard residual connection.
pools from a growing history of all previous outputs via attention,
then passes the result through the wrapped module (attn or ff).
the output is appended back to history (in-place list mutation).
"""
def __init__(self, fn, dim, heads = 8, dim_head = 64, use_learned_query = True):
super().__init__()
self.fn = fn
self.attn_pool = AttentionPool(dim, heads = heads, dim_head = dim_head, use_learned_query = use_learned_query)
def forward(self, history):
context = torch.stack(history, dim = 2)
b, n, l, d = context.shape
context = rearrange(context, 'b n l d -> (b n) l d')
last_out = last(history)
query = rearrange(last_out, 'b n d -> (b n) 1 d')
pooled = self.attn_pool(context, query = query)
pooled = rearrange(pooled, '(b n) 1 d -> b n d', b = b, n = n)
out = self.fn(pooled)
# mutate history for subsequent layers
history.append(out)
return history
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, use_learned_query = True):
super().__init__()
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
AttentionResidual(Attention(dim, heads = heads, dim_head = dim_head), dim, heads = heads, dim_head = dim_head, use_learned_query = use_learned_query),
AttentionResidual(FeedForward(dim, mlp_dim), dim, heads = heads, dim_head = dim_head, use_learned_query = use_learned_query)
]))
self.final_attn_pool = AttentionResidual(nn.LayerNorm(dim), dim, heads = heads, dim_head = dim_head, use_learned_query = use_learned_query)
def forward(self, tokens):
history = [tokens]
for attn_res, ff_res in self.layers:
history = attn_res(history)
history = ff_res(history)
history = self.final_attn_pool(history)
return last(history)
class SimpleViTAttnResidual(Module):
def __init__(
self,
*,
image_size,
patch_size,
num_classes,
dim,
depth,
heads,
mlp_dim,
channels = 3,
dim_head = 64,
use_learned_query = True
):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert divisible_by(image_height, patch_height) and divisible_by(image_width, patch_width), 'Image dimensions must be divisible by the patch size.'
patch_dim = channels * patch_height * patch_width
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.pos_embedding = posemb_sincos_2d(
h = image_height // patch_height,
w = image_width // patch_width,
dim = dim,
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, use_learned_query = use_learned_query)
self.to_latent = nn.Identity()
self.linear_head = nn.Linear(dim, num_classes)
def forward(self, img):
device, dtype = img.device, img.dtype
x = self.to_patch_embedding(img)
x += self.pos_embedding.to(device, dtype = dtype)
x = self.transformer(x)
x = x.mean(dim = 1)
x = self.to_latent(x)
return self.linear_head(x)
if __name__ == '__main__':
for use_learned_query in (True, False):
v = SimpleViTAttnResidual(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
use_learned_query = use_learned_query
)
img = torch.randn(2, 3, 256, 256)
preds = v(img)
assert preds.shape == (2, 1000)

View File

@@ -186,7 +186,7 @@ class SimpleViT(nn.Module):
h = image_height // patch_height,
w = image_width // patch_width,
dim = dim,
)
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, num_residual_streams)

View File

@@ -18,7 +18,7 @@ def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
omega = 1. / (temperature ** omega)
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
return pe.type(dtype)

View File

@@ -119,7 +119,7 @@ class SimpleViT(nn.Module):
h = image_height // patch_height,
w = image_width // patch_width,
dim = dim,
)
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

View File

@@ -105,7 +105,7 @@ class SimpleViT(nn.Module):
h = image_height // patch_height,
w = image_width // patch_width,
dim = dim,
)
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

View File

@@ -120,7 +120,7 @@ class SimpleViT(Module):
h = image_height // patch_height,
w = image_width // patch_width,
dim = dim,
)
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

View File

@@ -41,7 +41,7 @@ def posemb_sincos_2d(
omega = temperature ** -omega
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
pe = pe.type(dtype)
@@ -120,6 +120,12 @@ class Attention(Module):
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim_context, inner_dim * 2, bias = False)
self.to_out_gates = nn.Sequential(
nn.Linear(dim, heads),
Rearrange('b ... h -> b h ... 1'),
nn.Sigmoid()
)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
@@ -150,6 +156,9 @@ class Attention(Module):
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = out * self.to_out_gates(x)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
@@ -433,7 +442,8 @@ class VAAT(Module):
self_attn_heads = 4,
self_attn_dim_head = 32,
ast_layer_indices: tuple[int, ...] | None = None,
vit_layer_indices: tuple[int, ...] | None = None
vit_layer_indices: tuple[int, ...] | None = None,
num_advantage_bins = 0
):
super().__init__()
@@ -471,7 +481,7 @@ class VAAT(Module):
assert len(ast_layer_indices) == depth, f'number of ast layer indices {len(ast_layer_indices)} does not much the VAAT depth {depth}'
self.register_buffer('ast_layer_indices', tensor(vit_layer_indices), persistent = False)
self.register_buffer('ast_layer_indices', tensor(ast_layer_indices), persistent = False)
# handle maybe multiple frames
@@ -502,6 +512,14 @@ class VAAT(Module):
self.action_pos_emb = nn.Parameter(torch.randn(action_chunk_len, dim) * 1e-2)
# handle maybe advantage conditioning
self.has_advantages = num_advantage_bins > 0
self.num_advantage_bins = num_advantage_bins
if self.has_advantages:
self.advantage_emb = nn.Embedding(num_advantage_bins + 1, dim)
self.layers = ModuleList([])
for _ in range(depth):
@@ -531,14 +549,15 @@ class VAAT(Module):
video_or_image, # (b v? c t? h w) - batch, views [wrist + third person or more], channels, maybe time, height, width
audio_or_spec, # (b v? t) | (b v?f t) - batch, audio len | batch, spec freq, time
*,
extra = None, # (b d) - batch, dim extra
extra = None, # (b d) - batch, dim extra
tasks = None, # (b)
advantages = None,# (b)
actions = None, # (b k d) - batch, action chunk length, action dimension
return_hiddens = False,
freeze_vit = False,
freeze_ast = False
):
batch = video_or_image.shape[0]
batch, device = video_or_image.shape[0], video_or_image.device
return_loss = exists(actions)
# handle some various input dimensions
@@ -646,53 +665,66 @@ class VAAT(Module):
audio_context = rearrange(audio_hiddens, 'l b v n d -> l b (v n) d')
# get main action tokens and maybe append extra
# main action tokens
action_tokens = repeat(self.action_pos_emb, 'k d -> b k d', b = batch)
action_tokens = repeat(self.action_pos_emb, 'n d -> b n d', b = batch)
has_extra = exists(extra)
# maybe advantage tokens
if has_extra:
assert self.accept_extra_token
empty_token = action_tokens[:, 0:0]
extra_token = self.to_extra_token(extra)
maybe_advantage_embed = empty_token
action_tokens, packed_extra = pack([action_tokens, extra_token], 'b * d')
if self.has_advantages and exists(advantages):
if isinstance(advantages, int):
advantages = torch.full((batch,), advantages, device = device, dtype = torch.long)
maybe_advantage_embed = self.advantage_emb(advantages + 1)
# register tokens
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
register_tokens = empty_token
action_tokens, registers_packed_shape = pack((register_tokens, action_tokens), 'b * d')
if exists(self.register_tokens):
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
# cross attention
# extra
hiddens = [action_tokens]
maybe_extra_embed = empty_token
has_extra = exists(extra)
if has_extra:
assert self.accept_extra_token
maybe_extra_embed = self.to_extra_token(extra)
# pack all tokens for attention
tokens, ps = pack((register_tokens, maybe_advantage_embed, action_tokens, maybe_extra_embed), 'b * d')
# transformer
hiddens = [tokens]
for (maybe_film, maybe_self_attn, image_cross_attn, audio_cross_attn, ff), image_layer_context, audio_layer_context in zip(self.layers, image_context, audio_context):
if exists(tasks):
action_tokens = maybe_film(action_tokens, task_emb)
if exists(maybe_film) and exists(tasks):
tokens = maybe_film(tokens, task_emb)
action_tokens = image_cross_attn(action_tokens, image_layer_context) + action_tokens
tokens = image_cross_attn(tokens, image_layer_context) + tokens
action_tokens = audio_cross_attn(action_tokens, audio_layer_context) + action_tokens
tokens = audio_cross_attn(tokens, audio_layer_context) + tokens
if exists(maybe_self_attn):
action_tokens = maybe_self_attn(action_tokens) + action_tokens
tokens = maybe_self_attn(tokens) + tokens
action_tokens = ff(action_tokens) + action_tokens
tokens = ff(tokens) + tokens
hiddens.append(action_tokens)
hiddens.append(tokens)
# unpack registers
# unpack register, advantage, action, and extra tokens
_, action_tokens = unpack(action_tokens, registers_packed_shape, 'b * d')
# maybe unpack extra
if has_extra:
action_tokens, _ = unpack(action_tokens, packed_extra, 'b * d')
maybe_register_embed, maybe_advantage_embed, action_tokens, maybe_extra_embed = unpack(tokens, ps, 'b * d')
# norm and prediction
@@ -735,43 +767,51 @@ if __name__ == '__main__':
mlp_dim = 384 * 4
)
vaat = VAAT(
vit,
ast,
dim = 512,
depth = 9,
heads = 8,
dim_head = 64,
mlp_dim = 2048,
dim_action = 20,
action_chunk_len = 7,
time_seq_len = 4,
num_image_views = 2,
num_audio_views = 2,
num_tasks = 4,
add_self_attn = True,
dim_extra_token = 33, # extra token with some variable dimension
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
0, 0, 1, 1, 2, 2, 3, 3, 4
),
ast_layer_indices = (
1, 1, 1, 2, 2, 2, 3, 3, 3
for num_adv_bins in (0, 2, 10):
vaat = VAAT(
vit,
ast,
dim = 512,
depth = 9,
heads = 8,
dim_head = 64,
mlp_dim = 2048,
dim_action = 20,
action_chunk_len = 7,
time_seq_len = 4,
num_image_views = 2,
num_audio_views = 2,
num_tasks = 4,
num_advantage_bins = num_adv_bins,
add_self_attn = True,
dim_extra_token = 33, # extra token with some variable dimension
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
0, 0, 1, 1, 2, 2, 3, 3, 4
),
ast_layer_indices = (
1, 1, 1, 2, 2, 2, 3, 3, 3
)
)
)
images = torch.randn(2, 2, 3, 4, 256, 256) # (2 views with 4 frames)
audio = torch.randn(2, 2, 14_100 * 5)
images = torch.randn(2, 2, 3, 4, 256, 256) # (2 views with 4 frames)
audio = torch.randn(2, 2, 14_100 * 5)
tasks = torch.randint(0, 4, (2,))
extra = torch.randn(2, 33) # extra internal state
tasks = torch.randint(0, 4, (2,))
extra = torch.randn(2, 33) # extra internal state
actions = torch.randn(2, 7, 20) # actions for learning
# advantage conditioning
loss = vaat(images, audio, actions = actions, tasks = tasks, extra = extra, freeze_vit = True)
loss.backward()
advantages = None
if num_adv_bins > 0:
advantages = torch.randint(-1, num_adv_bins, (2,))
# after much training
actions = torch.randn(2, 7, 20) # actions for learning
pred_actions, hiddens = vaat(images, audio, tasks = tasks, extra = extra, return_hiddens = True)
loss = vaat(images, audio, actions = actions, advantages = advantages, tasks = tasks, extra = extra, freeze_vit = True)
loss.backward()
assert pred_actions.shape == (2, 7, 20)
# after much training
pred_actions, hiddens = vaat(images, audio, advantages = advantages, tasks = tasks, extra = extra, return_hiddens = True)
assert pred_actions.shape == (2, 7, 20)

View File

@@ -92,6 +92,12 @@ class Attention(Module):
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim_context, inner_dim * 2, bias = False)
self.to_out_gates = nn.Sequential(
nn.Linear(dim, heads),
Rearrange('b ... h -> b h ... 1'),
nn.Sigmoid()
)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
@@ -122,6 +128,8 @@ class Attention(Module):
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = out * self.to_out_gates(x) # https://arxiv.org/abs/2505.06708
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
@@ -270,7 +278,8 @@ class VAT(Module):
add_self_attn = True, # in the paper, they didn't have any ways for the action token to exchange information with the extra token, so we'll just add it as an option
self_attn_heads = 4,
self_attn_dim_head = 32,
vit_layer_indices: tuple[int, ...] | None = None
vit_layer_indices: tuple[int, ...] | None = None,
num_advantage_bins = 0
):
super().__init__()
@@ -316,6 +325,14 @@ class VAT(Module):
self.action_pos_emb = nn.Parameter(torch.randn(action_chunk_len, dim) * 1e-2)
# handle maybe advantage conditioning
self.has_advantages = num_advantage_bins > 0
self.num_advantage_bins = num_advantage_bins
if self.has_advantages:
self.advantage_emb = nn.Embedding(num_advantage_bins + 1, dim)
self.layers = ModuleList([])
for _ in range(depth):
@@ -343,13 +360,14 @@ class VAT(Module):
self,
video_or_image, # (b v? c t? h w) - batch, views [wrist + third person or more], channels, maybe time, height, width
*,
extra = None, # (b d) - batch, dim extra
extra = None, # (b d) - batch, dim extra
tasks = None, # (b)
advantages = None,# (b)
actions = None, # (b k d) - batch, action chunk length, action dimension
return_hiddens = False,
freeze_vit = False
):
batch = video_or_image.shape[0]
batch, device = video_or_image.shape[0], video_or_image.device
return_loss = exists(actions)
# handle some various input dimensions
@@ -415,51 +433,64 @@ class VAT(Module):
context = rearrange(hiddens, 'l b v t n d -> l b (v t n) d')
# get main action tokens and maybe append extra
# main action tokens
action_tokens = repeat(self.action_pos_emb, 'k d -> b k d', b = batch)
action_tokens = repeat(self.action_pos_emb, 'n d -> b n d', b = batch)
has_extra = exists(extra)
# maybe advantage tokens
if has_extra:
assert self.accept_extra_token
empty_token = action_tokens[:, 0:0]
extra_token = self.to_extra_token(extra)
maybe_advantage_embed = empty_token
action_tokens, packed_extra = pack([action_tokens, extra_token], 'b * d')
if self.has_advantages and exists(advantages):
if isinstance(advantages, int):
advantages = torch.full((batch,), advantages, device = device, dtype = torch.long)
maybe_advantage_embed = self.advantage_emb(advantages + 1)
# register tokens
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
register_tokens = empty_token
action_tokens, registers_packed_shape = pack((register_tokens, action_tokens), 'b * d')
if exists(self.register_tokens):
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
# cross attention
# extra
hiddens = [action_tokens]
maybe_extra_embed = empty_token
has_extra = exists(extra)
if has_extra:
assert self.accept_extra_token
maybe_extra_embed = self.to_extra_token(extra)
# pack all tokens for attention
tokens, ps = pack((register_tokens, maybe_advantage_embed, action_tokens, maybe_extra_embed), 'b * d')
# transformer
hiddens = [tokens]
for (maybe_film, maybe_self_attn, cross_attn, ff), layer_context in zip(self.layers, context):
if exists(tasks):
action_tokens = maybe_film(action_tokens, task_emb)
if exists(maybe_film) and exists(tasks):
tokens = maybe_film(tokens, task_emb)
action_tokens = cross_attn(action_tokens, layer_context) + action_tokens
tokens = cross_attn(tokens, layer_context) + tokens
if exists(maybe_self_attn):
action_tokens = maybe_self_attn(action_tokens) + action_tokens
tokens = maybe_self_attn(tokens) + tokens
action_tokens = ff(action_tokens) + action_tokens
tokens = ff(tokens) + tokens
hiddens.append(action_tokens)
hiddens.append(tokens)
# unpack registers
# unpack register, advantage, action, and extra tokens
_, action_tokens = unpack(action_tokens, registers_packed_shape, 'b * d')
# maybe unpack extra
if has_extra:
action_tokens, _ = unpack(action_tokens, packed_extra, 'b * d')
maybe_register_embed, maybe_advantage_embed, action_tokens, maybe_extra_embed = unpack(tokens, ps, 'b * d')
# norm and prediction
@@ -493,36 +524,44 @@ if __name__ == '__main__':
mlp_dim = 1024
)
vat = VAT(
vit,
dim = 512,
depth = 9,
heads = 8,
dim_head = 64,
mlp_dim = 2048,
dim_action = 20,
action_chunk_len = 7,
time_seq_len = 4,
num_views = 2,
num_tasks = 4,
add_self_attn = True,
dim_extra_token = 33, # extra token with some variable dimension
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
0, 0, 1, 1, 2, 2, 3, 3, 4
for num_adv_bins in (0, 2, 10):
vat = VAT(
vit,
dim = 512,
depth = 9,
heads = 8,
dim_head = 64,
mlp_dim = 2048,
dim_action = 20,
action_chunk_len = 7,
time_seq_len = 4,
num_views = 2,
num_tasks = 4,
num_advantage_bins = num_adv_bins,
add_self_attn = True,
dim_extra_token = 33, # extra token with some variable dimension
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
0, 0, 1, 1, 2, 2, 3, 3, 4
)
)
)
images = torch.randn(2, 2, 3, 4, 256, 256) # (2 views with 4 frames)
tasks = torch.randint(0, 4, (2,))
extra = torch.randn(2, 33) # extra internal state
images = torch.randn(2, 2, 3, 4, 256, 256) # (2 views with 4 frames)
tasks = torch.randint(0, 4, (2,))
extra = torch.randn(2, 33) # extra internal state
actions = torch.randn(2, 7, 20) # actions for learning
# advantage conditioning
loss = vat(images, actions = actions, tasks = tasks, extra = extra, freeze_vit = True)
loss.backward()
advantages = None
if num_adv_bins > 0:
advantages = torch.randint(-1, num_adv_bins, (2,))
# after much training
actions = torch.randn(2, 7, 20) # actions for learning
pred_actions, hiddens = vat(images, tasks = tasks, extra = extra, return_hiddens = True)
loss = vat(images, actions = actions, advantages = advantages, tasks = tasks, extra = extra, freeze_vit = True)
loss.backward()
assert pred_actions.shape == (2, 7, 20)
# after much training
pred_actions, hiddens = vat(images, advantages = advantages, tasks = tasks, extra = extra, return_hiddens = True)
assert pred_actions.shape == (2, 7, 20)

521
vit_pytorch/vat_siglip.py Normal file
View File

@@ -0,0 +1,521 @@
from __future__ import annotations
from contextlib import nullcontext
from pathlib import Path
import torch
import torch.nn.functional as F
from torch import nn, cat, stack, tensor, einsum
from torch.nn import Module, ModuleList
from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange
# helpers
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# attention
class Attention(Module):
def __init__(
self,
dim,
dim_context = None,
heads = 8,
dim_head = 64,
dropout = 0.,
norm_eps = 1e-6,
gate_attn = False
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
inner_dim = dim_head * heads
self.norm = nn.LayerNorm(dim, eps = norm_eps)
self.is_cross_attn = exists(dim_context)
dim_context = default(dim_context, dim)
self.norm_context = nn.LayerNorm(dim_context, eps = norm_eps) if self.is_cross_attn else None
self.to_q = nn.Linear(dim, inner_dim)
self.to_kv = nn.Linear(dim_context, inner_dim * 2)
self.to_out_gates = nn.Sequential(
nn.Linear(dim, heads),
Rearrange('b ... h -> b h ... 1'),
nn.Sigmoid()
) if gate_attn else None
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, context = None):
x = self.norm(x)
if self.is_cross_attn:
assert exists(context)
context = self.norm_context(context)
else:
context = x
q = self.to_q(x)
k, v = self.to_kv(context).chunk(2, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))
sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
if exists(self.to_out_gates):
out = out * self.to_out_gates(x) # https://arxiv.org/abs/2505.06708
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
def FeedForward(
dim,
dim_inner,
norm_eps = 1e-6
):
return nn.Sequential(
nn.LayerNorm(dim, eps = norm_eps),
nn.Linear(dim, dim_inner),
nn.GELU(approximate = 'tanh'),
nn.Linear(dim_inner, dim)
)
class SigLIP(Module):
def __init__(
self,
image_size = 224,
patch_size = 14,
dim = 1152,
depth = 27,
heads = 16,
mlp_dim = 4304,
norm_eps = 1e-6
):
super().__init__()
self.dim = dim
self.depth = depth
num_patches = (image_size // patch_size) ** 2
dim_head = dim // heads
self.to_patch_embed = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
nn.Linear(patch_size * patch_size * 3, dim)
)
self.pos_embed = nn.Parameter(torch.randn(num_patches, dim))
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, norm_eps = norm_eps),
FeedForward(dim = dim, dim_inner = mlp_dim, norm_eps = norm_eps)
]))
self.norm = nn.LayerNorm(dim, eps = norm_eps)
def forward(self, x, return_hiddens = False):
x = self.to_patch_embed(x)
num_patches = x.shape[1]
x = x + self.pos_embed[:num_patches]
hiddens = []
for attn, ff in self.layers:
hiddens.append(x)
x = attn(x) + x
x = ff(x) + x
out = self.norm(x)
if not return_hiddens:
return out
return out, stack(hiddens)
class FiLM(Module):
def __init__(self, dim):
super().__init__()
proj = nn.Linear(dim, dim * 2)
self.to_gamma_beta = nn.Sequential(
proj,
Rearrange('b (two d) -> two b 1 d', two = 2)
)
nn.init.zeros_(proj.weight)
nn.init.zeros_(proj.bias)
def forward(self, tokens, cond):
gamma, beta = self.to_gamma_beta(cond)
return tokens * gamma + beta
class SigLIPVAT(Module):
def __init__(
self,
*,
dim = 512,
depth = 27,
heads = 8,
dim_head = 64,
dim_action = 32,
mlp_dim = 2048,
num_views = 1,
num_tasks = None,
dim_extra_token = None,
num_register_tokens = 4,
action_chunk_len = 50,
time_seq_len = 1,
dropout = 0.,
add_self_attn = True,
self_attn_heads = 4,
self_attn_dim_head = 32,
vit_layer_indices: tuple[int, ...] | None = None,
num_advantage_bins = 0,
siglip_image_size = 224,
siglip_patch_size = 14,
siglip_dim = 1152,
siglip_depth = 27,
siglip_heads = 16,
siglip_mlp_dim = 4304,
siglip_norm_eps = 1e-6,
):
super().__init__()
self.vit = SigLIP(
image_size = siglip_image_size,
patch_size = siglip_patch_size,
dim = siglip_dim,
depth = siglip_depth,
heads = siglip_heads,
mlp_dim = siglip_mlp_dim,
norm_eps = siglip_norm_eps
)
vit_dim = siglip_dim
self.vit_dim = vit_dim
vit_layer_indices = default(vit_layer_indices, tuple(range(depth)))
assert len(vit_layer_indices) == depth, f'number of vit layer indices {len(vit_layer_indices)} does not much the VAT depth {depth}'
self.register_buffer('layer_indices', tensor(vit_layer_indices), persistent = False)
# handle maybe multiple frames
is_video = time_seq_len > 1
self.is_video = is_video
self.time_seq_len = time_seq_len
self.time_pos_emb = nn.Parameter(torch.randn(time_seq_len, vit_dim) * 1e-2) if is_video else None
# maybe view embeddings
self.view_emb = nn.Parameter(torch.randn(num_views, vit_dim) * 1e-2) if exists(num_views) and num_views > 1 else None
# handle maybe task conditioning
self.has_tasks = exists(num_tasks)
if self.has_tasks:
self.task_emb = nn.Parameter(torch.randn(num_tasks, dim) * 1e-2)
# register tokens
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
# to action tokens
self.action_pos_emb = nn.Parameter(torch.randn(action_chunk_len, dim) * 1e-2)
# handle maybe advantage conditioning
self.has_advantages = num_advantage_bins > 0
self.num_advantage_bins = num_advantage_bins
if self.has_advantages:
self.advantage_emb = nn.Embedding(num_advantage_bins + 1, dim)
self.layers = ModuleList([])
for _ in range(depth):
maybe_film = FiLM(dim = dim) if self.has_tasks else None
maybe_self_attn = Attention(dim = dim, heads = self_attn_heads, dim_head = self_attn_dim_head, dropout = dropout) if add_self_attn else None
self.layers.append(ModuleList([
maybe_film,
maybe_self_attn,
Attention(dim = dim, dim_context = vit_dim, heads = heads, dim_head = dim_head, dropout = dropout, gate_attn = True),
FeedForward(dim = dim, dim_inner = mlp_dim)
]))
self.final_norm = nn.LayerNorm(dim)
self.to_pred_action = nn.Linear(dim, dim_action, bias = False)
# handle the extra token
self.accept_extra_token = exists(dim_extra_token)
if exists(dim_extra_token):
self.to_extra_token = nn.Linear(dim_extra_token, dim)
def load_siglip(
self,
repo_id = 'google/siglip-so400m-patch14-224',
folder = 'checkpoints/siglip'
):
folder = Path(folder)
if not folder.exists():
from huggingface_hub import snapshot_download
snapshot_download(
repo_id = repo_id,
local_dir = folder,
allow_patterns = ['config.json', 'model.safetensors']
)
from safetensors import safe_open
weights_path = folder / 'model.safetensors'
# Auto-detect prefix based on keys
with safe_open(weights_path, framework = 'pt') as f:
keys = f.keys()
vi_p = ''
if any(k.startswith('paligemma_with_expert.paligemma.model.vision_tower.vision_model') for k in keys):
vi_p = 'paligemma_with_expert.paligemma.model.vision_tower.vision_model.'
elif any(k.startswith('vision_model') for k in keys):
vi_p = 'vision_model.'
pz_state = self.vit.state_dict()
def copy_weight_bias(pz_prefix, vi_prefix):
pz_state[f'{pz_prefix}.weight'].copy_(f.get_tensor(f'{vi_prefix}.weight'))
pz_state[f'{pz_prefix}.bias'].copy_(f.get_tensor(f'{vi_prefix}.bias'))
# patch embedding
patch_weight = rearrange(f.get_tensor(f'{vi_p}embeddings.patch_embedding.weight'), 'd c h w -> d (h w c)')
pz_state['to_patch_embed.1.weight'].copy_(patch_weight)
pz_state['to_patch_embed.1.bias'].copy_(f.get_tensor(f'{vi_p}embeddings.patch_embedding.bias'))
# position embedding
pz_state['pos_embed'].copy_(f.get_tensor(f'{vi_p}embeddings.position_embedding.weight'))
# transformer layers
for i in range(self.vit.depth):
v_pi = f'{vi_p}encoder.layers.{i}'
v_pz = f'layers.{i}'
# attention
copy_weight_bias(f'{v_pz}.0.norm', f'{v_pi}.layer_norm1')
copy_weight_bias(f'{v_pz}.0.to_q', f'{v_pi}.self_attn.q_proj')
vk, vv = [f.get_tensor(f'{v_pi}.self_attn.{x}_proj.weight') for x in ('k', 'v')]
bk, bv = [f.get_tensor(f'{v_pi}.self_attn.{x}_proj.bias') for x in ('k', 'v')]
pz_state[f'{v_pz}.0.to_kv.weight'].copy_(cat((vk, vv), dim = 0))
pz_state[f'{v_pz}.0.to_kv.bias'].copy_(cat((bk, bv), dim = 0))
copy_weight_bias(f'{v_pz}.0.to_out.0', f'{v_pi}.self_attn.out_proj')
# feedforward
copy_weight_bias(f'{v_pz}.1.0', f'{v_pi}.layer_norm2')
copy_weight_bias(f'{v_pz}.1.1', f'{v_pi}.mlp.fc1')
copy_weight_bias(f'{v_pz}.1.3', f'{v_pi}.mlp.fc2')
# post-layernorm
copy_weight_bias('norm', f'{vi_p}post_layernorm')
self.vit.load_state_dict(pz_state)
print(f'Successfully loaded SigLIP weights from {repo_id}')
def forward(
self,
video_or_image, # (b v? c t? h w) - batch, views [wrist + third person or more], channels, maybe time, height, width
*,
extra = None, # (b d) - batch, dim extra
tasks = None, # (b)
advantages = None,# (b)
actions = None, # (b k d) - batch, action chunk length, action dimension
return_hiddens = False,
freeze_vit = False
):
batch, device = video_or_image.shape[0], video_or_image.device
return_loss = exists(actions)
# handle some various input dimensions
if video_or_image.ndim == 4:
video_or_image = rearrange(video_or_image, 'b 1 c h w')
if video_or_image.ndim == 5:
video_or_image = rearrange(video_or_image, 'b v c h w -> b v c 1 h w')
assert video_or_image.shape[3] == self.time_seq_len
# to images
images = rearrange(video_or_image, 'b v c t h w -> b v t c h w')
images, packed_shape = pack([images], '* c h w')
# get representation trajectory from vit
vit_forward_context = torch.no_grad if freeze_vit else nullcontext
with vit_forward_context():
embed, hiddens = self.vit(images, return_hiddens = True)
hiddens = cat((hiddens, embed[None, ...]))
# extract the hiddens needed for the action cross attention
hiddens = hiddens[self.layer_indices]
# pack temporarily for embedding
hiddens, = unpack(hiddens, packed_shape, 'l * n d') # l for layers
# maybe add time embeddings
if self.is_video:
time_pos_emb = rearrange(self.time_pos_emb, 't d -> t 1 d')
hiddens = hiddens + time_pos_emb
# maybe view embeddings
if exists(self.view_emb):
view_emb = rearrange(self.view_emb, 'v d -> v 1 1 d')
hiddens = hiddens + view_emb
# maybe tasks
if exists(tasks):
task_emb = self.task_emb[tasks]
# cross from actions to representation trajectory
context = rearrange(hiddens, 'l b v t n d -> l b (v t n) d')
# main action tokens
action_tokens = repeat(self.action_pos_emb, 'n d -> b n d', b = batch)
# maybe advantage tokens
empty_token = action_tokens[:, 0:0]
maybe_advantage_embed = empty_token
if self.has_advantages and exists(advantages):
if isinstance(advantages, int):
advantages = torch.full((batch,), advantages, device = device, dtype = torch.long)
maybe_advantage_embed = self.advantage_emb(advantages + 1)
# register tokens
register_tokens = empty_token
if exists(self.register_tokens):
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
# extra
maybe_extra_embed = empty_token
has_extra = exists(extra)
if has_extra:
maybe_extra_embed = self.to_extra_token(extra)
# pack all tokens for attention
tokens, ps = pack((register_tokens, maybe_advantage_embed, action_tokens, maybe_extra_embed), 'b * d')
# transformer
vat_hiddens = [tokens]
for (maybe_film, maybe_self_attn, cross_attn, ff), layer_context in zip(self.layers, context):
if exists(maybe_film) and exists(tasks):
tokens = maybe_film(tokens, task_emb)
tokens = cross_attn(tokens, layer_context) + tokens
if exists(maybe_self_attn):
tokens = maybe_self_attn(tokens) + tokens
tokens = ff(tokens) + tokens
vat_hiddens.append(tokens)
# unpack register, advantage, action, and extra tokens
maybe_register_embed, maybe_advantage_embed, action_tokens, maybe_extra_embed = unpack(tokens, ps, 'b * d')
# norm and prediction
action_tokens = self.final_norm(action_tokens)
pred_action = self.to_pred_action(action_tokens)
if not return_loss:
if not return_hiddens:
return pred_action
return pred_action, stack(vat_hiddens)
assert pred_action.shape[1] == actions.shape[1]
return F.l1_loss(pred_action, actions)
# quick test
if __name__ == '__main__':
for num_adv_bins in (0, 2, 10):
vat = SigLIPVAT(
num_tasks = 4,
dim_extra_token = 32,
time_seq_len = 2,
num_views = 2,
depth = 4,
num_advantage_bins = num_adv_bins,
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
0, 1, 26, 27
)
)
vat.load_siglip() # load siglip weights from hf
# inputs
images = torch.randn(1, 2, 3, 2, 224, 224) # (b, v, c, t, h, w)
tasks = torch.randint(0, 4, (1,))
extra = torch.randn(1, 32)
# advantage conditioning
advantages = None
if num_adv_bins > 0:
advantages = torch.randint(-1, num_adv_bins, (1,))
actions = torch.randn(1, 50, 32) # actions for learning
loss = vat(images, actions = actions, advantages = advantages, tasks = tasks, extra = extra, freeze_vit = True)
loss.backward()
# after much training
pred_actions = vat(images, advantages = advantages, tasks = tasks, extra = extra)
assert pred_actions.shape == (1, 50, 32)

View File

@@ -113,7 +113,7 @@ class ViT(Module):
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes)
self.mlp_head = nn.Linear(dim, num_classes) if num_classes > 0 else None
def forward(self, img):
batch = img.shape[0]
@@ -129,6 +129,9 @@ class ViT(Module):
x = self.transformer(x)
if self.mlp_head is None:
return x
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)

View File

@@ -31,7 +31,7 @@ class FeedForward(Module):
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
@@ -40,31 +40,31 @@ class Attention(Module):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
@@ -79,7 +79,7 @@ class Transformer(Module):
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
@@ -105,73 +105,73 @@ class ViTND(Module):
emb_dropout: float = 0.
):
super().__init__()
assert 1 <= ndim <= 7, 'ndim must be between 1 and 7'
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.ndim = ndim
self.pool = pool
input_shape = ensure_tuple(input_shape, ndim)
patch_size = ensure_tuple(patch_size, ndim)
for i, (inp_dim, patch_dim) in enumerate(zip(input_shape, patch_size)):
assert inp_dim % patch_dim == 0, f'Input dimension {i} ({inp_dim}) must be divisible by patch size ({patch_dim})'
num_patches_per_dim = [inp_dim // patch_dim for inp_dim, patch_dim in zip(input_shape, patch_size)]
num_patches = 1
for n in num_patches_per_dim:
num_patches *= n
patch_dim = channels
for p in patch_size:
patch_dim *= p
dim_names = 'fghijkl'[:ndim]
input_dims = [f'({d} p{i})' for i, d in enumerate(dim_names)]
patch_dims = [f'p{i}' for i in range(ndim)]
input_pattern = f'b c {join(input_dims)}'
output_pattern = f'b ({join(dim_names)}) ({join(patch_dims)} c)'
rearrange_str = f'{input_pattern} -> {output_pattern}'
rearrange_kwargs = {f'p{i}': p for i, p in enumerate(patch_size)}
self.to_patch_embedding = nn.Sequential(
Rearrange(rearrange_str, **rearrange_kwargs),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes)
def forward(self, x):
x = self.to_patch_embedding(x)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim = 1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x)
x = x[:, 1:].mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)
if __name__ == '__main__':
model = ViTND(
ndim = 4,
input_shape = (8, 16, 32, 64),
@@ -185,7 +185,7 @@ if __name__ == '__main__':
dropout = 0.1,
emb_dropout = 0.1
)
occupancy_time = torch.randn(2, 3, 8, 16, 32, 64)
logits = model(occupancy_time)

353
vit_pytorch/vit_nd_pope.py Normal file
View File

@@ -0,0 +1,353 @@
from __future__ import annotations
import torch
import torch.nn.functional as F
from torch import pi, nn, arange, cat, stack, Tensor
from torch.nn import Module, ModuleList
from torch.amp import autocast
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange
# helpers
def exists(val):
return val is not None
def l2norm(t):
return F.normalize(t, dim = -1, p = 2)
def join(arr, delimiter = ' '):
return delimiter.join(arr)
def ensure_tuple(t, length):
if isinstance(t, (tuple, list)):
assert len(t) == length, f'Expected tuple of length {length}, got {len(t)}'
return tuple(t)
return (t,) * length
# golden gate rotary - Jerry Xiong, PhD student at UIUC
# https://jerryxio.ng/posts/nd-rope/
# but using polar version instead
# Gopalakrishnan et al. https://arxiv.org/abs/2509.10534
def _phi(m: int) -> float:
x = 2.0
for _ in range(10):
x = (1 + x) ** (1.0 / (m + 1.0))
return x
def make_directions(n: int, d: int) -> Tensor:
g = _phi(d)
alpha = (1.0 / g) ** arange(1, d + 1, dtype = torch.float64)
i = arange(1, n + 1, dtype = torch.float64).unsqueeze(1)
z = torch.fmod(i * alpha, 1.0)
directions = torch.erfinv(2.0 * z - 1.0)
directions = l2norm(directions)
return directions.float()
class GoldenGatePoPENd(Module):
def __init__(
self,
dim_pos: int,
heads: int,
dim_head: int,
min_freq: float = 1.0,
max_freq: float = 10000.0,
p_zero_freqs: float = 0.0, # proportion of frequencies set to 0
init_learned_bias_uniform = False
):
super().__init__()
n_freqs = dim_head
n_zero_freqs = round(p_zero_freqs * n_freqs)
omega = cat((
torch.zeros(n_zero_freqs),
min_freq * (max_freq / min_freq) ** torch.linspace(0, 1, n_freqs - n_zero_freqs),
))
directions = rearrange(
make_directions(heads * n_freqs, dim_pos),
'(h f) p -> h f p',
h = heads
)
omega_expanded = rearrange(omega, 'f -> f 1')
self.register_buffer('freqs', directions * omega_expanded) # shape: (h, f, p)
self.learned_bias = nn.Parameter(torch.zeros(heads, dim_head))
if init_learned_bias_uniform:
self.learned_bias.uniform_(-2. * pi, 0.)
@autocast('cuda', enabled = False)
def forward(self, pos):
freqs = rearrange(self.freqs, 'h f p -> 1 h 1 f p')
positions = rearrange(pos.float(), 'b n p -> b 1 n 1 p')
# compute theta for each (batch, head, seq, freq)
theta = reduce(freqs * positions, 'b h n f p -> b h n f', 'sum')
bias = self.learned_bias.clamp(-2. * pi, 0.)
bias = rearrange(bias, 'h d -> h 1 d')
return theta, bias
@autocast('cuda', enabled = False)
def apply_polar_pos_emb(t, freqs):
orig_dtype = t.dtype
t = t.float()
t = F.softplus(t)
out = cat((t * freqs.cos(), t * freqs.sin()), dim = -1)
return out.type(orig_dtype)
# classes
class FeedForward(Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qk = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_v = nn.Linear(dim, inner_dim, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x, polar_pos_emb = None):
x = self.norm(x)
qkv = (*self.to_qk(x).chunk(2, dim = -1), self.to_v(x))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
if exists(polar_pos_emb):
freqs, bias = polar_pos_emb
q = apply_polar_pos_emb(q, freqs)
k = apply_polar_pos_emb(k, freqs + bias)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., polar_emb = None):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.polar_emb = polar_emb
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x, pos = None):
# pope embedding
polar_pos_emb = None
if exists(pos) and exists(self.polar_emb):
polar_pos_emb = self.polar_emb(pos)
# transformer layers
for attn, ff in self.layers:
x = attn(x, polar_pos_emb) + x
x = ff(x) + x
return self.norm(x)
class ViTND(Module):
def __init__(
self,
*,
ndim: int,
input_shape: int | tuple[int, ...],
patch_size: int | tuple[int, ...],
num_classes: int,
dim: int,
depth: int,
heads: int,
mlp_dim: int,
channels: int = 3,
dim_head: int = 64,
dropout: float = 0.,
emb_dropout: float = 0.,
pope_min_freq: float = 1.0,
pope_max_freq: float = 10000.0,
pope_p_zero_freqs: float = 0.0,
pope_init_learned_bias_uniform = False
):
super().__init__()
assert 1 <= ndim <= 7, 'ndim must be between 1 and 7'
self.ndim = ndim
input_shape = ensure_tuple(input_shape, ndim)
patch_size = ensure_tuple(patch_size, ndim)
for i, (inp_dim, patch_dim) in enumerate(zip(input_shape, patch_size)):
assert inp_dim % patch_dim == 0, f'Input dimension {i} ({inp_dim}) must be divisible by patch size ({patch_dim})'
num_patches_per_dim = [inp_dim // patch_dim for inp_dim, patch_dim in zip(input_shape, patch_size)]
num_patches = 1
for n in num_patches_per_dim:
num_patches *= n
patch_dim = channels
for p in patch_size:
patch_dim *= p
dim_names = 'fghijkl'[:ndim]
input_dims = [f'({d} p{i})' for i, d in enumerate(dim_names)]
patch_dims = [f'p{i}' for i in range(ndim)]
input_pattern = f'b c {join(input_dims)}'
output_pattern = f'b {join(dim_names)} ({join(patch_dims)} c)'
rearrange_str = f'{input_pattern} -> {output_pattern}'
rearrange_kwargs = {f'p{i}': p for i, p in enumerate(patch_size)}
self.to_patch_embedding = nn.Sequential(
Rearrange(rearrange_str, **rearrange_kwargs),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.dropout = nn.Dropout(emb_dropout)
# golden gate pope
self.polar_emb = GoldenGatePoPENd(
dim_pos = ndim,
heads = heads,
dim_head = dim_head,
min_freq = pope_min_freq,
max_freq = pope_max_freq,
p_zero_freqs = pope_p_zero_freqs,
init_learned_bias_uniform = pope_init_learned_bias_uniform
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, polar_emb = self.polar_emb)
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes)
def muon_parameters(self):
params = []
for m in self.modules():
if isinstance(m, Attention):
params.extend([
m.to_v.weight,
m.to_out[0].weight
])
elif isinstance(m, FeedForward):
params.extend([
m.net[1].weight,
m.net[-2].weight
])
return params
def forward(
self,
x,
return_embed = False
):
x = self.to_patch_embedding(x) # (b, *spatial_dims, patch_dim)
batch, *spatial_dims, _, device = *x.shape, x.device
# Generate position coordinates
grids = [arange(d, device = device, dtype = torch.float32) for d in spatial_dims]
grid = torch.meshgrid(*grids, indexing = 'ij')
pos = stack(grid, dim = -1) # (*spatial_dims, ndim)
# flatten spatial dimensions for attention with nd rotary
pos = repeat(pos, '... p -> b (...) p', b = batch)
x, packed_shape = pack([x], 'b * d')
x = self.dropout(x)
embed = self.transformer(x, pos)
# return the embed with reconstituted patch shape
if return_embed:
embed, = unpack(embed, packed_shape, 'b * d')
return embed
# pooling to logits
pooled = reduce(embed, 'b n d -> b d', 'mean')
pooled = self.to_latent(pooled)
return self.mlp_head(pooled)
if __name__ == '__main__':
model = ViTND(
ndim = 5,
input_shape = (4, 8, 16, 32, 64),
patch_size = (2, 2, 4, 4, 8),
num_classes = 1000,
dim = 512,
depth = 6,
heads = 8,
mlp_dim = 2048,
channels = 3,
dropout = 0.1,
emb_dropout = 0.1
)
data = torch.randn(3, 3, 4, 8, 16, 32, 64)
logits = model(data)
embed = model(data, return_embed = True)
assert embed.shape == (3, 2, 4, 4, 8, 8, 512)

View File

@@ -75,23 +75,23 @@ class GoldenGateRoPENd(Module):
# input shape: (b, h, n, d) where d = head_dim
# pos shape: (b, n, p) where p = pos_dim
# self.freqs shape: (h, f, p) where f = d // 2
x, y = input.float().chunk(2, dim = -1) # both (b, h, n, f)
# Expand dimensions for broadcasting
freqs = rearrange(self.freqs, 'h f p -> 1 h 1 f p')
positions = rearrange(pos.float(), 'b n p -> b 1 n 1 p')
# Compute theta for each (batch, head, seq, freq)
theta = reduce(freqs * positions, 'b h n f p -> b h n f', 'sum')
cos_theta = torch.cos(theta)
sin_theta = torch.sin(theta)
# Apply rotation
x_out = x * cos_theta - y * sin_theta
y_out = x * sin_theta + y * cos_theta
output = cat((x_out, y_out), dim=-1)
return output.type_as(input)
@@ -108,7 +108,7 @@ class FeedForward(Module):
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
@@ -117,15 +117,15 @@ class Attention(Module):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.rotary_emb = rotary_emb
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qk = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_v = nn.Linear(dim, inner_dim, bias = False)
@@ -133,24 +133,24 @@ class Attention(Module):
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x, pos = None):
x = self.norm(x)
qkv = (*self.to_qk(x).chunk(2, dim = -1), self.to_v(x))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
# Apply rotary embeddings if available
if exists(self.rotary_emb):
assert exists(pos)
q = self.rotary_emb(q, pos)
k = self.rotary_emb(k, pos)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
@@ -165,7 +165,7 @@ class Transformer(Module):
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, rotary_emb = rotary_emb),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x, pos = None):
for attn, ff in self.layers:
x = attn(x, pos) + x
@@ -193,45 +193,45 @@ class ViTND(Module):
rope_p_zero_freqs: float = 0.0
):
super().__init__()
assert 1 <= ndim <= 7, 'ndim must be between 1 and 7'
self.ndim = ndim
input_shape = ensure_tuple(input_shape, ndim)
patch_size = ensure_tuple(patch_size, ndim)
for i, (inp_dim, patch_dim) in enumerate(zip(input_shape, patch_size)):
assert inp_dim % patch_dim == 0, f'Input dimension {i} ({inp_dim}) must be divisible by patch size ({patch_dim})'
num_patches_per_dim = [inp_dim // patch_dim for inp_dim, patch_dim in zip(input_shape, patch_size)]
num_patches = 1
for n in num_patches_per_dim:
num_patches *= n
patch_dim = channels
for p in patch_size:
patch_dim *= p
dim_names = 'fghijkl'[:ndim]
input_dims = [f'({d} p{i})' for i, d in enumerate(dim_names)]
patch_dims = [f'p{i}' for i in range(ndim)]
input_pattern = f'b c {join(input_dims)}'
output_pattern = f'b {join(dim_names)} ({join(patch_dims)} c)'
rearrange_str = f'{input_pattern} -> {output_pattern}'
rearrange_kwargs = {f'p{i}': p for i, p in enumerate(patch_size)}
self.to_patch_embedding = nn.Sequential(
Rearrange(rearrange_str, **rearrange_kwargs),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.dropout = nn.Dropout(emb_dropout)
# Create rotary embeddings
self.rotary_emb = GoldenGateRoPENd(
dim_pos = ndim,
@@ -241,12 +241,12 @@ class ViTND(Module):
rope_max_freq = rope_max_freq,
rope_p_zero_freqs = rope_p_zero_freqs
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, rotary_emb = self.rotary_emb)
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes)
def muon_parameters(self):
params = []
@@ -270,9 +270,9 @@ class ViTND(Module):
return_embed = False
):
x = self.to_patch_embedding(x) # (b, *spatial_dims, patch_dim)
batch, *spatial_dims, _, device = *x.shape, x.device
# Generate position coordinates
grids = [arange(d, device = device, dtype = torch.float32) for d in spatial_dims]
@@ -280,12 +280,12 @@ class ViTND(Module):
pos = stack(grid, dim = -1) # (*spatial_dims, ndim)
# flatten spatial dimensions for attention with nd rotary
pos = repeat(pos, '... p -> b (...) p', b = batch)
x, packed_shape = pack([x], 'b * d')
x = self.dropout(x)
embed = self.transformer(x, pos)
# return the embed with reconstituted patch shape
@@ -303,7 +303,7 @@ class ViTND(Module):
if __name__ == '__main__':
model = ViTND(
ndim = 5,
input_shape = (4, 8, 16, 32, 64),

View File

@@ -231,4 +231,4 @@ if __name__ == '__main__':
decorr_loss = DecorrelationLoss(0.0001, soft_validate_num_sampled = True)
out = decorr_loss(hiddens)
assert out.item() == 0
assert out.item() == 0

View File

@@ -0,0 +1,217 @@
from __future__ import annotations
import torch
from torch import nn
from torch.nn import Module, ModuleList
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# functions
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# classes
class FeedForward(Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim, bias = False),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.norm = nn.LayerNorm(dim, bias = False)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(Module):
def __init__(
self,
dim,
depth,
heads,
dim_head,
mlp_dim,
dropout = 0.,
keel_residual_scale = None
):
super().__init__()
assert depth > 1
self.layers = ModuleList([])
for _ in range(depth):
self.layers.extend([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
])
num_layers = depth * 2
self.keel_residual_scale = default(keel_residual_scale, num_layers)
self.post_norms = ModuleList([nn.LayerNorm(dim, bias = False) for _ in range(num_layers - 1)])
def forward(self, x):
residual_scale = self.keel_residual_scale
for layer_ind, layer in enumerate(self.layers):
first_layer = layer_ind == 0
residual = x
out = layer(x)
if first_layer:
x = out + residual
continue
post_norm = self.post_norms[layer_ind - 1]
x = post_norm(out + residual * residual_scale)
return x
class ViT(Module):
def __init__(
self,
*,
image_size,
patch_size,
num_classes,
dim,
depth,
heads,
mlp_dim,
pool = 'cls',
channels = 3,
dim_head = 64,
dropout = 0.,
emb_dropout = 0.,
keel_residual_scale = None
):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
num_cls_tokens = 1 if pool == 'cls' else 0
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.cls_token = nn.Parameter(torch.randn(num_cls_tokens, dim))
self.pos_embedding = nn.Parameter(torch.randn(num_patches + num_cls_tokens, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(
dim,
depth,
heads,
dim_head,
mlp_dim,
dropout,
keel_residual_scale = keel_residual_scale
)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes) if num_classes > 0 else None
def forward(self, img):
batch = img.shape[0]
x = self.to_patch_embedding(img)
cls_tokens = repeat(self.cls_token, '... d -> b ... d', b = batch)
x = torch.cat((cls_tokens, x), dim = 1)
seq = x.shape[1]
x = x + self.pos_embedding[:seq]
x = self.dropout(x)
x = self.transformer(x)
if not exists(self.mlp_head):
return x
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)
if __name__ == '__main__':
v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
img = torch.randn(1, 3, 256, 256)
preds = v(img)
assert preds.shape == (1, 1000)

View File

@@ -1,5 +1,10 @@
from collections import namedtuple
import torch
from torch import nn
from torch import nn, cat
import torch.nn.functional as F
from torch.nn import Module, ModuleList
from torch.nn.attention import SDPBackend, sdpa_kernel
from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange
@@ -9,12 +14,15 @@ from einops.layers.torch import Rearrange
def exists(val):
return val is not None
def divisible_by(num, den):
return (num % den) == 0
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# classes
class FeedForward(nn.Module):
class FeedForward(Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
@@ -28,9 +36,11 @@ class FeedForward(nn.Module):
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., use_flash_attn = True):
super().__init__()
self.use_flash_attn = use_flash_attn
self.dropout_p = dropout
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
@@ -48,61 +58,100 @@ class Attention(nn.Module):
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
def flash_attn(self, q, k, v, mask = None):
with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION]):
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask = mask,
dropout_p = self.dropout_p,
is_causal = False,
scale = self.scale
)
return out
def forward(self, x, mask = None):
batch, seq, _ = x.shape
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
attn = self.attend(dots)
attn = self.dropout(attn)
if self.use_flash_attn:
out = self.flash_attn(q, k, v, mask = mask)
else:
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
if exists(mask):
dots = dots.masked_fill(~mask, -torch.finfo(dots.dtype).max)
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., use_flash_attn = True):
super().__init__()
self.use_flash_attn = use_flash_attn
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_flash_attn = use_flash_attn),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
def forward(self, x, mask = None):
for attn, ff in self.layers:
x = attn(x) + x
x = attn(x, mask = mask) + x
x = ff(x) + x
return self.norm(x)
class FactorizedTransformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
class FactorizedTransformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., use_flash_attn = True):
super().__init__()
self.use_flash_attn = use_flash_attn
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_flash_attn = use_flash_attn),
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_flash_attn = use_flash_attn),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
b, f, n, _ = x.shape
def forward(self, x, mask = None):
batch, frames, seq, _ = x.shape
if exists(mask):
mask = repeat(mask, 'b ... -> (b space) ...', space = x.shape[2])
for spatial_attn, temporal_attn, ff in self.layers:
x = rearrange(x, 'b f n d -> (b f) n d')
x = spatial_attn(x) + x
x = rearrange(x, '(b f) n d -> (b n) f d', b=b, f=f)
x = temporal_attn(x) + x
x = rearrange(x, '(b f) n d -> (b n) f d', b = batch, f = frames)
x = temporal_attn(x, mask = mask) + x
x = ff(x) + x
x = rearrange(x, '(b n) f d -> b f n d', b=b, n=n)
x = rearrange(x, '(b n) f d -> b f n d', b = batch, n = seq)
return self.norm(x)
class ViT(nn.Module):
class ViViT(Module):
def __init__(
self,
*,
@@ -122,13 +171,14 @@ class ViT(nn.Module):
dropout = 0.,
emb_dropout = 0.,
variant = 'factorized_encoder',
use_flash_attn: bool = True,
):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(image_patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size'
assert divisible_by(image_height, patch_height) and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
assert divisible_by(frames, frame_patch_size), 'Frames must be divisible by frame patch size'
assert variant in ('factorized_encoder', 'factorized_self_attention'), f'variant = {variant} is not implemented'
num_image_patches = (image_height // patch_height) * (image_width // patch_width)
@@ -138,6 +188,8 @@ class ViT(nn.Module):
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.frame_patch_size = frame_patch_size
self.global_average_pool = pool == 'mean'
self.to_patch_embedding = nn.Sequential(
@@ -154,11 +206,11 @@ class ViT(nn.Module):
if variant == 'factorized_encoder':
self.temporal_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None
self.spatial_transformer = Transformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)
self.temporal_transformer = Transformer(dim, temporal_depth, heads, dim_head, mlp_dim, dropout)
self.spatial_transformer = Transformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout, use_flash_attn)
self.temporal_transformer = Transformer(dim, temporal_depth, heads, dim_head, mlp_dim, dropout, use_flash_attn)
elif variant == 'factorized_self_attention':
assert spatial_depth == temporal_depth, 'Spatial and temporal depth must be the same for factorized self-attention'
self.factorized_transformer = FactorizedTransformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)
self.factorized_transformer = FactorizedTransformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout, use_flash_attn)
self.pool = pool
self.to_latent = nn.Identity()
@@ -166,25 +218,36 @@ class ViT(nn.Module):
self.mlp_head = nn.Linear(dim, num_classes)
self.variant = variant
def forward(self, video):
x = self.to_patch_embedding(video)
b, f, n, _ = x.shape
def forward(self, video, mask = None):
device = video.device
x = x + self.pos_embedding[:, :f, :n]
x = self.to_patch_embedding(video)
batch, frames, seq, _ = x.shape
x = x + self.pos_embedding[:, :frames, :seq]
if exists(self.spatial_cls_token):
spatial_cls_tokens = repeat(self.spatial_cls_token, '1 1 d -> b f 1 d', b = b, f = f)
x = torch.cat((spatial_cls_tokens, x), dim = 2)
spatial_cls_tokens = repeat(self.spatial_cls_token, '1 1 d -> b f 1 d', b = batch, f = frames)
x = cat((spatial_cls_tokens, x), dim = 2)
x = self.dropout(x)
# maybe temporal mask
temporal_mask = None
if exists(mask):
temporal_mask = reduce(mask, 'b (f patch) -> b f', 'all', patch = self.frame_patch_size)
# the two variants
if self.variant == 'factorized_encoder':
x = rearrange(x, 'b f n d -> (b f) n d')
# attend across space
x = self.spatial_transformer(x)
x = rearrange(x, '(b f) n d -> b f n d', b = b)
x = rearrange(x, '(b f) n d -> b f n d', b = batch)
# excise out the spatial cls tokens or average pool for temporal attention
@@ -193,22 +256,50 @@ class ViT(nn.Module):
# append temporal CLS tokens
if exists(self.temporal_cls_token):
temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = b)
temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = batch)
x = torch.cat((temporal_cls_tokens, x), dim = 1)
x = cat((temporal_cls_tokens, x), dim = 1)
if exists(temporal_mask):
temporal_mask = F.pad(temporal_mask, (1, 0), value = True)
# attend across time
x = self.temporal_transformer(x)
x = self.temporal_transformer(x, mask = temporal_mask)
# excise out temporal cls token or average pool
x = x[:, 0] if not self.global_average_pool else reduce(x, 'b f d -> b d', 'mean')
elif self.variant == 'factorized_self_attention':
x = self.factorized_transformer(x)
x = self.factorized_transformer(x, mask = temporal_mask)
x = x[:, 0, 0] if not self.global_average_pool else reduce(x, 'b f n d -> b d', 'mean')
x = self.to_latent(x)
return self.mlp_head(x)
# main
if __name__ == '__main__':
vivit = ViViT(
dim = 512,
spatial_depth = 2,
temporal_depth = 2,
heads = 4,
mlp_dim = 2048,
image_size = 256,
image_patch_size = 16,
frames = 8,
frame_patch_size = 2,
num_classes = 1000,
variant = 'factorized_encoder',
)
video = torch.randn(3, 3, 8, 256, 256)
mask = torch.randint(0, 2, (3, 8)).bool()
logits = vivit(video, mask = None)
assert logits.shape == (3, 1000)