Compare commits

..

25 Commits
1.16.4 ... main

Author SHA1 Message Date
lucidrains
2e86e34b01 add moss to accept video wrapper, aimed towards SRT-H 2026-05-19 08:43:57 -07:00
lucidrains
a9eaa74e14 a copy of MOSS (stack of STSS) for personal use, stop gap measure before video models get commoditized 2026-05-18 08:23:05 -07:00
lucidrains
93df0e6046 add another vit variant, where they found improvements for certain tasks when cls token get its own specialized parameters - layernorms were crucial 2026-05-01 13:49:43 -07:00
lucidrains
8e104e9afc cleanup vit det pool 2026-05-01 12:40:31 -07:00
lucidrains
3f03aa3994 add a vit that can accept an object mask (from sam or other seg models), and only attends and pools those patch tokens 2026-05-01 12:17:19 -07:00
lucidrains
2da1b45b9b allow vit to modulate the parallel and orthog components 2026-04-18 09:46:29 -07:00
lucidrains
7ab07c2499 add vit with orthogonal residual update 2026-04-18 09:24:52 -07:00
lucidrains
dea6b0da56 blur the line between depth and recurrence even more 2026-04-06 09:27:02 -07:00
lucidrains
13284b7af1 first attention residual should be disabled, cleanup 2026-04-05 20:14:06 -07:00
lucidrains
7e18d0302e stop relying on github 2026-03-27 08:07:42 -07:00
lucidrains
b80676e09c add an attention residual example (kimi team) as well as dino / byol redone with sigreg (lejepa) 2026-03-27 07:51:12 -07:00
lucidrains
fc1e727428 add ability to condition on binned advantages for the vision action transformers 2026-02-19 14:10:44 -08:00
lucidrains
6032a54b48 patch 2026-02-11 11:49:51 -08:00
Harikrishna KP
06a1f42924 Fix ViViT Transformer not passing use_flash_attn to Attention and duplicate mask reshape (#360)
Two related bugs in vivit.py:

1. Transformer.__init__ accepted use_flash_attn but never forwarded it to the
   Attention modules it creates. Since Attention defaults to use_flash_attn=True,
   setting use_flash_attn=False on ViViT had no effect on the factorized_encoder
   variant's spatial and temporal transformers.

2. Attention.forward reshaped the mask from 2D to 4D before the flash/non-flash
   branch (line 82), then attempted to reshape it again inside the non-flash
   branch (line 92). When the non-flash code path is actually reached with a
   mask, einops raises an error because the mask is already 4D.

   These bugs masked each other: bug #1 prevented bug #2 from triggering because
   the non-flash path was never taken even when requested.

Fix: pass use_flash_attn through to Attention in Transformer.__init__, and
remove the redundant second mask rearrange in the non-flash branch.
2026-02-11 11:49:31 -08:00
Phil Wang
6ae6a3ab64 cleanup 2026-02-04 13:29:40 -08:00
lucidrains
827300beed add vit with keel post ln, proposed by bytedance for scaling depth 2026-02-04 09:09:17 -08:00
lucidrains
a7c4e7f79f best practices 2026-01-28 05:04:20 -08:00
lucidrains
54ec3f2af5 address https://github.com/lucidrains/vit-pytorch/issues/357 2026-01-17 05:11:36 -08:00
lucidrains
9aa52cce49 do an actual vat with siglip arch, and have gemini flash craft the weight loading script from hf 2026-01-15 11:27:05 -08:00
lucidrains
4c89017444 fix up vivit 2026-01-08 06:36:40 -08:00
Eyal Mazuz
580258d99e Allow to pass mask parameter for temporal transformer in ViVit (#356)
* Mask for temporal transformer in ViVit

This allows to pad videos to certain length which allow the transformer
to ignore padded frames using batch sizes > 1

* Added flash attention to vivit

* Added flash attention to vivit

* Added flash attention to vivit
2026-01-08 06:08:54 -08:00
lucidrains
6f1caef987 allow for no final output head on the vit 2026-01-06 13:00:48 -08:00
lucidrains
fb5014f0ee get a version of n-dimensional vit with golden gate polar coordinate embeddings into the repo for future use 2025-12-25 09:11:13 -08:00
Phil Wang
0b7518ef45 educate 2025-12-21 07:06:20 -08:00
lucidrains
077d8c188f fix distill 2025-12-10 15:52:10 -08:00
55 changed files with 3635 additions and 644 deletions

3
.github/FUNDING.yml vendored
View File

@@ -1,3 +0,0 @@
# These are supported funding model platforms
github: [lucidrains]

View File

@@ -1,36 +0,0 @@
# This workflow will upload a Python Package using Twine when a release is created
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
# This workflow uses actions that are not certified by GitHub.
# They are provided by a third-party and are governed by
# separate terms of service, privacy policy, and support
# documentation.
name: Upload Python Package
on:
release:
types: [published]
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install build
- name: Build package
run: python -m build
- name: Publish package
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}

View File

@@ -1,34 +0,0 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: Test
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8, 3.9]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cpu
python -m pip install -e .
python -m pip install pytest
- name: Test with pytest
run: |
pytest -q

3
.gitignore vendored
View File

@@ -127,3 +127,6 @@ dmypy.json
# Pyre type checker
.pyre/
# scripts
*.sh

177
README.md
View File

@@ -49,7 +49,7 @@
## Vision Transformer - Pytorch
Implementation of <a href="https://openreview.net/pdf?id=YicbFdNTTy">Vision Transformer</a>, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. Significance is further explained in <a href="https://www.youtube.com/watch?v=TrdevFK_am4">Yannic Kilcher's</a> video. There's really not much to code here, but may as well lay it out for everyone so we expedite the attention revolution.
Implementation of <a href="https://openreview.net/pdf?id=YicbFdNTTy">Vision Transformer</a>, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. Significance is further explained in <a href="https://www.youtube.com/watch?v=TrdevFK_am4">Yannic Kilcher's</a> video. There's really not much to code here, but may as well lay it out for everyone so we expedite the [attention](https://www.youtube.com/watch?v=eMlx5fFNoYc) revolution.
For a Pytorch implementation with pretrained models, please see Ross Wightman's repository <a href="https://github.com/rwightman/pytorch-image-models">here</a>.
@@ -90,26 +90,26 @@ preds = v(img) # (1, 1000)
## Parameters
- `image_size`: int.
- `image_size`: int.
Image size. If you have rectangular images, make sure your image size is the maximum of the width and height
- `patch_size`: int.
Size of patches. `image_size` must be divisible by `patch_size`.
- `patch_size`: int.
Size of patches. `image_size` must be divisible by `patch_size`.
The number of patches is: ` n = (image_size // patch_size) ** 2` and `n` **must be greater than 16**.
- `num_classes`: int.
- `num_classes`: int.
Number of classes to classify.
- `dim`: int.
- `dim`: int.
Last dimension of output tensor after linear transformation `nn.Linear(..., dim)`.
- `depth`: int.
- `depth`: int.
Number of Transformer blocks.
- `heads`: int.
Number of heads in Multi-head Attention layer.
- `mlp_dim`: int.
Dimension of the MLP (FeedForward) layer.
- `channels`: int, default `3`.
Number of image's channels.
- `dropout`: float between `[0, 1]`, default `0.`.
Dropout rate.
- `emb_dropout`: float between `[0, 1]`, default `0`.
- `heads`: int.
Number of heads in Multi-head Attention layer.
- `mlp_dim`: int.
Dimension of the MLP (FeedForward) layer.
- `channels`: int, default `3`.
Number of image's channels.
- `dropout`: float between `[0, 1]`, default `0.`.
Dropout rate.
- `emb_dropout`: float between `[0, 1]`, default `0`.
Embedding dropout rate.
- `pool`: string, either `cls` token pooling or `mean` pooling
@@ -972,7 +972,7 @@ torch.save(model.state_dict(), './pretrained-net.pt')
<img src="./images/mp3.png" width="400px"></img>
New <a href="https://arxiv.org/abs/2207.07611">paper</a> that introduces masked position prediction pre-training criteria. This strategy is more efficient than the Masked Autoencoder strategy and has comparable performance.
New <a href="https://arxiv.org/abs/2207.07611">paper</a> that introduces masked position prediction pre-training criteria. This strategy is more efficient than the Masked Autoencoder strategy and has comparable performance.
```python
import torch
@@ -1358,10 +1358,10 @@ learner = Dino(
hidden_layer = 'to_latent', # hidden layer name or index, from which to extract the embedding
projection_hidden_size = 256, # projector network hidden dimension
projection_layers = 4, # number of layers in projection network
num_classes_K = 65336, # output logits dimensions (referenced as K in paper)
num_classes_K = 65536, # output logits dimensions (referenced as K in paper)
student_temp = 0.9, # student temperature
teacher_temp = 0.04, # teacher temperature, needs to be annealed from 0.04 to 0.07 over 30 epochs
local_upper_crop_scale = 0.4, # upper bound for local crop - 0.4 was recommended in the paper
local_upper_crop_scale = 0.4, # upper bound for local crop - 0.4 was recommended in the paper
global_lower_crop_scale = 0.5, # lower bound for global crop - 0.5 was recommended in the paper
moving_average_decay = 0.9, # moving average of encoder - paper showed anywhere from 0.9 to 0.999 was ok
center_moving_average_decay = 0.9, # moving average of teacher centers - paper showed anywhere from 0.9 to 0.999 was ok
@@ -1735,7 +1735,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
```bibtex
@misc{touvron2020training,
title = {Training data-efficient image transformers & distillation through attention},
title = {Training data-efficient image transformers & distillation through attention},
author = {Hugo Touvron and Matthieu Cord and Matthijs Douze and Francisco Massa and Alexandre Sablayrolles and Hervé Jégou},
year = {2020},
eprint = {2012.12877},
@@ -1768,7 +1768,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
```bibtex
@misc{touvron2021going,
title = {Going deeper with Image Transformers},
title = {Going deeper with Image Transformers},
author = {Hugo Touvron and Matthieu Cord and Alexandre Sablayrolles and Gabriel Synnaeve and Hervé Jégou},
year = {2021},
eprint = {2103.17239},
@@ -1801,7 +1801,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
```bibtex
@misc{heo2021rethinking,
title = {Rethinking Spatial Dimensions of Vision Transformers},
title = {Rethinking Spatial Dimensions of Vision Transformers},
author = {Byeongho Heo and Sangdoo Yun and Dongyoon Han and Sanghyuk Chun and Junsuk Choe and Seong Joon Oh},
year = {2021},
eprint = {2103.16302},
@@ -1845,7 +1845,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
```bibtex
@misc{su2021roformer,
title = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
title = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
author = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
year = {2021},
eprint = {2104.09864},
@@ -1867,7 +1867,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
```bibtex
@misc{chen2021regionvit,
title = {RegionViT: Regional-to-Local Attention for Vision Transformers},
title = {RegionViT: Regional-to-Local Attention for Vision Transformers},
author = {Chun-Fu Chen and Rameswar Panda and Quanfu Fan},
year = {2021},
eprint = {2106.02689},
@@ -1878,7 +1878,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
```bibtex
@misc{wang2021crossformer,
title = {CrossFormer: A Versatile Vision Transformer Hinging on Cross-scale Attention},
title = {CrossFormer: A Versatile Vision Transformer Hinging on Cross-scale Attention},
author = {Wenxiao Wang and Lu Yao and Long Chen and Binbin Lin and Deng Cai and Xiaofei He and Wei Liu},
year = {2021},
eprint = {2108.00154},
@@ -1900,7 +1900,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
```bibtex
@misc{he2021masked,
title = {Masked Autoencoders Are Scalable Vision Learners},
title = {Masked Autoencoders Are Scalable Vision Learners},
author = {Kaiming He and Xinlei Chen and Saining Xie and Yanghao Li and Piotr Dollár and Ross Girshick},
year = {2021},
eprint = {2111.06377},
@@ -1911,7 +1911,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
```bibtex
@misc{xie2021simmim,
title = {SimMIM: A Simple Framework for Masked Image Modeling},
title = {SimMIM: A Simple Framework for Masked Image Modeling},
author = {Zhenda Xie and Zheng Zhang and Yue Cao and Yutong Lin and Jianmin Bao and Zhuliang Yao and Qi Dai and Han Hu},
year = {2021},
eprint = {2111.09886},
@@ -1944,7 +1944,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
```bibtex
@misc{lee2021vision,
title = {Vision Transformer for Small-Size Datasets},
title = {Vision Transformer for Small-Size Datasets},
author = {Seung Hoon Lee and Seunghyun Lee and Byung Cheol Song},
year = {2021},
eprint = {2112.13492},
@@ -1966,7 +1966,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
```bibtex
@misc{yang2022scalablevit,
title = {ScalableViT: Rethinking the Context-oriented Generalization of Vision Transformer},
title = {ScalableViT: Rethinking the Context-oriented Generalization of Vision Transformer},
author = {Rui Yang and Hailong Ma and Jie Wu and Yansong Tang and Xuefeng Xiao and Min Zheng and Xiu Li},
year = {2022},
eprint = {2203.10790},
@@ -2203,13 +2203,128 @@ Coming from computer vision and new to transformers? Here are some resources tha
```bibtex
@misc{carrigg2025decorrelationspeedsvisiontransformers,
title = {Decorrelation Speeds Up Vision Transformers},
title = {Decorrelation Speeds Up Vision Transformers},
author = {Kieran Carrigg and Rob van Gastel and Melda Yeghaian and Sander Dalm and Faysal Boughorbel and Marcel van Gerven},
year = {2025},
eprint = {2510.14657},
archivePrefix = {arXiv},
primaryClass = {cs.CV},
url = {https://arxiv.org/abs/2510.14657},
url = {https://arxiv.org/abs/2510.14657},
}
```
```bibtex
@misc{gopalakrishnan2025decouplingwhatwherepolar,
title = {Decoupling the "What" and "Where" With Polar Coordinate Positional Embeddings},
author = {Anand Gopalakrishnan and Robert Csordás and Jürgen Schmidhuber and Michael C. Mozer},
year = {2025},
eprint = {2509.10534},
archivePrefix = {arXiv},
primaryClass = {cs.LG},
url = {https://arxiv.org/abs/2509.10534},
}
```
```bibtex
@misc{qiu2025gatedattentionlargelanguage,
title = {Gated Attention for Large Language Models: Non-linearity, Sparsity, and Attention-Sink-Free},
author = {Zihan Qiu and Zekun Wang and Bo Zheng and Zeyu Huang and Kaiyue Wen and Songlin Yang and Rui Men and Le Yu and Fei Huang and Suozhi Huang and Dayiheng Liu and Jingren Zhou and Junyang Lin},
year = {2025},
eprint = {2505.06708},
archivePrefix = {arXiv},
primaryClass = {cs.CL},
url = {https://arxiv.org/abs/2505.06708}
}
```
```bibtex
@misc{chen2026postlayernormbackstableexpressive,
title = {Post-LayerNorm Is Back: Stable, ExpressivE, and Deep},
author = {Chen Chen and Lai Wei},
year = {2026},
eprint = {2601.19895},
archivePrefix = {arXiv},
primaryClass = {cs.LG},
url = {https://arxiv.org/abs/2601.19895},
}
```
```bibtex
@misc{intelligence2025pi06vlalearnsexperience,
title = {$\pi^{*}_{0.6}$: a VLA That Learns From Experience},
author = {Physical Intelligence and Ali Amin and Raichelle Aniceto and Ashwin Balakrishna and Kevin Black and Ken Conley and Grace Connors and James Darpinian and Karan Dhabalia and Jared DiCarlo and Danny Driess and Michael Equi and Adnan Esmail and Yunhao Fang and Chelsea Finn and Catherine Glossop and Thomas Godden and Ivan Goryachev and Lachy Groom and Hunter Hancock and Karol Hausman and Gashon Hussein and Brian Ichter and Szymon Jakubczak and Rowan Jen and Tim Jones and Ben Katz and Liyiming Ke and Chandra Kuchi and Marinda Lamb and Devin LeBlanc and Sergey Levine and Adrian Li-Bell and Yao Lu and Vishnu Mano and Mohith Mothukuri and Suraj Nair and Karl Pertsch and Allen Z. Ren and Charvi Sharma and Lucy Xiaoyang Shi and Laura Smith and Jost Tobias Springenberg and Kyle Stachowicz and Will Stoeckle and Alex Swerdlow and James Tanner and Marcel Torne and Quan Vuong and Anna Walling and Haohuan Wang and Blake Williams and Sukwon Yoo and Lili Yu and Ury Zhilinsky and Zhiyuan Zhou},
year = {2025},
eprint = {2511.14759},
archivePrefix = {arXiv},
primaryClass = {cs.LG},
url = {https://arxiv.org/abs/2511.14759},
}
```
```bibtex
@misc{kimiteam2026attentionresiduals,
title = {Attention Residuals},
author = {Kimi Team and Guangyu Chen and Yu Zhang and Jianlin Su and Weixin Xu and Siyuan Pan and Yaoyu Wang and Yucheng Wang and Guanduo Chen and Bohong Yin and Yutian Chen and Junjie Yan and Ming Wei and Y. Zhang and Fanqing Meng and Chao Hong and Xiaotong Xie and Shaowei Liu and Enzhe Lu and Yunpeng Tai and Yanru Chen and Xin Men and Haiqing Guo and Y. Charles and Haoyu Lu and Lin Sui and Jinguo Zhu and Zaida Zhou and Weiran He and Weixiao Huang and Xinran Xu and Yuzhi Wang and Guokun Lai and Yulun Du and Yuxin Wu and Zhilin Yang and Xinyu Zhou},
year = {2026},
eprint = {2603.15031},
archivePrefix = {arXiv},
primaryClass = {cs.CL},
url = {https://arxiv.org/abs/2603.15031},
}
```
```bibtex
@misc{balestriero2025lejepa,
title = {LeJEPA: Provable and Scalable Self-Supervised Learning Without the Heuristics},
author = {Randall Balestriero and Yann LeCun},
year = {2025},
eprint = {2511.08544},
archivePrefix = {arXiv},
primaryClass = {cs.LG},
url = {https://arxiv.org/abs/2511.08544},
}
```
```bibtex
@misc{oh2026revisitingresidualconnectionsorthogonal,
title = {Revisiting Residual Connections: Orthogonal Updates for Stable and Efficient Deep Networks},
author = {Giyeong Oh and Woohyun Cho and Siyeol Kim and Suhwan Choi and Youngjae Yu},
year = {2026},
eprint = {2505.11881},
archivePrefix = {arXiv},
primaryClass = {cs.CV},
url = {https://arxiv.org/abs/2505.11881},
}
```
```bibtex
@inproceedings{niu2026learning,
title = {Learning to Grasp Anything By Playing with Random Toys},
author = {Dantong Niu and Yuvan Sharma and Baifeng Shi and Rachel Ding and Matteo Gioia and Haoru Xue and Henry Tsai and Konstantinos Kallidromitis and Anirudh Pai and S. Shankar Sastry and Trevor Darrell and Jitendra Malik and Roei Herzig},
booktitle = {The Fourteenth International Conference on Learning Representations},
year = {2026},
url = {https://openreview.net/forum?id=NZDaMcpXZm}
}
```
```bibtex
@misc{marouani2026revisitingclspatchtoken,
title = {Revisiting [CLS] and Patch Token Interaction in Vision Transformers},
author = {Alexis Marouani and Oriane Siméoni and Hervé Jégou and Piotr Bojanowski and Huy V. Vo},
year = {2026},
eprint = {2602.08626},
archivePrefix = {arXiv},
primaryClass = {cs.CV},
url = {https://arxiv.org/abs/2602.08626},
}
```
```bibtex
@misc{kim2026exploring,
title = {Exploring High-Order Self-Similarity for Video Understanding},
author = {Manjin Kim and Heeseung Kwon and Karteek Alahari and Minsu Cho},
year = {2026},
url = {https://openreview.net/forum?id=Co6SCyBIjo}
}
```

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "vit-pytorch"
version = "1.16.4"
version = "1.21.0"
description = "Vision Transformer (ViT) - Pytorch"
readme = { file = "README.md", content-type = "text/markdown" }
license = { file = "LICENSE" }
@@ -31,8 +31,8 @@ classifiers = [
"Programming Language :: Python :: 3.12",
]
dependencies = [
"einops>=0.7.0",
"torch>=1.10",
"einops>=0.8.2",
"torch>=2.4",
"torchvision",
]
@@ -44,8 +44,8 @@ test = [
]
[project.urls]
Homepage = "https://github.com/lucidrains/vit-pytorch"
Repository = "https://github.com/lucidrains/vit-pytorch"
Homepage = "https://codeberg.org/lucidrains/vit-pytorch"
Repository = "https://codeberg.org/lucidrains/vit-pytorch"
[tool.setuptools]
include-package-data = true

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
from contextlib import nullcontext
import torch
@@ -5,6 +7,8 @@ from torch import is_tensor, randn
from torch.nn import Module, Linear, Parameter
from torch.utils._pytree import tree_flatten, tree_unflatten
from vit_pytorch.vivit_with_moss import MOSS
from einops import rearrange, repeat
# helper functions
@@ -15,6 +19,9 @@ def exists(v):
def default(v, d):
return v if exists(v) else d
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# classes
class AcceptVideoWrapper(Module):
@@ -26,8 +33,10 @@ class AcceptVideoWrapper(Module):
dim_emb = None,
time_seq_len = None,
embed_is_channel_first = False,
output_pos_add_pos_emb = 0, # defaults to first output position to add embedding
proj_embed_to_dim = None
output_pos_add_pos_emb = 0, # defaults to first output position to add embedding
proj_embed_to_dim = None,
patch_size = None,
moss: Module | dict | None = None
):
super().__init__()
self.image_net = image_net
@@ -56,6 +65,24 @@ class AcceptVideoWrapper(Module):
self.embed_is_channel_first = embed_is_channel_first
# patch size and moss
if not exists(patch_size):
if hasattr(image_net, 'patch_size'):
patch_size = image_net.patch_size
elif hasattr(image_net, 'vit') and hasattr(image_net.vit, 'patch_size'):
patch_size = image_net.vit.patch_size
self.patch_size = patch_size
if isinstance(moss, dict):
moss = MOSS(**moss)
self.moss = moss
if exists(self.moss):
assert exists(self.patch_size), '`patch_size` must be provided either on the `image_net` or passed in explicitly if using MOSS'
def forward(
self,
video, # (b c t h w)
@@ -70,6 +97,8 @@ class AcceptVideoWrapper(Module):
if add_time_pos_emb:
assert time <= self.time_seq_len, f'received video with {time} frames but `time_seq_len` ({self.time_seq_len}) is too low'
video_height, video_width = video.shape[-2:]
video = rearrange(video, 'b c t h w -> b t c h w')
video = rearrange(video, 'b t ... -> (b t) ...')
@@ -127,6 +156,27 @@ class AcceptVideoWrapper(Module):
outputs[self.output_pos_add_pos_emb] = embed
# moss - stack of ssts
# https://openreview.net/forum?id=Co6SCyBIjo
if exists(self.moss):
outputs = list(outputs)
embed = outputs[self.output_pos_add_pos_emb]
patch_h, patch_w = pair(self.patch_size)
num_h, num_w = video_height // patch_h, video_width // patch_w
num_patches = num_h * num_w
num_cls_tokens = embed.shape[-2] - num_patches
cls_tokens, patch_tokens = embed[:, :, :num_cls_tokens], embed[:, :, num_cls_tokens:]
patch_tokens = rearrange(patch_tokens, 'b t (h w) d -> b t h w d', h = num_h, w = num_w)
patch_tokens = self.moss(patch_tokens)
patch_tokens = rearrange(patch_tokens, 'b t h w d -> b t (h w) d')
embed = torch.cat((cls_tokens, patch_tokens), dim = -2)
outputs[self.output_pos_add_pos_emb] = embed
return tree_unflatten(outputs, tree_spec)
# main
@@ -151,9 +201,28 @@ if __name__ == '__main__':
# step up the difficulty and return embeddings for robotics
from vit_pytorch.extractor import Extractor
v = Extractor(v)
video_acceptor = AcceptVideoWrapper(v, add_time_pos_emb = True, output_pos_add_pos_emb = 1, time_seq_len = 12, dim_emb = 1024, proj_embed_to_dim = 512)
moss_kwargs = dict(
dim = 512,
local_time = 3,
local_height = 3,
local_width = 3,
hidden_dim = 64,
orders = 2,
causal = True
)
video_acceptor = AcceptVideoWrapper(
v,
add_time_pos_emb = True,
output_pos_add_pos_emb = 1,
time_seq_len = 12,
dim_emb = 1024,
proj_embed_to_dim = 512,
moss = moss_kwargs
)
logits, embeddings = video_acceptor(videos, eval_with_no_grad = True) # always (batch, channels, time, height, width) - time is always dimension 2

View File

@@ -216,7 +216,7 @@ class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, max_tokens_per_depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
self.patch_size = patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

View File

@@ -25,12 +25,12 @@ class DistillMixin:
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = b)
cls_tokens = repeat(self.cls_token, 'n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim = 1)
x += self.pos_embedding[:, :(n + 1)]
x += self.pos_embedding[:(n + 1)]
if distilling:
distill_tokens = repeat(distill_token, '1 n d -> b n d', b = b)
distill_tokens = repeat(distill_token, 'n d -> b n d', b = b)
x = torch.cat((x, distill_tokens), dim = 1)
x = self._attend(x)
@@ -125,7 +125,7 @@ class DistillWrapper(Module):
self.alpha = alpha
self.hard = hard
self.distillation_token = nn.Parameter(torch.randn(1, 1, dim))
self.distillation_token = nn.Parameter(torch.randn(1, dim))
self.distill_mlp = nn.Sequential(
nn.LayerNorm(dim) if mlp_layernorm else nn.Identity(),

View File

@@ -86,7 +86,7 @@ class JumboViT(Module):
):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
self.patch_size = patch_height, patch_width = pair(patch_size)
assert divisible_by(image_height, patch_height) and divisible_by(image_width, patch_width), 'Image dimensions must be divisible by the patch size.'
@@ -103,7 +103,7 @@ class JumboViT(Module):
h = image_height // patch_height,
w = image_width // patch_width,
dim = dim,
)
)
jumbo_cls_dim = dim * jumbo_cls_k

View File

@@ -108,7 +108,7 @@ class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
self.patch_size = patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
@@ -145,7 +145,7 @@ class ViT(nn.Module):
return x
def forward(self, img):
x = self.img_to_tokens(img)
x = self.img_to_tokens(img)
x = self.transformer(x)
@@ -160,7 +160,7 @@ class Adapter(nn.Module):
*,
vit,
num_memories_per_layer = 10,
num_classes = 2,
num_classes = 2,
):
super().__init__()
assert isinstance(vit, ViT)
@@ -188,7 +188,7 @@ class Adapter(nn.Module):
)
# specialized attention mask to preserve the output of the original ViT
# it allows the memory CLS token to attend to all other tokens (and the learnable memory layer tokens), but not vice versa
# it allows the memory CLS token to attend to all other tokens (and the learnable memory layer tokens), but not vice versa
attn_mask = torch.ones((num_patches, num_patches), dtype = torch.bool)
attn_mask = F.pad(attn_mask, (1, num_memories_per_layer), value = False) # main tokens cannot attend to learnable memories per layer
@@ -203,7 +203,7 @@ class Adapter(nn.Module):
# add task specific memory tokens
memory_cls_tokens = repeat(self.memory_cls_token, 'd -> b 1 d', b = b)
tokens = torch.cat((memory_cls_tokens, tokens), dim = 1)
tokens = torch.cat((memory_cls_tokens, tokens), dim = 1)
# pass memories along with image tokens through transformer for attending

319
vit_pytorch/lejepa.py Normal file
View File

@@ -0,0 +1,319 @@
import random
from functools import wraps
import torch
from torch import nn
from torch.nn import Module
import torch.nn.functional as F
from torchvision import transforms as T
from einops import rearrange
# helper functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def singleton(cache_key):
def inner_fn(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
instance = getattr(self, cache_key)
if instance is not None:
return instance
instance = fn(self, *args, **kwargs)
setattr(self, cache_key, instance)
return instance
return wrapper
return inner_fn
def get_module_device(module):
return next(module.parameters()).device
def l2norm(t, eps = 1e-6):
return F.normalize(t, dim = -1, eps = eps)
# loss function
def sigreg_loss(
x,
num_slices = 1024,
domain = (-5, 5),
num_knots = 17
):
# Randall Balestriero - https://arxiv.org/abs/2511.08544
dim, device = x.shape[-1], x.device
# slice sampling
rand_projs = torch.randn((num_slices, dim), device = device)
rand_projs = l2norm(rand_projs)
# integration points
t = torch.linspace(*domain, num_knots, device = device)
# theoretical CF for N(0, 1) and Gauss. window
exp_f = (-0.5 * t.square()).exp()
# empirical CF
x_t = torch.einsum('... d, m d -> ... m', x, rand_projs)
x_t = rearrange(x_t, '... m -> (...) m')
x_t = rearrange(x_t, 'n m -> n m 1') * t
ecf = (1j * x_t).exp().mean(dim = 0)
# weighted L2 distance
err = ecf.sub(exp_f).abs().square().mul(exp_f)
return torch.trapezoid(err, t, dim = -1).mean()
# augmentation utils
class RandomApply(Module):
def __init__(self, fn, p):
super().__init__()
self.fn = fn
self.p = p
def forward(self, x):
if random.random() > self.p:
return x
return self.fn(x)
# MLP class for projector
class L2Norm(Module):
def forward(self, x, eps = 1e-6):
return l2norm(x, eps)
class MLP(Module):
def __init__(self, dim, dim_out, num_layers, hidden_size = 256):
super().__init__()
layers = []
dims = (dim, *((hidden_size,) * (num_layers - 1)))
for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
is_last = ind == (len(dims) - 1)
layers.extend([
nn.Linear(layer_dim_in, layer_dim_out),
nn.GELU() if not is_last else nn.Identity()
])
self.net = nn.Sequential(
*layers,
L2Norm(),
nn.Linear(hidden_size, dim_out)
)
def forward(self, x):
return self.net(x)
# wrapper
class NetWrapper(Module):
def __init__(self, net, output_dim, projection_hidden_size, projection_num_layers, layer = -2):
super().__init__()
self.net = net
self.layer = layer
self.projector = None
self.projection_hidden_size = projection_hidden_size
self.projection_num_layers = projection_num_layers
self.output_dim = output_dim
self.hidden = {}
self.hook_registered = False
def _find_layer(self):
if type(self.layer) == str:
modules = dict([*self.net.named_modules()])
return modules.get(self.layer, None)
elif type(self.layer) == int:
children = [*self.net.children()]
return children[self.layer]
return None
def _hook(self, _, input, output):
device = input[0].device
self.hidden[device] = output.flatten(1)
def _register_hook(self):
layer = self._find_layer()
assert layer is not None, f'hidden layer ({self.layer}) not found'
handle = layer.register_forward_hook(self._hook)
self.hook_registered = True
@singleton('projector')
def _get_projector(self, hidden):
_, dim = hidden.shape
projector = MLP(dim, self.output_dim, self.projection_num_layers, self.projection_hidden_size)
return projector.to(hidden)
def get_embedding(self, x):
if self.layer == -1:
return self.net(x)
if not self.hook_registered:
self._register_hook()
self.hidden.clear()
_ = self.net(x)
hidden = self.hidden[x.device]
self.hidden.clear()
assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
return hidden
def forward(self, x, return_projection = True):
embed = self.get_embedding(x)
if not return_projection:
return embed
projector = self._get_projector(embed)
return projector(embed), embed
# main class
class LeJEPA(nn.Module):
def __init__(
self,
net,
image_size,
hidden_layer = -2,
projection_hidden_size = 256,
num_classes_K = 65336,
projection_layers = 4,
local_upper_crop_scale = 0.4,
global_lower_crop_scale = 0.5,
target_loss_weight = 1.,
sigreg_loss_weight = 1.,
sigreg_loss_kwargs = dict(
num_slices = 1024,
domain = (-5, 5),
num_knots = 17
),
augment_fn = None,
augment_fn2 = None
):
super().__init__()
self.net = net
# default BYOL augmentation
DEFAULT_AUG = torch.nn.Sequential(
RandomApply(
T.ColorJitter(0.8, 0.8, 0.8, 0.2),
p = 0.3
),
T.RandomGrayscale(p=0.2),
T.RandomHorizontalFlip(),
RandomApply(
T.GaussianBlur((3, 3), (1.0, 2.0)),
p = 0.2
),
T.Normalize(
mean=torch.tensor([0.485, 0.456, 0.406]),
std=torch.tensor([0.229, 0.224, 0.225])),
)
self.augment1 = default(augment_fn, DEFAULT_AUG)
self.augment2 = default(augment_fn2, DEFAULT_AUG)
# local and global crops
self.local_crop = T.RandomResizedCrop((image_size, image_size), scale = (0.05, local_upper_crop_scale))
self.global_crop = T.RandomResizedCrop((image_size, image_size), scale = (global_lower_crop_scale, 1.))
self.encoder = NetWrapper(net, num_classes_K, projection_hidden_size, projection_layers, layer = hidden_layer)
self.target_loss_weight = target_loss_weight
self.sigreg_loss_weight = sigreg_loss_weight
self.sigreg_loss_kwargs = sigreg_loss_kwargs
# get device of network and make wrapper same device
device = get_module_device(net)
self.to(device)
# send a mock image tensor to instantiate singleton parameters
self.forward(torch.randn(2, 3, image_size, image_size, device=device))
def forward(
self,
x,
return_embedding = False,
return_projection = True
):
if return_embedding:
return self.encoder(x, return_projection = return_projection)
image_one, image_two = self.augment1(x), self.augment2(x)
local_image_one, local_image_two = self.local_crop(image_one), self.local_crop(image_two)
global_image_one, global_image_two = self.global_crop(image_one), self.global_crop(image_two)
local_images = torch.cat((local_image_one, local_image_two), dim = 0)
proj_locals, _ = self.encoder(local_images)
proj_local_one, proj_local_two = proj_locals.chunk(2, dim = 0)
with torch.no_grad():
global_images = torch.cat((global_image_one, global_image_two), dim = 0)
proj_globals, _ = self.encoder(global_images)
proj_global_one, proj_global_two = proj_globals.chunk(2, dim = 0)
# invariance loss
mse_loss = F.mse_loss(proj_local_one, proj_global_two) + F.mse_loss(proj_local_two, proj_global_one)
# sigreg loss
sreg_loss = sigreg_loss(proj_locals, **self.sigreg_loss_kwargs)
return mse_loss * self.target_loss_weight + sreg_loss * self.sigreg_loss_weight
# quick run
if __name__ == '__main__':
from vit_pytorch import ViT
model = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048
)
learner = LeJEPA(
model,
image_size = 256,
hidden_layer = 'to_latent', # layer name where output is hidden dimension
projection_hidden_size = 256, # projector network hidden dimension
projection_layers = 4, # number of layers in projection network
num_classes_K = 65336, # output dimension
target_loss_weight = 1.0,
sigreg_loss_weight = 1.0
)
opt = torch.optim.Adam(learner.parameters(), lr = 3e-4)
images = torch.randn(8, 3, 256, 256)
loss = learner(images)
opt.zero_grad()
loss.backward()
opt.step()
print('loss:', loss.item())

View File

@@ -182,7 +182,7 @@ class LeViT(nn.Module):
def forward(self, img):
x = self.conv_embedding(img)
x = self.backbone(x)
x = self.backbone(x)
x = self.pool(x)

View File

@@ -52,7 +52,7 @@ class MAE(nn.Module):
if self.encoder.pool == "cls":
tokens += self.encoder.pos_embedding[:, 1:(num_patches + 1)]
elif self.encoder.pool == "mean":
tokens += self.encoder.pos_embedding.to(device, dtype=tokens.dtype)
tokens += self.encoder.pos_embedding.to(device, dtype=tokens.dtype)
# calculate of patches needed to be masked, and get random indices, dividing it up for mask vs unmasked
@@ -87,7 +87,7 @@ class MAE(nn.Module):
mask_tokens = mask_tokens + self.decoder_pos_emb(masked_indices)
# concat the masked tokens to the decoder tokens and attend with decoder
decoder_tokens = torch.zeros(batch, num_patches, self.decoder_dim, device=device)
decoder_tokens[batch_range, unmasked_indices] = unmasked_decoder_tokens
decoder_tokens[batch_range, masked_indices] = mask_tokens

View File

@@ -77,7 +77,7 @@ class Dropsample(nn.Module):
def __init__(self, prob = 0):
super().__init__()
self.prob = prob
def forward(self, x):
device = x.device

View File

@@ -72,7 +72,7 @@ class Dropsample(Module):
def __init__(self, prob = 0):
super().__init__()
self.prob = prob
def forward(self, x):
device = x.device

View File

@@ -1,243 +1,243 @@
import torch
import torch.nn as nn
from einops import rearrange
from einops.layers.torch import Reduce
# helpers
def conv_1x1_bn(inp, oup):
return nn.Sequential(
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.SiLU()
)
def conv_nxn_bn(inp, oup, kernel_size=3, stride=1):
return nn.Sequential(
nn.Conv2d(inp, oup, kernel_size, stride, 1, bias=False),
nn.BatchNorm2d(oup),
nn.SiLU()
)
# classes
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.SiLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b p h n d -> b p n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
"""Transformer block described in ViT.
Paper: https://arxiv.org/abs/2010.11929
Based on: https://github.com/lucidrains/vit-pytorch
"""
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads, dim_head, dropout),
FeedForward(dim, mlp_dim, dropout)
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class MV2Block(nn.Module):
"""MV2 block described in MobileNetV2.
Paper: https://arxiv.org/pdf/1801.04381
Based on: https://github.com/tonylins/pytorch-mobilenet-v2
"""
def __init__(self, inp, oup, stride=1, expansion=4):
super().__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = int(inp * expansion)
self.use_res_connect = self.stride == 1 and inp == oup
if expansion == 1:
self.conv = nn.Sequential(
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
1, groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
)
else:
self.conv = nn.Sequential(
# pw
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
1, groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
)
def forward(self, x):
out = self.conv(x)
if self.use_res_connect:
out = out + x
return out
class MobileViTBlock(nn.Module):
def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):
super().__init__()
self.ph, self.pw = patch_size
self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
self.conv2 = conv_1x1_bn(channel, dim)
self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout)
self.conv3 = conv_1x1_bn(dim, channel)
self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)
def forward(self, x):
y = x.clone()
# Local representations
x = self.conv1(x)
x = self.conv2(x)
# Global representations
_, _, h, w = x.shape
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
x = self.transformer(x)
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)
# Fusion
x = self.conv3(x)
x = torch.cat((x, y), 1)
x = self.conv4(x)
return x
class MobileViT(nn.Module):
"""MobileViT.
Paper: https://arxiv.org/abs/2110.02178
Based on: https://github.com/chinhsuanwu/mobilevit-pytorch
"""
def __init__(
self,
image_size,
dims,
channels,
num_classes,
expansion=4,
kernel_size=3,
patch_size=(2, 2),
depths=(2, 4, 3)
):
super().__init__()
assert len(dims) == 3, 'dims must be a tuple of 3'
assert len(depths) == 3, 'depths must be a tuple of 3'
ih, iw = image_size
ph, pw = patch_size
assert ih % ph == 0 and iw % pw == 0
init_dim, *_, last_dim = channels
self.conv1 = conv_nxn_bn(3, init_dim, stride=2)
self.stem = nn.ModuleList([])
self.stem.append(MV2Block(channels[0], channels[1], 1, expansion))
self.stem.append(MV2Block(channels[1], channels[2], 2, expansion))
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
self.trunk = nn.ModuleList([])
self.trunk.append(nn.ModuleList([
MV2Block(channels[3], channels[4], 2, expansion),
MobileViTBlock(dims[0], depths[0], channels[5],
kernel_size, patch_size, int(dims[0] * 2))
]))
self.trunk.append(nn.ModuleList([
MV2Block(channels[5], channels[6], 2, expansion),
MobileViTBlock(dims[1], depths[1], channels[7],
kernel_size, patch_size, int(dims[1] * 4))
]))
self.trunk.append(nn.ModuleList([
MV2Block(channels[7], channels[8], 2, expansion),
MobileViTBlock(dims[2], depths[2], channels[9],
kernel_size, patch_size, int(dims[2] * 4))
]))
self.to_logits = nn.Sequential(
conv_1x1_bn(channels[-2], last_dim),
Reduce('b c h w -> b c', 'mean'),
nn.Linear(channels[-1], num_classes, bias=False)
)
def forward(self, x):
x = self.conv1(x)
for conv in self.stem:
x = conv(x)
for conv, attn in self.trunk:
x = conv(x)
x = attn(x)
return self.to_logits(x)
import torch
import torch.nn as nn
from einops import rearrange
from einops.layers.torch import Reduce
# helpers
def conv_1x1_bn(inp, oup):
return nn.Sequential(
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.SiLU()
)
def conv_nxn_bn(inp, oup, kernel_size=3, stride=1):
return nn.Sequential(
nn.Conv2d(inp, oup, kernel_size, stride, 1, bias=False),
nn.BatchNorm2d(oup),
nn.SiLU()
)
# classes
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.SiLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b p h n d -> b p n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
"""Transformer block described in ViT.
Paper: https://arxiv.org/abs/2010.11929
Based on: https://github.com/lucidrains/vit-pytorch
"""
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads, dim_head, dropout),
FeedForward(dim, mlp_dim, dropout)
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class MV2Block(nn.Module):
"""MV2 block described in MobileNetV2.
Paper: https://arxiv.org/pdf/1801.04381
Based on: https://github.com/tonylins/pytorch-mobilenet-v2
"""
def __init__(self, inp, oup, stride=1, expansion=4):
super().__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = int(inp * expansion)
self.use_res_connect = self.stride == 1 and inp == oup
if expansion == 1:
self.conv = nn.Sequential(
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
1, groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
)
else:
self.conv = nn.Sequential(
# pw
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
1, groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
)
def forward(self, x):
out = self.conv(x)
if self.use_res_connect:
out = out + x
return out
class MobileViTBlock(nn.Module):
def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):
super().__init__()
self.ph, self.pw = patch_size
self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
self.conv2 = conv_1x1_bn(channel, dim)
self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout)
self.conv3 = conv_1x1_bn(dim, channel)
self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)
def forward(self, x):
y = x.clone()
# Local representations
x = self.conv1(x)
x = self.conv2(x)
# Global representations
_, _, h, w = x.shape
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
x = self.transformer(x)
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)
# Fusion
x = self.conv3(x)
x = torch.cat((x, y), 1)
x = self.conv4(x)
return x
class MobileViT(nn.Module):
"""MobileViT.
Paper: https://arxiv.org/abs/2110.02178
Based on: https://github.com/chinhsuanwu/mobilevit-pytorch
"""
def __init__(
self,
image_size,
dims,
channels,
num_classes,
expansion=4,
kernel_size=3,
patch_size=(2, 2),
depths=(2, 4, 3)
):
super().__init__()
assert len(dims) == 3, 'dims must be a tuple of 3'
assert len(depths) == 3, 'depths must be a tuple of 3'
ih, iw = image_size
ph, pw = patch_size
assert ih % ph == 0 and iw % pw == 0
init_dim, *_, last_dim = channels
self.conv1 = conv_nxn_bn(3, init_dim, stride=2)
self.stem = nn.ModuleList([])
self.stem.append(MV2Block(channels[0], channels[1], 1, expansion))
self.stem.append(MV2Block(channels[1], channels[2], 2, expansion))
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
self.trunk = nn.ModuleList([])
self.trunk.append(nn.ModuleList([
MV2Block(channels[3], channels[4], 2, expansion),
MobileViTBlock(dims[0], depths[0], channels[5],
kernel_size, patch_size, int(dims[0] * 2))
]))
self.trunk.append(nn.ModuleList([
MV2Block(channels[5], channels[6], 2, expansion),
MobileViTBlock(dims[1], depths[1], channels[7],
kernel_size, patch_size, int(dims[1] * 4))
]))
self.trunk.append(nn.ModuleList([
MV2Block(channels[7], channels[8], 2, expansion),
MobileViTBlock(dims[2], depths[2], channels[9],
kernel_size, patch_size, int(dims[2] * 4))
]))
self.to_logits = nn.Sequential(
conv_1x1_bn(channels[-2], last_dim),
Reduce('b c h w -> b c', 'mean'),
nn.Linear(channels[-1], num_classes, bias=False)
)
def forward(self, x):
x = self.conv1(x)
for conv in self.stem:
x = conv(x)
for conv, attn in self.trunk:
x = conv(x)
x = attn(x)
return self.to_logits(x)

View File

@@ -107,10 +107,10 @@ class ViT(nn.Module):
def __init__(self, *, num_classes, image_size, patch_size, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
self.patch_size = patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
@@ -178,7 +178,7 @@ class MP3(nn.Module):
attended_tokens = self.vit.transformer(tokens, tokens_unmasked)
logits = rearrange(self.mlp_head(attended_tokens), 'b n d -> (b n) d')
# Define labels
labels = repeat(torch.arange(num_patches, device = device), 'n -> (b n)', b = batch)
loss = F.cross_entropy(logits, labels)

View File

@@ -343,11 +343,11 @@ class NaViT(nn.Module):
# need to know how many images for final attention pooling
num_images = torch.tensor(num_images, device = device, dtype = torch.long)
num_images = torch.tensor(num_images, device = device, dtype = torch.long)
# to patches
x = self.to_patch_embedding(patches)
x = self.to_patch_embedding(patches)
# factorized 2d absolute positional embedding

View File

@@ -64,7 +64,7 @@ class Attention(Module):
self.to_out = nn.Linear(dim_inner, dim, bias = False)
def forward(
self,
self,
x,
context: Tensor | None = None
):

View File

@@ -64,7 +64,7 @@ class Attention(Module):
self.to_out = nn.Linear(dim_inner, dim, bias = False)
def forward(
self,
self,
x,
context: Tensor | None = None
):

View File

@@ -91,7 +91,7 @@ class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', num_parallel_branches = 2, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
self.patch_size = patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

View File

@@ -154,7 +154,7 @@ class PiT(nn.Module):
for ind, (layer_depth, layer_heads) in enumerate(zip(depth, heads)):
not_last = ind < (len(depth) - 1)
layers.append(Transformer(dim, layer_depth, layer_heads, dim_head, mlp_dim, dropout))
if not_last:

View File

@@ -146,7 +146,7 @@ class R2LTransformer(nn.Module):
region_tokens = rearrange(region_tokens, 'b c h w -> b (h w) c')
# calculate local relative positional bias
h_range = torch.arange(window_size_h, device = device)
w_range = torch.arange(window_size_w, device = device)

View File

@@ -187,7 +187,7 @@ class InteractiveWindowedSelfAttention(nn.Module):
out = rearrange(out, '(b x y) h (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz_h, y = width // wsz_w, w1 = wsz_h, w2 = wsz_w)
# add LIM output
# add LIM output
out = out + local_out

View File

@@ -26,7 +26,7 @@ def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
omega = 1. / (temperature ** omega)
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
return pe.type(dtype)
@@ -57,7 +57,7 @@ class Attend(nn.Module):
config = self.cuda_config if q.is_cuda else self.cpu_config
# flash attention - https://arxiv.org/abs/2205.14135
with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(q, k, v)
@@ -140,7 +140,7 @@ class SimpleViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, use_flash = True):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
self.patch_size = patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

View File

@@ -34,7 +34,7 @@ def posemb_sincos_3d(patches, temperature = 10000, dtype = torch.float32):
z = z.flatten()[:, None] * omega[None, :]
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos(), z.sin(), z.cos()), dim = 1)
@@ -52,7 +52,7 @@ class Attend(Module):
def flash_attn(self, q, k, v):
# flash attention - https://arxiv.org/abs/2205.14135
with torch.backends.cuda.sdp_kernel(**self.config._asdict()):
out = F.scaled_dot_product_attention(q, k, v)
@@ -137,7 +137,7 @@ class SimpleViT(Module):
def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, use_flash_attn = True):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(image_patch_size)
self.patch_size = patch_height, patch_width = pair(image_patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
assert frames % frame_patch_size == 0, 'Frames must be divisible by the frame patch size'

View File

@@ -35,7 +35,7 @@ def FeedForward(dim, hidden_dim):
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64):
@@ -107,7 +107,7 @@ class SimpleUViT(Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, num_register_tokens = 4, channels = 3, dim_head = 64):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
self.patch_size = patch_height, patch_width = pair(patch_size)
assert divisible_by(image_height, patch_height) and divisible_by(image_width, patch_width), 'Image dimensions must be divisible by the patch size.'

View File

@@ -81,7 +81,7 @@ class SimpleViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
self.patch_size = patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
@@ -98,7 +98,7 @@ class SimpleViT(nn.Module):
h = image_height // patch_height,
w = image_width // patch_width,
dim = dim,
)
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

View File

@@ -26,7 +26,7 @@ def posemb_sincos_3d(patches, temperature = 10000, dtype = torch.float32):
z = z.flatten()[:, None] * omega[None, :]
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos(), z.sin(), z.cos()), dim = 1)
@@ -94,7 +94,7 @@ class SimpleViT(nn.Module):
def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(image_patch_size)
self.patch_size = patch_height, patch_width = pair(image_patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
assert frames % frame_patch_size == 0, 'Frames must be divisible by the frame patch size'

View File

@@ -0,0 +1,243 @@
# @bycloud explanation https://www.youtube.com/watch?v=iw1VF8HOCrk
from __future__ import annotations
import torch
from torch import nn, Tensor
from torch.nn import Module, ModuleList
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
def last(arr):
return arr[-1]
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def divisible_by(num, den):
return (num % den) == 0
def posemb_sincos_2d(h, w, dim, temperature = 10000, dtype = torch.float32):
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing = 'ij')
assert divisible_by(dim, 4), 'feature dimension must be multiple of 4 for sincos emb'
omega = torch.arange(dim // 4) / (dim // 4 - 1)
omega = 1.0 / (temperature ** omega)
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
return pe.type(dtype)
# classes
class FeedForward(Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
def forward(self, x):
return self.net(x)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, cross_attend = False):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.norm_context = nn.LayerNorm(dim) if cross_attend else None
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x, context = None):
x = self.norm(x)
if exists(context):
context = self.norm_context(context)
else:
context = x
q = self.to_q(x)
k, v = self.to_kv(context).chunk(2, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = dots.softmax(dim = -1)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class AttentionResidual(Module):
def __init__(self, fn, dim, heads = 8, dim_head = 64, learned_query = True, disable = False):
super().__init__()
self.fn = fn
self.disable = disable
if disable:
return
self.attn = Attention(dim, heads = heads, dim_head = dim_head, cross_attend = True)
self.learned_query = nn.Parameter(torch.randn(dim)) if learned_query else None
def forward(self, history: list[Tensor]) -> Tensor:
if self.disable:
return self.fn(last(history))
batch, seq_len = history[0].shape[:2]
context = torch.stack(history, dim = 2)
context = rearrange(context, 'b n l d -> (b n) l d')
if exists(self.learned_query):
q = repeat(self.learned_query, 'd -> (b n) 1 d', b = batch, n = seq_len)
else:
q = rearrange(last(history), 'b n d -> (b n) 1 d')
pooled = self.attn(q, context = context)
pooled = rearrange(pooled, '(b n) 1 d -> b n d', b = batch, n = seq_len)
return self.fn(pooled)
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, learned_query = True):
super().__init__()
self.layers = ModuleList([])
for ind in range(depth):
is_first = ind == 0
self.layers.append(ModuleList([
AttentionResidual(Attention(dim, heads = heads, dim_head = dim_head), dim, heads = heads, dim_head = dim_head, learned_query = learned_query, disable = is_first),
AttentionResidual(FeedForward(dim, mlp_dim), dim, heads = heads, dim_head = dim_head, learned_query = learned_query),
]))
self.final_pool = AttentionResidual(nn.LayerNorm(dim), dim, heads = heads, dim_head = dim_head, learned_query = learned_query)
def forward(
self,
x,
history: list[Tensor] | None = None,
return_history = False
):
history = [*default(history, [])]
history.append(x)
for attn_residual, ff_residual in self.layers:
history.append(attn_residual(history))
history.append(ff_residual(history))
out = self.final_pool(history)
if return_history:
return out, history
return out
class SimpleViTAttnResidual(Module):
def __init__(
self,
*,
image_size,
patch_size,
num_classes,
dim,
depth,
heads,
mlp_dim,
channels = 3,
dim_head = 64,
learned_query = True
):
super().__init__()
image_height, image_width = pair(image_size)
self.patch_size = patch_height, patch_width = pair(patch_size)
assert divisible_by(image_height, patch_height) and divisible_by(image_width, patch_width), 'Image dimensions must be divisible by the patch size.'
patch_dim = channels * patch_height * patch_width
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.pos_embedding = posemb_sincos_2d(
h = image_height // patch_height,
w = image_width // patch_width,
dim = dim,
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, learned_query = learned_query)
self.to_latent = nn.Identity()
self.linear_head = nn.Linear(dim, num_classes)
def forward(
self,
img,
history: list[Tensor] | None = None,
return_history = False
):
device, dtype = img.device, img.dtype
x = self.to_patch_embedding(img)
x += self.pos_embedding.to(device, dtype = dtype)
x = self.transformer(x, history = history, return_history = return_history)
if return_history:
x, history = x
x = x.mean(dim = 1)
x = self.to_latent(x)
out = self.linear_head(x)
if return_history:
return out, history
return out
if __name__ == '__main__':
for learned_query in (True, False):
v = SimpleViTAttnResidual(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
learned_query = learned_query
)
img = torch.randn(2, 3, 256, 256)
preds, history = v(img, return_history = True)
assert preds.shape == (2, 1000)
preds, _ = v(img, history = history, return_history = True)
assert preds.shape == (2, 1000)

View File

@@ -0,0 +1,206 @@
# Revisiting Residual Connections: Orthogonal Updates for Stable and Efficient Deep Networks
# Giyeong Oh et al. https://arxiv.org/abs/2505.11881
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import Module, ModuleList
from einops import rearrange
from einops.layers.torch import Rearrange
# helpers
def exists(v):
return v is not None
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
omega = torch.arange(dim // 4) / (dim // 4 - 1)
omega = 1.0 / (temperature ** omega)
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
return pe.type(dtype)
# classes
class FeedForward(Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
def forward(self, x):
return self.net(x)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class OrthogonalResidualUpdate(Module):
def __init__(
self,
block: Module,
dim = None,
double_precision = True,
learned = False
):
super().__init__()
self.block = block
self.double_precision = double_precision
self.learned = learned
if learned:
assert exists(dim)
self.to_modulation = nn.Linear(dim, 2)
def orthog_proj(self, block_out, residual):
use_double, dtype = self.double_precision, residual.dtype
if use_double:
residual, block_out = residual.double(), block_out.double()
# get orthogonal projection of the attention or feedforward output respect to residual
unit = F.normalize(residual, dim = -1)
parallel = (block_out * unit).sum(dim = -1, keepdim = True) * unit
orthogonal = block_out - parallel
# back to original dtype if double precision
if use_double:
parallel, orthogonal = parallel.to(dtype), orthogonal.to(dtype)
return parallel, orthogonal
def forward(self, residual):
block_out = self.block(residual)
parallel_update, orthog_update = self.orthog_proj(block_out, residual)
if self.learned:
parallel_mod, orthog_mod = self.to_modulation(block_out).sigmoid().split(1, dim = -1)
parallel_update = parallel_update * parallel_mod
orthog_update = orthog_update * orthog_mod
else:
parallel_update = 0
return residual + parallel_update + orthog_update
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, orthog_residual_update_kwargs: dict = dict()):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = ModuleList([])
for _ in range(depth):
attn = Attention(dim, heads = heads, dim_head = dim_head)
ff = FeedForward(dim, mlp_dim)
self.layers.append(ModuleList([
OrthogonalResidualUpdate(attn, dim = dim, **orthog_residual_update_kwargs),
OrthogonalResidualUpdate(ff, dim = dim, **orthog_residual_update_kwargs)
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x)
x = ff(x)
return self.norm(x)
class SimpleViT(Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, orthog_residual_update_kwargs: dict = dict()):
super().__init__()
image_height, image_width = pair(image_size)
self.patch_size = patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
patch_dim = channels * patch_height * patch_width
self.to_patch_embedding = nn.Sequential(
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.pos_embedding = posemb_sincos_2d(
h = image_height // patch_height,
w = image_width // patch_width,
dim = dim,
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, orthog_residual_update_kwargs)
self.pool = "mean"
self.to_latent = nn.Identity()
self.linear_head = nn.Linear(dim, num_classes)
def forward(self, img):
device = img.device
x = self.to_patch_embedding(img)
x += self.pos_embedding.to(device, dtype=x.dtype)
x = self.transformer(x)
x = x.mean(dim = 1)
x = self.to_latent(x)
return self.linear_head(x)
# quick test
if __name__ == '__main__':
vit = SimpleViT(
image_size = 256,
patch_size = 16,
num_classes = 10,
dim = 512,
depth = 2,
heads = 4,
mlp_dim = 2048,
orthog_residual_update_kwargs = dict(
learned = True
)
)
images = torch.randn(2, 3, 256, 256)
assert vit(images).shape == (2, 10)

View File

@@ -82,7 +82,7 @@ class SimpleViT(nn.Module):
def __init__(self, *, image_size, patch_size, freq_patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
self.patch_size = patch_height, patch_width = pair(patch_size)
freq_patch_height, freq_patch_width = pair(freq_patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

View File

@@ -167,7 +167,7 @@ class SimpleViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, num_residual_streams, num_register_tokens = 4, channels = 3, dim_head = 64):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
self.patch_size = patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
@@ -186,7 +186,7 @@ class SimpleViT(nn.Module):
h = image_height // patch_height,
w = image_width // patch_width,
dim = dim,
)
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, num_residual_streams)

View File

@@ -18,7 +18,7 @@ def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
omega = 1. / (temperature ** omega)
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
return pe.type(dtype)
@@ -104,7 +104,7 @@ class SimpleViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, patch_dropout = 0.5):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
self.patch_size = patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

View File

@@ -102,7 +102,7 @@ class SimpleViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
self.patch_size = patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
@@ -119,7 +119,7 @@ class SimpleViT(nn.Module):
h = image_height // patch_height,
w = image_width // patch_width,
dim = dim,
)
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

View File

@@ -86,7 +86,7 @@ class SimpleViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, num_register_tokens = 4, channels = 3, dim_head = 64):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
self.patch_size = patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
@@ -105,7 +105,7 @@ class SimpleViT(nn.Module):
h = image_height // patch_height,
w = image_width // patch_width,
dim = dim,
)
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

View File

@@ -0,0 +1,205 @@
from __future__ import annotations
# Alexis Marouani et al. https://arxiv.org/abs/2602.08626
import torch
from torch import nn, cat, Tensor, is_tensor
from torch.nn import Module, ModuleList
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
omega = torch.arange(dim // 4) / (dim // 4 - 1)
omega = 1.0 / (temperature ** omega)
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
return pe.type(dtype)
# classes
class Specialized(Module):
def __init__(self, modules: list[Module]):
super().__init__()
self.fns = ModuleList(modules)
def forward(
self,
x: Tensor | list[Tensor],
token_lens: tuple[int, ...] = None
):
if is_tensor(x):
assert exists(token_lens)
x = x.split(token_lens, dim = 1)
assert len(self.fns) == len(x)
out = tuple(fn(t) for fn, t in zip(self.fns, x))
if is_tensor:
out = cat(out, dim = 1)
return out
class FeedForward(Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.norm = Specialized([
nn.LayerNorm(dim),
nn.LayerNorm(dim)
])
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
def forward(self, x, token_lens = None):
x = self.norm(x, token_lens = token_lens)
return self.net(x)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, specialize_qkv = False):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = Specialized([
nn.LayerNorm(dim),
nn.LayerNorm(dim)
])
self.attend = nn.Softmax(dim = -1)
self.specialize_qkv = specialize_qkv
if specialize_qkv:
self.to_qkv = Specialized([
nn.Linear(dim, inner_dim * 3, bias = False),
nn.Linear(dim, inner_dim * 3, bias = False)
])
else:
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x, token_lens = None):
x = self.norm(x, token_lens = token_lens)
if self.specialize_qkv:
qkv = self.to_qkv(x, token_lens = token_lens).chunk(3, dim = -1)
else:
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = (rearrange(t, 'b n (h d) -> b h n d', h = self.heads) for t in qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, specialize_qkv_depth):
super().__init__()
self.norm = Specialized([nn.LayerNorm(dim), nn.LayerNorm(dim)])
self.layers = ModuleList([])
for ind in range(depth):
specialize_qkv = ind < specialize_qkv_depth
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, specialize_qkv = specialize_qkv),
FeedForward(dim, mlp_dim)
]))
def forward(self, x, token_lens = None):
for attn, ff in self.layers:
x = attn(x, token_lens = token_lens) + x
x = ff(x, token_lens = token_lens) + x
return self.norm(x, token_lens = token_lens)
class SimpleViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, specialize_qkv_depth = None):
super().__init__()
image_height, image_width = pair(image_size)
self.patch_size = patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
patch_dim = channels * patch_height * patch_width
self.to_patch_embedding = nn.Sequential(
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.pos_embedding = posemb_sincos_2d(
h = image_height // patch_height,
w = image_width // patch_width,
dim = dim,
)
self.cls_token = nn.Parameter(torch.randn(dim) * 1e-2)
specialize_qkv_depth = default(specialize_qkv_depth, depth // 3) # author found just first third of transformer having specialized qkv projection for cls token is enough
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, specialize_qkv_depth)
self.pool = 'cls'
self.to_latent = nn.Identity()
self.linear_head = nn.Linear(dim, num_classes)
def forward(self, img):
device = img.device
x = self.to_patch_embedding(img)
x += self.pos_embedding.to(device, dtype=x.dtype)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, 'd -> b 1 d', b = b)
x = cat((cls_tokens, x), dim = 1)
x = self.transformer(x, token_lens = (1, n))
x = x[:, 0]
x = self.to_latent(x)
return self.linear_head(x)
if __name__ == '__main__':
v = SimpleViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048
)
img = torch.randn(1, 3, 256, 256)
out = v(img)
assert out.shape == (1, 1000)

View File

@@ -103,7 +103,7 @@ class SimpleViT(Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
self.patch_size = patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
@@ -120,7 +120,7 @@ class SimpleViT(Module):
h = image_height // patch_height,
w = image_width // patch_width,
dim = dim,
)
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

View File

@@ -41,7 +41,7 @@ def posemb_sincos_2d(
omega = temperature ** -omega
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
pe = pe.type(dtype)
@@ -120,6 +120,12 @@ class Attention(Module):
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim_context, inner_dim * 2, bias = False)
self.to_out_gates = nn.Sequential(
nn.Linear(dim, heads),
Rearrange('b ... h -> b h ... 1'),
nn.Sigmoid()
)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
@@ -150,6 +156,9 @@ class Attention(Module):
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = out * self.to_out_gates(x)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
@@ -222,7 +231,7 @@ class AST(Module):
self.dim = dim
self.depth = depth
patch_height, patch_width = pair(patch_size)
self.patch_size = patch_height, patch_width = pair(patch_size)
patch_input_dim = patch_height * patch_width
self.patch_size = (patch_height, patch_width)
@@ -348,7 +357,7 @@ class ViT(Module):
self.depth = depth
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
self.patch_size = patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
@@ -433,7 +442,8 @@ class VAAT(Module):
self_attn_heads = 4,
self_attn_dim_head = 32,
ast_layer_indices: tuple[int, ...] | None = None,
vit_layer_indices: tuple[int, ...] | None = None
vit_layer_indices: tuple[int, ...] | None = None,
num_advantage_bins = 0
):
super().__init__()
@@ -471,7 +481,7 @@ class VAAT(Module):
assert len(ast_layer_indices) == depth, f'number of ast layer indices {len(ast_layer_indices)} does not much the VAAT depth {depth}'
self.register_buffer('ast_layer_indices', tensor(vit_layer_indices), persistent = False)
self.register_buffer('ast_layer_indices', tensor(ast_layer_indices), persistent = False)
# handle maybe multiple frames
@@ -502,6 +512,14 @@ class VAAT(Module):
self.action_pos_emb = nn.Parameter(torch.randn(action_chunk_len, dim) * 1e-2)
# handle maybe advantage conditioning
self.has_advantages = num_advantage_bins > 0
self.num_advantage_bins = num_advantage_bins
if self.has_advantages:
self.advantage_emb = nn.Embedding(num_advantage_bins + 1, dim)
self.layers = ModuleList([])
for _ in range(depth):
@@ -531,14 +549,15 @@ class VAAT(Module):
video_or_image, # (b v? c t? h w) - batch, views [wrist + third person or more], channels, maybe time, height, width
audio_or_spec, # (b v? t) | (b v?f t) - batch, audio len | batch, spec freq, time
*,
extra = None, # (b d) - batch, dim extra
extra = None, # (b d) - batch, dim extra
tasks = None, # (b)
advantages = None,# (b)
actions = None, # (b k d) - batch, action chunk length, action dimension
return_hiddens = False,
freeze_vit = False,
freeze_ast = False
):
batch = video_or_image.shape[0]
batch, device = video_or_image.shape[0], video_or_image.device
return_loss = exists(actions)
# handle some various input dimensions
@@ -646,53 +665,66 @@ class VAAT(Module):
audio_context = rearrange(audio_hiddens, 'l b v n d -> l b (v n) d')
# get main action tokens and maybe append extra
# main action tokens
action_tokens = repeat(self.action_pos_emb, 'k d -> b k d', b = batch)
action_tokens = repeat(self.action_pos_emb, 'n d -> b n d', b = batch)
has_extra = exists(extra)
# maybe advantage tokens
if has_extra:
assert self.accept_extra_token
empty_token = action_tokens[:, 0:0]
extra_token = self.to_extra_token(extra)
maybe_advantage_embed = empty_token
action_tokens, packed_extra = pack([action_tokens, extra_token], 'b * d')
if self.has_advantages and exists(advantages):
if isinstance(advantages, int):
advantages = torch.full((batch,), advantages, device = device, dtype = torch.long)
maybe_advantage_embed = self.advantage_emb(advantages + 1)
# register tokens
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
register_tokens = empty_token
action_tokens, registers_packed_shape = pack((register_tokens, action_tokens), 'b * d')
if exists(self.register_tokens):
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
# cross attention
# extra
hiddens = [action_tokens]
maybe_extra_embed = empty_token
has_extra = exists(extra)
if has_extra:
assert self.accept_extra_token
maybe_extra_embed = self.to_extra_token(extra)
# pack all tokens for attention
tokens, ps = pack((register_tokens, maybe_advantage_embed, action_tokens, maybe_extra_embed), 'b * d')
# transformer
hiddens = [tokens]
for (maybe_film, maybe_self_attn, image_cross_attn, audio_cross_attn, ff), image_layer_context, audio_layer_context in zip(self.layers, image_context, audio_context):
if exists(tasks):
action_tokens = maybe_film(action_tokens, task_emb)
if exists(maybe_film) and exists(tasks):
tokens = maybe_film(tokens, task_emb)
action_tokens = image_cross_attn(action_tokens, image_layer_context) + action_tokens
tokens = image_cross_attn(tokens, image_layer_context) + tokens
action_tokens = audio_cross_attn(action_tokens, audio_layer_context) + action_tokens
tokens = audio_cross_attn(tokens, audio_layer_context) + tokens
if exists(maybe_self_attn):
action_tokens = maybe_self_attn(action_tokens) + action_tokens
tokens = maybe_self_attn(tokens) + tokens
action_tokens = ff(action_tokens) + action_tokens
tokens = ff(tokens) + tokens
hiddens.append(action_tokens)
hiddens.append(tokens)
# unpack registers
# unpack register, advantage, action, and extra tokens
_, action_tokens = unpack(action_tokens, registers_packed_shape, 'b * d')
# maybe unpack extra
if has_extra:
action_tokens, _ = unpack(action_tokens, packed_extra, 'b * d')
maybe_register_embed, maybe_advantage_embed, action_tokens, maybe_extra_embed = unpack(tokens, ps, 'b * d')
# norm and prediction
@@ -735,43 +767,51 @@ if __name__ == '__main__':
mlp_dim = 384 * 4
)
vaat = VAAT(
vit,
ast,
dim = 512,
depth = 9,
heads = 8,
dim_head = 64,
mlp_dim = 2048,
dim_action = 20,
action_chunk_len = 7,
time_seq_len = 4,
num_image_views = 2,
num_audio_views = 2,
num_tasks = 4,
add_self_attn = True,
dim_extra_token = 33, # extra token with some variable dimension
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
0, 0, 1, 1, 2, 2, 3, 3, 4
),
ast_layer_indices = (
1, 1, 1, 2, 2, 2, 3, 3, 3
for num_adv_bins in (0, 2, 10):
vaat = VAAT(
vit,
ast,
dim = 512,
depth = 9,
heads = 8,
dim_head = 64,
mlp_dim = 2048,
dim_action = 20,
action_chunk_len = 7,
time_seq_len = 4,
num_image_views = 2,
num_audio_views = 2,
num_tasks = 4,
num_advantage_bins = num_adv_bins,
add_self_attn = True,
dim_extra_token = 33, # extra token with some variable dimension
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
0, 0, 1, 1, 2, 2, 3, 3, 4
),
ast_layer_indices = (
1, 1, 1, 2, 2, 2, 3, 3, 3
)
)
)
images = torch.randn(2, 2, 3, 4, 256, 256) # (2 views with 4 frames)
audio = torch.randn(2, 2, 14_100 * 5)
images = torch.randn(2, 2, 3, 4, 256, 256) # (2 views with 4 frames)
audio = torch.randn(2, 2, 14_100 * 5)
tasks = torch.randint(0, 4, (2,))
extra = torch.randn(2, 33) # extra internal state
tasks = torch.randint(0, 4, (2,))
extra = torch.randn(2, 33) # extra internal state
actions = torch.randn(2, 7, 20) # actions for learning
# advantage conditioning
loss = vaat(images, audio, actions = actions, tasks = tasks, extra = extra, freeze_vit = True)
loss.backward()
advantages = None
if num_adv_bins > 0:
advantages = torch.randint(-1, num_adv_bins, (2,))
# after much training
actions = torch.randn(2, 7, 20) # actions for learning
pred_actions, hiddens = vaat(images, audio, tasks = tasks, extra = extra, return_hiddens = True)
loss = vaat(images, audio, actions = actions, advantages = advantages, tasks = tasks, extra = extra, freeze_vit = True)
loss.backward()
assert pred_actions.shape == (2, 7, 20)
# after much training
pred_actions, hiddens = vaat(images, audio, advantages = advantages, tasks = tasks, extra = extra, return_hiddens = True)
assert pred_actions.shape == (2, 7, 20)

View File

@@ -92,6 +92,12 @@ class Attention(Module):
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim_context, inner_dim * 2, bias = False)
self.to_out_gates = nn.Sequential(
nn.Linear(dim, heads),
Rearrange('b ... h -> b h ... 1'),
nn.Sigmoid()
)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
@@ -122,6 +128,8 @@ class Attention(Module):
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = out * self.to_out_gates(x) # https://arxiv.org/abs/2505.06708
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
@@ -189,7 +197,7 @@ class ViT(Module):
self.depth = depth
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
self.patch_size = patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
@@ -270,7 +278,8 @@ class VAT(Module):
add_self_attn = True, # in the paper, they didn't have any ways for the action token to exchange information with the extra token, so we'll just add it as an option
self_attn_heads = 4,
self_attn_dim_head = 32,
vit_layer_indices: tuple[int, ...] | None = None
vit_layer_indices: tuple[int, ...] | None = None,
num_advantage_bins = 0
):
super().__init__()
@@ -316,6 +325,14 @@ class VAT(Module):
self.action_pos_emb = nn.Parameter(torch.randn(action_chunk_len, dim) * 1e-2)
# handle maybe advantage conditioning
self.has_advantages = num_advantage_bins > 0
self.num_advantage_bins = num_advantage_bins
if self.has_advantages:
self.advantage_emb = nn.Embedding(num_advantage_bins + 1, dim)
self.layers = ModuleList([])
for _ in range(depth):
@@ -343,13 +360,14 @@ class VAT(Module):
self,
video_or_image, # (b v? c t? h w) - batch, views [wrist + third person or more], channels, maybe time, height, width
*,
extra = None, # (b d) - batch, dim extra
extra = None, # (b d) - batch, dim extra
tasks = None, # (b)
advantages = None,# (b)
actions = None, # (b k d) - batch, action chunk length, action dimension
return_hiddens = False,
freeze_vit = False
):
batch = video_or_image.shape[0]
batch, device = video_or_image.shape[0], video_or_image.device
return_loss = exists(actions)
# handle some various input dimensions
@@ -415,51 +433,64 @@ class VAT(Module):
context = rearrange(hiddens, 'l b v t n d -> l b (v t n) d')
# get main action tokens and maybe append extra
# main action tokens
action_tokens = repeat(self.action_pos_emb, 'k d -> b k d', b = batch)
action_tokens = repeat(self.action_pos_emb, 'n d -> b n d', b = batch)
has_extra = exists(extra)
# maybe advantage tokens
if has_extra:
assert self.accept_extra_token
empty_token = action_tokens[:, 0:0]
extra_token = self.to_extra_token(extra)
maybe_advantage_embed = empty_token
action_tokens, packed_extra = pack([action_tokens, extra_token], 'b * d')
if self.has_advantages and exists(advantages):
if isinstance(advantages, int):
advantages = torch.full((batch,), advantages, device = device, dtype = torch.long)
maybe_advantage_embed = self.advantage_emb(advantages + 1)
# register tokens
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
register_tokens = empty_token
action_tokens, registers_packed_shape = pack((register_tokens, action_tokens), 'b * d')
if exists(self.register_tokens):
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
# cross attention
# extra
hiddens = [action_tokens]
maybe_extra_embed = empty_token
has_extra = exists(extra)
if has_extra:
assert self.accept_extra_token
maybe_extra_embed = self.to_extra_token(extra)
# pack all tokens for attention
tokens, ps = pack((register_tokens, maybe_advantage_embed, action_tokens, maybe_extra_embed), 'b * d')
# transformer
hiddens = [tokens]
for (maybe_film, maybe_self_attn, cross_attn, ff), layer_context in zip(self.layers, context):
if exists(tasks):
action_tokens = maybe_film(action_tokens, task_emb)
if exists(maybe_film) and exists(tasks):
tokens = maybe_film(tokens, task_emb)
action_tokens = cross_attn(action_tokens, layer_context) + action_tokens
tokens = cross_attn(tokens, layer_context) + tokens
if exists(maybe_self_attn):
action_tokens = maybe_self_attn(action_tokens) + action_tokens
tokens = maybe_self_attn(tokens) + tokens
action_tokens = ff(action_tokens) + action_tokens
tokens = ff(tokens) + tokens
hiddens.append(action_tokens)
hiddens.append(tokens)
# unpack registers
# unpack register, advantage, action, and extra tokens
_, action_tokens = unpack(action_tokens, registers_packed_shape, 'b * d')
# maybe unpack extra
if has_extra:
action_tokens, _ = unpack(action_tokens, packed_extra, 'b * d')
maybe_register_embed, maybe_advantage_embed, action_tokens, maybe_extra_embed = unpack(tokens, ps, 'b * d')
# norm and prediction
@@ -493,36 +524,44 @@ if __name__ == '__main__':
mlp_dim = 1024
)
vat = VAT(
vit,
dim = 512,
depth = 9,
heads = 8,
dim_head = 64,
mlp_dim = 2048,
dim_action = 20,
action_chunk_len = 7,
time_seq_len = 4,
num_views = 2,
num_tasks = 4,
add_self_attn = True,
dim_extra_token = 33, # extra token with some variable dimension
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
0, 0, 1, 1, 2, 2, 3, 3, 4
for num_adv_bins in (0, 2, 10):
vat = VAT(
vit,
dim = 512,
depth = 9,
heads = 8,
dim_head = 64,
mlp_dim = 2048,
dim_action = 20,
action_chunk_len = 7,
time_seq_len = 4,
num_views = 2,
num_tasks = 4,
num_advantage_bins = num_adv_bins,
add_self_attn = True,
dim_extra_token = 33, # extra token with some variable dimension
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
0, 0, 1, 1, 2, 2, 3, 3, 4
)
)
)
images = torch.randn(2, 2, 3, 4, 256, 256) # (2 views with 4 frames)
tasks = torch.randint(0, 4, (2,))
extra = torch.randn(2, 33) # extra internal state
images = torch.randn(2, 2, 3, 4, 256, 256) # (2 views with 4 frames)
tasks = torch.randint(0, 4, (2,))
extra = torch.randn(2, 33) # extra internal state
actions = torch.randn(2, 7, 20) # actions for learning
# advantage conditioning
loss = vat(images, actions = actions, tasks = tasks, extra = extra, freeze_vit = True)
loss.backward()
advantages = None
if num_adv_bins > 0:
advantages = torch.randint(-1, num_adv_bins, (2,))
# after much training
actions = torch.randn(2, 7, 20) # actions for learning
pred_actions, hiddens = vat(images, tasks = tasks, extra = extra, return_hiddens = True)
loss = vat(images, actions = actions, advantages = advantages, tasks = tasks, extra = extra, freeze_vit = True)
loss.backward()
assert pred_actions.shape == (2, 7, 20)
# after much training
pred_actions, hiddens = vat(images, advantages = advantages, tasks = tasks, extra = extra, return_hiddens = True)
assert pred_actions.shape == (2, 7, 20)

521
vit_pytorch/vat_siglip.py Normal file
View File

@@ -0,0 +1,521 @@
from __future__ import annotations
from contextlib import nullcontext
from pathlib import Path
import torch
import torch.nn.functional as F
from torch import nn, cat, stack, tensor, einsum
from torch.nn import Module, ModuleList
from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange
# helpers
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# attention
class Attention(Module):
def __init__(
self,
dim,
dim_context = None,
heads = 8,
dim_head = 64,
dropout = 0.,
norm_eps = 1e-6,
gate_attn = False
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
inner_dim = dim_head * heads
self.norm = nn.LayerNorm(dim, eps = norm_eps)
self.is_cross_attn = exists(dim_context)
dim_context = default(dim_context, dim)
self.norm_context = nn.LayerNorm(dim_context, eps = norm_eps) if self.is_cross_attn else None
self.to_q = nn.Linear(dim, inner_dim)
self.to_kv = nn.Linear(dim_context, inner_dim * 2)
self.to_out_gates = nn.Sequential(
nn.Linear(dim, heads),
Rearrange('b ... h -> b h ... 1'),
nn.Sigmoid()
) if gate_attn else None
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, context = None):
x = self.norm(x)
if self.is_cross_attn:
assert exists(context)
context = self.norm_context(context)
else:
context = x
q = self.to_q(x)
k, v = self.to_kv(context).chunk(2, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))
sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
if exists(self.to_out_gates):
out = out * self.to_out_gates(x) # https://arxiv.org/abs/2505.06708
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
def FeedForward(
dim,
dim_inner,
norm_eps = 1e-6
):
return nn.Sequential(
nn.LayerNorm(dim, eps = norm_eps),
nn.Linear(dim, dim_inner),
nn.GELU(approximate = 'tanh'),
nn.Linear(dim_inner, dim)
)
class SigLIP(Module):
def __init__(
self,
image_size = 224,
patch_size = 14,
dim = 1152,
depth = 27,
heads = 16,
mlp_dim = 4304,
norm_eps = 1e-6
):
super().__init__()
self.dim = dim
self.depth = depth
num_patches = (image_size // patch_size) ** 2
dim_head = dim // heads
self.to_patch_embed = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
nn.Linear(patch_size * patch_size * 3, dim)
)
self.pos_embed = nn.Parameter(torch.randn(num_patches, dim))
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, norm_eps = norm_eps),
FeedForward(dim = dim, dim_inner = mlp_dim, norm_eps = norm_eps)
]))
self.norm = nn.LayerNorm(dim, eps = norm_eps)
def forward(self, x, return_hiddens = False):
x = self.to_patch_embed(x)
num_patches = x.shape[1]
x = x + self.pos_embed[:num_patches]
hiddens = []
for attn, ff in self.layers:
hiddens.append(x)
x = attn(x) + x
x = ff(x) + x
out = self.norm(x)
if not return_hiddens:
return out
return out, stack(hiddens)
class FiLM(Module):
def __init__(self, dim):
super().__init__()
proj = nn.Linear(dim, dim * 2)
self.to_gamma_beta = nn.Sequential(
proj,
Rearrange('b (two d) -> two b 1 d', two = 2)
)
nn.init.zeros_(proj.weight)
nn.init.zeros_(proj.bias)
def forward(self, tokens, cond):
gamma, beta = self.to_gamma_beta(cond)
return tokens * gamma + beta
class SigLIPVAT(Module):
def __init__(
self,
*,
dim = 512,
depth = 27,
heads = 8,
dim_head = 64,
dim_action = 32,
mlp_dim = 2048,
num_views = 1,
num_tasks = None,
dim_extra_token = None,
num_register_tokens = 4,
action_chunk_len = 50,
time_seq_len = 1,
dropout = 0.,
add_self_attn = True,
self_attn_heads = 4,
self_attn_dim_head = 32,
vit_layer_indices: tuple[int, ...] | None = None,
num_advantage_bins = 0,
siglip_image_size = 224,
siglip_patch_size = 14,
siglip_dim = 1152,
siglip_depth = 27,
siglip_heads = 16,
siglip_mlp_dim = 4304,
siglip_norm_eps = 1e-6,
):
super().__init__()
self.vit = SigLIP(
image_size = siglip_image_size,
patch_size = siglip_patch_size,
dim = siglip_dim,
depth = siglip_depth,
heads = siglip_heads,
mlp_dim = siglip_mlp_dim,
norm_eps = siglip_norm_eps
)
vit_dim = siglip_dim
self.vit_dim = vit_dim
vit_layer_indices = default(vit_layer_indices, tuple(range(depth)))
assert len(vit_layer_indices) == depth, f'number of vit layer indices {len(vit_layer_indices)} does not much the VAT depth {depth}'
self.register_buffer('layer_indices', tensor(vit_layer_indices), persistent = False)
# handle maybe multiple frames
is_video = time_seq_len > 1
self.is_video = is_video
self.time_seq_len = time_seq_len
self.time_pos_emb = nn.Parameter(torch.randn(time_seq_len, vit_dim) * 1e-2) if is_video else None
# maybe view embeddings
self.view_emb = nn.Parameter(torch.randn(num_views, vit_dim) * 1e-2) if exists(num_views) and num_views > 1 else None
# handle maybe task conditioning
self.has_tasks = exists(num_tasks)
if self.has_tasks:
self.task_emb = nn.Parameter(torch.randn(num_tasks, dim) * 1e-2)
# register tokens
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
# to action tokens
self.action_pos_emb = nn.Parameter(torch.randn(action_chunk_len, dim) * 1e-2)
# handle maybe advantage conditioning
self.has_advantages = num_advantage_bins > 0
self.num_advantage_bins = num_advantage_bins
if self.has_advantages:
self.advantage_emb = nn.Embedding(num_advantage_bins + 1, dim)
self.layers = ModuleList([])
for _ in range(depth):
maybe_film = FiLM(dim = dim) if self.has_tasks else None
maybe_self_attn = Attention(dim = dim, heads = self_attn_heads, dim_head = self_attn_dim_head, dropout = dropout) if add_self_attn else None
self.layers.append(ModuleList([
maybe_film,
maybe_self_attn,
Attention(dim = dim, dim_context = vit_dim, heads = heads, dim_head = dim_head, dropout = dropout, gate_attn = True),
FeedForward(dim = dim, dim_inner = mlp_dim)
]))
self.final_norm = nn.LayerNorm(dim)
self.to_pred_action = nn.Linear(dim, dim_action, bias = False)
# handle the extra token
self.accept_extra_token = exists(dim_extra_token)
if exists(dim_extra_token):
self.to_extra_token = nn.Linear(dim_extra_token, dim)
def load_siglip(
self,
repo_id = 'google/siglip-so400m-patch14-224',
folder = 'checkpoints/siglip'
):
folder = Path(folder)
if not folder.exists():
from huggingface_hub import snapshot_download
snapshot_download(
repo_id = repo_id,
local_dir = folder,
allow_patterns = ['config.json', 'model.safetensors']
)
from safetensors import safe_open
weights_path = folder / 'model.safetensors'
# Auto-detect prefix based on keys
with safe_open(weights_path, framework = 'pt') as f:
keys = f.keys()
vi_p = ''
if any(k.startswith('paligemma_with_expert.paligemma.model.vision_tower.vision_model') for k in keys):
vi_p = 'paligemma_with_expert.paligemma.model.vision_tower.vision_model.'
elif any(k.startswith('vision_model') for k in keys):
vi_p = 'vision_model.'
pz_state = self.vit.state_dict()
def copy_weight_bias(pz_prefix, vi_prefix):
pz_state[f'{pz_prefix}.weight'].copy_(f.get_tensor(f'{vi_prefix}.weight'))
pz_state[f'{pz_prefix}.bias'].copy_(f.get_tensor(f'{vi_prefix}.bias'))
# patch embedding
patch_weight = rearrange(f.get_tensor(f'{vi_p}embeddings.patch_embedding.weight'), 'd c h w -> d (h w c)')
pz_state['to_patch_embed.1.weight'].copy_(patch_weight)
pz_state['to_patch_embed.1.bias'].copy_(f.get_tensor(f'{vi_p}embeddings.patch_embedding.bias'))
# position embedding
pz_state['pos_embed'].copy_(f.get_tensor(f'{vi_p}embeddings.position_embedding.weight'))
# transformer layers
for i in range(self.vit.depth):
v_pi = f'{vi_p}encoder.layers.{i}'
v_pz = f'layers.{i}'
# attention
copy_weight_bias(f'{v_pz}.0.norm', f'{v_pi}.layer_norm1')
copy_weight_bias(f'{v_pz}.0.to_q', f'{v_pi}.self_attn.q_proj')
vk, vv = [f.get_tensor(f'{v_pi}.self_attn.{x}_proj.weight') for x in ('k', 'v')]
bk, bv = [f.get_tensor(f'{v_pi}.self_attn.{x}_proj.bias') for x in ('k', 'v')]
pz_state[f'{v_pz}.0.to_kv.weight'].copy_(cat((vk, vv), dim = 0))
pz_state[f'{v_pz}.0.to_kv.bias'].copy_(cat((bk, bv), dim = 0))
copy_weight_bias(f'{v_pz}.0.to_out.0', f'{v_pi}.self_attn.out_proj')
# feedforward
copy_weight_bias(f'{v_pz}.1.0', f'{v_pi}.layer_norm2')
copy_weight_bias(f'{v_pz}.1.1', f'{v_pi}.mlp.fc1')
copy_weight_bias(f'{v_pz}.1.3', f'{v_pi}.mlp.fc2')
# post-layernorm
copy_weight_bias('norm', f'{vi_p}post_layernorm')
self.vit.load_state_dict(pz_state)
print(f'Successfully loaded SigLIP weights from {repo_id}')
def forward(
self,
video_or_image, # (b v? c t? h w) - batch, views [wrist + third person or more], channels, maybe time, height, width
*,
extra = None, # (b d) - batch, dim extra
tasks = None, # (b)
advantages = None,# (b)
actions = None, # (b k d) - batch, action chunk length, action dimension
return_hiddens = False,
freeze_vit = False
):
batch, device = video_or_image.shape[0], video_or_image.device
return_loss = exists(actions)
# handle some various input dimensions
if video_or_image.ndim == 4:
video_or_image = rearrange(video_or_image, 'b 1 c h w')
if video_or_image.ndim == 5:
video_or_image = rearrange(video_or_image, 'b v c h w -> b v c 1 h w')
assert video_or_image.shape[3] == self.time_seq_len
# to images
images = rearrange(video_or_image, 'b v c t h w -> b v t c h w')
images, packed_shape = pack([images], '* c h w')
# get representation trajectory from vit
vit_forward_context = torch.no_grad if freeze_vit else nullcontext
with vit_forward_context():
embed, hiddens = self.vit(images, return_hiddens = True)
hiddens = cat((hiddens, embed[None, ...]))
# extract the hiddens needed for the action cross attention
hiddens = hiddens[self.layer_indices]
# pack temporarily for embedding
hiddens, = unpack(hiddens, packed_shape, 'l * n d') # l for layers
# maybe add time embeddings
if self.is_video:
time_pos_emb = rearrange(self.time_pos_emb, 't d -> t 1 d')
hiddens = hiddens + time_pos_emb
# maybe view embeddings
if exists(self.view_emb):
view_emb = rearrange(self.view_emb, 'v d -> v 1 1 d')
hiddens = hiddens + view_emb
# maybe tasks
if exists(tasks):
task_emb = self.task_emb[tasks]
# cross from actions to representation trajectory
context = rearrange(hiddens, 'l b v t n d -> l b (v t n) d')
# main action tokens
action_tokens = repeat(self.action_pos_emb, 'n d -> b n d', b = batch)
# maybe advantage tokens
empty_token = action_tokens[:, 0:0]
maybe_advantage_embed = empty_token
if self.has_advantages and exists(advantages):
if isinstance(advantages, int):
advantages = torch.full((batch,), advantages, device = device, dtype = torch.long)
maybe_advantage_embed = self.advantage_emb(advantages + 1)
# register tokens
register_tokens = empty_token
if exists(self.register_tokens):
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
# extra
maybe_extra_embed = empty_token
has_extra = exists(extra)
if has_extra:
maybe_extra_embed = self.to_extra_token(extra)
# pack all tokens for attention
tokens, ps = pack((register_tokens, maybe_advantage_embed, action_tokens, maybe_extra_embed), 'b * d')
# transformer
vat_hiddens = [tokens]
for (maybe_film, maybe_self_attn, cross_attn, ff), layer_context in zip(self.layers, context):
if exists(maybe_film) and exists(tasks):
tokens = maybe_film(tokens, task_emb)
tokens = cross_attn(tokens, layer_context) + tokens
if exists(maybe_self_attn):
tokens = maybe_self_attn(tokens) + tokens
tokens = ff(tokens) + tokens
vat_hiddens.append(tokens)
# unpack register, advantage, action, and extra tokens
maybe_register_embed, maybe_advantage_embed, action_tokens, maybe_extra_embed = unpack(tokens, ps, 'b * d')
# norm and prediction
action_tokens = self.final_norm(action_tokens)
pred_action = self.to_pred_action(action_tokens)
if not return_loss:
if not return_hiddens:
return pred_action
return pred_action, stack(vat_hiddens)
assert pred_action.shape[1] == actions.shape[1]
return F.l1_loss(pred_action, actions)
# quick test
if __name__ == '__main__':
for num_adv_bins in (0, 2, 10):
vat = SigLIPVAT(
num_tasks = 4,
dim_extra_token = 32,
time_seq_len = 2,
num_views = 2,
depth = 4,
num_advantage_bins = num_adv_bins,
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
0, 1, 26, 27
)
)
vat.load_siglip() # load siglip weights from hf
# inputs
images = torch.randn(1, 2, 3, 2, 224, 224) # (b, v, c, t, h, w)
tasks = torch.randint(0, 4, (1,))
extra = torch.randn(1, 32)
# advantage conditioning
advantages = None
if num_adv_bins > 0:
advantages = torch.randint(-1, num_adv_bins, (1,))
actions = torch.randn(1, 50, 32) # actions for learning
loss = vat(images, actions = actions, advantages = advantages, tasks = tasks, extra = extra, freeze_vit = True)
loss.backward()
# after much training
pred_actions = vat(images, advantages = advantages, tasks = tasks, extra = extra)
assert pred_actions.shape == (1, 50, 32)

View File

@@ -86,7 +86,7 @@ class ViT(Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
self.patch_size = patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
@@ -113,7 +113,7 @@ class ViT(Module):
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes)
self.mlp_head = nn.Linear(dim, num_classes) if num_classes > 0 else None
def forward(self, img):
batch = img.shape[0]
@@ -129,6 +129,9 @@ class ViT(Module):
x = self.transformer(x)
if self.mlp_head is None:
return x
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)

View File

@@ -78,7 +78,7 @@ class ViT(nn.Module):
def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(image_patch_size)
self.patch_size = patch_height, patch_width = pair(image_patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size'

254
vit_pytorch/vit_detpool.py Normal file
View File

@@ -0,0 +1,254 @@
from __future__ import annotations
# DetPool ViT - a vit that accepts an object mask and attends and pools only using that mask - table 1
# Dantong Niu et al. - https://openreview.net/forum?id=NZDaMcpXZm
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import Module, ModuleList
from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange, Reduce
# helpers
def exists(val):
return val is not None
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def masked_mean(t, mask, dim = 1, eps = 1e-5):
if not exists(mask):
return t.mean(dim = dim)
mask = rearrange(mask.bool(), '... -> ... 1')
t = t.masked_fill(~mask, 0.)
return t.sum(dim = dim) / mask.sum(dim = dim).clamp(min = eps)
# classes
class FeedForward(Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x, mask = None):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = (rearrange(t, 'b n (h d) -> b h n d', h = self.heads) for t in qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
mask_value = -torch.finfo(dots.dtype).max
dots = dots.masked_fill(~mask, mask_value)
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x, mask = None):
for attn, ff in self.layers:
x = attn(x, mask = mask) + x
x = ff(x) + x
return self.norm(x)
class ViTDetPool(Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, use_cls_token = True, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., mask_generator: Module | None = None):
super().__init__()
image_height, image_width = pair(image_size)
self.patch_size = patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
self.patch_height = patch_height
self.patch_width = patch_width
self.downsample_mask = Reduce('b (h p1) (w p2) -> b (h w)', 'max', p1 = patch_height, p2 = patch_width)
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
# maybe cls
self.use_cls_token = use_cls_token
if use_cls_token:
self.cls_token = nn.Parameter(torch.randn(dim) * 1e-2)
self.pos_embedding = nn.Parameter(torch.randn(num_patches, dim) * 1e-2)
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes) if num_classes > 0 else None
self.mask_generator = mask_generator
def forward(self, img, object_mask = None):
if not exists(object_mask) and exists(self.mask_generator):
with torch.no_grad():
self.mask_generator.eval()
object_mask = self.mask_generator(img)
has_cls = self.use_cls_token
batch, _, height, width = img.shape
tokens = self.to_patch_embedding(img)
seq = tokens.shape[1]
tokens = tokens + self.pos_embedding[:seq]
if has_cls:
cls_token = repeat(self.cls_token, 'd -> b d', b = batch)
tokens, packed_shape = pack((cls_token, tokens), 'b * d')
tokens = self.dropout(tokens)
# handle the attention mask, and for final pooling
mask = None
if exists(object_mask):
assert object_mask.ndim in {3, 2}
if object_mask.shape == (batch, height, width):
mask = self.downsample_mask(object_mask)
else:
mask = object_mask
mask = rearrange(mask, 'b ... -> b (...)')
assert mask.shape == (batch, seq)
if has_cls:
mask = F.pad(mask, (1, 0), value = True)
# attend with maybe mask
tokens = self.transformer(tokens, mask = mask)
if not exists(self.mlp_head):
return tokens
# splice out cls
if has_cls:
_, tokens = unpack(tokens, packed_shape, 'b * d')
if exists(mask):
mask = mask[..., 1:]
# pooling with the mask
pooled = masked_mean(tokens, mask, dim = 1)
pooled = self.to_latent(pooled)
return self.mlp_head(pooled)
# quick test
if __name__ == '__main__':
vit = ViTDetPool(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
img = torch.randn(1, 3, 256, 256)
object_mask = torch.randint(0, 2, (1, 256, 256)).bool()
preds = vit(img, object_mask = object_mask)
assert preds.shape == (1, 1000)
preds_no_mask = vit(img)
assert preds_no_mask.shape == (1, 1000)
# test with module included
class MockMasker(Module):
def forward(self, img):
batch, _, height, width = img.shape
return torch.ones(batch, height, width).bool()
vit = ViTDetPool(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 1,
heads = 16,
mlp_dim = 2048,
mask_generator = MockMasker()
)
preds = vit(img)
assert preds.shape == (1, 1000)

View File

@@ -99,7 +99,7 @@ class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
self.patch_size = patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

View File

@@ -31,7 +31,7 @@ class FeedForward(Module):
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
@@ -40,31 +40,31 @@ class Attention(Module):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
@@ -79,7 +79,7 @@ class Transformer(Module):
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
@@ -105,73 +105,73 @@ class ViTND(Module):
emb_dropout: float = 0.
):
super().__init__()
assert 1 <= ndim <= 7, 'ndim must be between 1 and 7'
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.ndim = ndim
self.pool = pool
input_shape = ensure_tuple(input_shape, ndim)
patch_size = ensure_tuple(patch_size, ndim)
for i, (inp_dim, patch_dim) in enumerate(zip(input_shape, patch_size)):
assert inp_dim % patch_dim == 0, f'Input dimension {i} ({inp_dim}) must be divisible by patch size ({patch_dim})'
num_patches_per_dim = [inp_dim // patch_dim for inp_dim, patch_dim in zip(input_shape, patch_size)]
num_patches = 1
for n in num_patches_per_dim:
num_patches *= n
patch_dim = channels
for p in patch_size:
patch_dim *= p
dim_names = 'fghijkl'[:ndim]
input_dims = [f'({d} p{i})' for i, d in enumerate(dim_names)]
patch_dims = [f'p{i}' for i in range(ndim)]
input_pattern = f'b c {join(input_dims)}'
output_pattern = f'b ({join(dim_names)}) ({join(patch_dims)} c)'
rearrange_str = f'{input_pattern} -> {output_pattern}'
rearrange_kwargs = {f'p{i}': p for i, p in enumerate(patch_size)}
self.to_patch_embedding = nn.Sequential(
Rearrange(rearrange_str, **rearrange_kwargs),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes)
def forward(self, x):
x = self.to_patch_embedding(x)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim = 1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x)
x = x[:, 1:].mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)
if __name__ == '__main__':
model = ViTND(
ndim = 4,
input_shape = (8, 16, 32, 64),
@@ -185,7 +185,7 @@ if __name__ == '__main__':
dropout = 0.1,
emb_dropout = 0.1
)
occupancy_time = torch.randn(2, 3, 8, 16, 32, 64)
logits = model(occupancy_time)

353
vit_pytorch/vit_nd_pope.py Normal file
View File

@@ -0,0 +1,353 @@
from __future__ import annotations
import torch
import torch.nn.functional as F
from torch import pi, nn, arange, cat, stack, Tensor
from torch.nn import Module, ModuleList
from torch.amp import autocast
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange
# helpers
def exists(val):
return val is not None
def l2norm(t):
return F.normalize(t, dim = -1, p = 2)
def join(arr, delimiter = ' '):
return delimiter.join(arr)
def ensure_tuple(t, length):
if isinstance(t, (tuple, list)):
assert len(t) == length, f'Expected tuple of length {length}, got {len(t)}'
return tuple(t)
return (t,) * length
# golden gate rotary - Jerry Xiong, PhD student at UIUC
# https://jerryxio.ng/posts/nd-rope/
# but using polar version instead
# Gopalakrishnan et al. https://arxiv.org/abs/2509.10534
def _phi(m: int) -> float:
x = 2.0
for _ in range(10):
x = (1 + x) ** (1.0 / (m + 1.0))
return x
def make_directions(n: int, d: int) -> Tensor:
g = _phi(d)
alpha = (1.0 / g) ** arange(1, d + 1, dtype = torch.float64)
i = arange(1, n + 1, dtype = torch.float64).unsqueeze(1)
z = torch.fmod(i * alpha, 1.0)
directions = torch.erfinv(2.0 * z - 1.0)
directions = l2norm(directions)
return directions.float()
class GoldenGatePoPENd(Module):
def __init__(
self,
dim_pos: int,
heads: int,
dim_head: int,
min_freq: float = 1.0,
max_freq: float = 10000.0,
p_zero_freqs: float = 0.0, # proportion of frequencies set to 0
init_learned_bias_uniform = False
):
super().__init__()
n_freqs = dim_head
n_zero_freqs = round(p_zero_freqs * n_freqs)
omega = cat((
torch.zeros(n_zero_freqs),
min_freq * (max_freq / min_freq) ** torch.linspace(0, 1, n_freqs - n_zero_freqs),
))
directions = rearrange(
make_directions(heads * n_freqs, dim_pos),
'(h f) p -> h f p',
h = heads
)
omega_expanded = rearrange(omega, 'f -> f 1')
self.register_buffer('freqs', directions * omega_expanded) # shape: (h, f, p)
self.learned_bias = nn.Parameter(torch.zeros(heads, dim_head))
if init_learned_bias_uniform:
self.learned_bias.uniform_(-2. * pi, 0.)
@autocast('cuda', enabled = False)
def forward(self, pos):
freqs = rearrange(self.freqs, 'h f p -> 1 h 1 f p')
positions = rearrange(pos.float(), 'b n p -> b 1 n 1 p')
# compute theta for each (batch, head, seq, freq)
theta = reduce(freqs * positions, 'b h n f p -> b h n f', 'sum')
bias = self.learned_bias.clamp(-2. * pi, 0.)
bias = rearrange(bias, 'h d -> h 1 d')
return theta, bias
@autocast('cuda', enabled = False)
def apply_polar_pos_emb(t, freqs):
orig_dtype = t.dtype
t = t.float()
t = F.softplus(t)
out = cat((t * freqs.cos(), t * freqs.sin()), dim = -1)
return out.type(orig_dtype)
# classes
class FeedForward(Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qk = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_v = nn.Linear(dim, inner_dim, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x, polar_pos_emb = None):
x = self.norm(x)
qkv = (*self.to_qk(x).chunk(2, dim = -1), self.to_v(x))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
if exists(polar_pos_emb):
freqs, bias = polar_pos_emb
q = apply_polar_pos_emb(q, freqs)
k = apply_polar_pos_emb(k, freqs + bias)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., polar_emb = None):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.polar_emb = polar_emb
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x, pos = None):
# pope embedding
polar_pos_emb = None
if exists(pos) and exists(self.polar_emb):
polar_pos_emb = self.polar_emb(pos)
# transformer layers
for attn, ff in self.layers:
x = attn(x, polar_pos_emb) + x
x = ff(x) + x
return self.norm(x)
class ViTND(Module):
def __init__(
self,
*,
ndim: int,
input_shape: int | tuple[int, ...],
patch_size: int | tuple[int, ...],
num_classes: int,
dim: int,
depth: int,
heads: int,
mlp_dim: int,
channels: int = 3,
dim_head: int = 64,
dropout: float = 0.,
emb_dropout: float = 0.,
pope_min_freq: float = 1.0,
pope_max_freq: float = 10000.0,
pope_p_zero_freqs: float = 0.0,
pope_init_learned_bias_uniform = False
):
super().__init__()
assert 1 <= ndim <= 7, 'ndim must be between 1 and 7'
self.ndim = ndim
input_shape = ensure_tuple(input_shape, ndim)
patch_size = ensure_tuple(patch_size, ndim)
for i, (inp_dim, patch_dim) in enumerate(zip(input_shape, patch_size)):
assert inp_dim % patch_dim == 0, f'Input dimension {i} ({inp_dim}) must be divisible by patch size ({patch_dim})'
num_patches_per_dim = [inp_dim // patch_dim for inp_dim, patch_dim in zip(input_shape, patch_size)]
num_patches = 1
for n in num_patches_per_dim:
num_patches *= n
patch_dim = channels
for p in patch_size:
patch_dim *= p
dim_names = 'fghijkl'[:ndim]
input_dims = [f'({d} p{i})' for i, d in enumerate(dim_names)]
patch_dims = [f'p{i}' for i in range(ndim)]
input_pattern = f'b c {join(input_dims)}'
output_pattern = f'b {join(dim_names)} ({join(patch_dims)} c)'
rearrange_str = f'{input_pattern} -> {output_pattern}'
rearrange_kwargs = {f'p{i}': p for i, p in enumerate(patch_size)}
self.to_patch_embedding = nn.Sequential(
Rearrange(rearrange_str, **rearrange_kwargs),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.dropout = nn.Dropout(emb_dropout)
# golden gate pope
self.polar_emb = GoldenGatePoPENd(
dim_pos = ndim,
heads = heads,
dim_head = dim_head,
min_freq = pope_min_freq,
max_freq = pope_max_freq,
p_zero_freqs = pope_p_zero_freqs,
init_learned_bias_uniform = pope_init_learned_bias_uniform
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, polar_emb = self.polar_emb)
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes)
def muon_parameters(self):
params = []
for m in self.modules():
if isinstance(m, Attention):
params.extend([
m.to_v.weight,
m.to_out[0].weight
])
elif isinstance(m, FeedForward):
params.extend([
m.net[1].weight,
m.net[-2].weight
])
return params
def forward(
self,
x,
return_embed = False
):
x = self.to_patch_embedding(x) # (b, *spatial_dims, patch_dim)
batch, *spatial_dims, _, device = *x.shape, x.device
# Generate position coordinates
grids = [arange(d, device = device, dtype = torch.float32) for d in spatial_dims]
grid = torch.meshgrid(*grids, indexing = 'ij')
pos = stack(grid, dim = -1) # (*spatial_dims, ndim)
# flatten spatial dimensions for attention with nd rotary
pos = repeat(pos, '... p -> b (...) p', b = batch)
x, packed_shape = pack([x], 'b * d')
x = self.dropout(x)
embed = self.transformer(x, pos)
# return the embed with reconstituted patch shape
if return_embed:
embed, = unpack(embed, packed_shape, 'b * d')
return embed
# pooling to logits
pooled = reduce(embed, 'b n d -> b d', 'mean')
pooled = self.to_latent(pooled)
return self.mlp_head(pooled)
if __name__ == '__main__':
model = ViTND(
ndim = 5,
input_shape = (4, 8, 16, 32, 64),
patch_size = (2, 2, 4, 4, 8),
num_classes = 1000,
dim = 512,
depth = 6,
heads = 8,
mlp_dim = 2048,
channels = 3,
dropout = 0.1,
emb_dropout = 0.1
)
data = torch.randn(3, 3, 4, 8, 16, 32, 64)
logits = model(data)
embed = model(data, return_embed = True)
assert embed.shape == (3, 2, 4, 4, 8, 8, 512)

View File

@@ -75,23 +75,23 @@ class GoldenGateRoPENd(Module):
# input shape: (b, h, n, d) where d = head_dim
# pos shape: (b, n, p) where p = pos_dim
# self.freqs shape: (h, f, p) where f = d // 2
x, y = input.float().chunk(2, dim = -1) # both (b, h, n, f)
# Expand dimensions for broadcasting
freqs = rearrange(self.freqs, 'h f p -> 1 h 1 f p')
positions = rearrange(pos.float(), 'b n p -> b 1 n 1 p')
# Compute theta for each (batch, head, seq, freq)
theta = reduce(freqs * positions, 'b h n f p -> b h n f', 'sum')
cos_theta = torch.cos(theta)
sin_theta = torch.sin(theta)
# Apply rotation
x_out = x * cos_theta - y * sin_theta
y_out = x * sin_theta + y * cos_theta
output = cat((x_out, y_out), dim=-1)
return output.type_as(input)
@@ -108,7 +108,7 @@ class FeedForward(Module):
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
@@ -117,15 +117,15 @@ class Attention(Module):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.rotary_emb = rotary_emb
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qk = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_v = nn.Linear(dim, inner_dim, bias = False)
@@ -133,24 +133,24 @@ class Attention(Module):
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x, pos = None):
x = self.norm(x)
qkv = (*self.to_qk(x).chunk(2, dim = -1), self.to_v(x))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
# Apply rotary embeddings if available
if exists(self.rotary_emb):
assert exists(pos)
q = self.rotary_emb(q, pos)
k = self.rotary_emb(k, pos)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
@@ -165,7 +165,7 @@ class Transformer(Module):
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, rotary_emb = rotary_emb),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x, pos = None):
for attn, ff in self.layers:
x = attn(x, pos) + x
@@ -193,45 +193,45 @@ class ViTND(Module):
rope_p_zero_freqs: float = 0.0
):
super().__init__()
assert 1 <= ndim <= 7, 'ndim must be between 1 and 7'
self.ndim = ndim
input_shape = ensure_tuple(input_shape, ndim)
patch_size = ensure_tuple(patch_size, ndim)
for i, (inp_dim, patch_dim) in enumerate(zip(input_shape, patch_size)):
assert inp_dim % patch_dim == 0, f'Input dimension {i} ({inp_dim}) must be divisible by patch size ({patch_dim})'
num_patches_per_dim = [inp_dim // patch_dim for inp_dim, patch_dim in zip(input_shape, patch_size)]
num_patches = 1
for n in num_patches_per_dim:
num_patches *= n
patch_dim = channels
for p in patch_size:
patch_dim *= p
dim_names = 'fghijkl'[:ndim]
input_dims = [f'({d} p{i})' for i, d in enumerate(dim_names)]
patch_dims = [f'p{i}' for i in range(ndim)]
input_pattern = f'b c {join(input_dims)}'
output_pattern = f'b {join(dim_names)} ({join(patch_dims)} c)'
rearrange_str = f'{input_pattern} -> {output_pattern}'
rearrange_kwargs = {f'p{i}': p for i, p in enumerate(patch_size)}
self.to_patch_embedding = nn.Sequential(
Rearrange(rearrange_str, **rearrange_kwargs),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.dropout = nn.Dropout(emb_dropout)
# Create rotary embeddings
self.rotary_emb = GoldenGateRoPENd(
dim_pos = ndim,
@@ -241,12 +241,12 @@ class ViTND(Module):
rope_max_freq = rope_max_freq,
rope_p_zero_freqs = rope_p_zero_freqs
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, rotary_emb = self.rotary_emb)
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes)
def muon_parameters(self):
params = []
@@ -270,9 +270,9 @@ class ViTND(Module):
return_embed = False
):
x = self.to_patch_embedding(x) # (b, *spatial_dims, patch_dim)
batch, *spatial_dims, _, device = *x.shape, x.device
# Generate position coordinates
grids = [arange(d, device = device, dtype = torch.float32) for d in spatial_dims]
@@ -280,12 +280,12 @@ class ViTND(Module):
pos = stack(grid, dim = -1) # (*spatial_dims, ndim)
# flatten spatial dimensions for attention with nd rotary
pos = repeat(pos, '... p -> b (...) p', b = batch)
x, packed_shape = pack([x], 'b * d')
x = self.dropout(x)
embed = self.transformer(x, pos)
# return the embed with reconstituted patch shape
@@ -303,7 +303,7 @@ class ViTND(Module):
if __name__ == '__main__':
model = ViTND(
ndim = 5,
input_shape = (4, 8, 16, 32, 64),

View File

@@ -155,7 +155,7 @@ class ViT(Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., decorr_sample_frac = 1.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
self.patch_size = patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
@@ -231,4 +231,4 @@ if __name__ == '__main__':
decorr_loss = DecorrelationLoss(0.0001, soft_validate_num_sampled = True)
out = decorr_loss(hiddens)
assert out.item() == 0
assert out.item() == 0

View File

@@ -0,0 +1,217 @@
from __future__ import annotations
import torch
from torch import nn
from torch.nn import Module, ModuleList
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# functions
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# classes
class FeedForward(Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim, bias = False),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.norm = nn.LayerNorm(dim, bias = False)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(Module):
def __init__(
self,
dim,
depth,
heads,
dim_head,
mlp_dim,
dropout = 0.,
keel_residual_scale = None
):
super().__init__()
assert depth > 1
self.layers = ModuleList([])
for _ in range(depth):
self.layers.extend([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
])
num_layers = depth * 2
self.keel_residual_scale = default(keel_residual_scale, num_layers)
self.post_norms = ModuleList([nn.LayerNorm(dim, bias = False) for _ in range(num_layers - 1)])
def forward(self, x):
residual_scale = self.keel_residual_scale
for layer_ind, layer in enumerate(self.layers):
first_layer = layer_ind == 0
residual = x
out = layer(x)
if first_layer:
x = out + residual
continue
post_norm = self.post_norms[layer_ind - 1]
x = post_norm(out + residual * residual_scale)
return x
class ViT(Module):
def __init__(
self,
*,
image_size,
patch_size,
num_classes,
dim,
depth,
heads,
mlp_dim,
pool = 'cls',
channels = 3,
dim_head = 64,
dropout = 0.,
emb_dropout = 0.,
keel_residual_scale = None
):
super().__init__()
image_height, image_width = pair(image_size)
self.patch_size = patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
num_cls_tokens = 1 if pool == 'cls' else 0
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.cls_token = nn.Parameter(torch.randn(num_cls_tokens, dim))
self.pos_embedding = nn.Parameter(torch.randn(num_patches + num_cls_tokens, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(
dim,
depth,
heads,
dim_head,
mlp_dim,
dropout,
keel_residual_scale = keel_residual_scale
)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes) if num_classes > 0 else None
def forward(self, img):
batch = img.shape[0]
x = self.to_patch_embedding(img)
cls_tokens = repeat(self.cls_token, '... d -> b ... d', b = batch)
x = torch.cat((cls_tokens, x), dim = 1)
seq = x.shape[1]
x = x + self.pos_embedding[:seq]
x = self.dropout(x)
x = self.transformer(x)
if not exists(self.mlp_head):
return x
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)
if __name__ == '__main__':
v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
img = torch.randn(1, 3, 256, 256)
preds = v(img)
assert preds.shape == (1, 1000)

View File

@@ -97,7 +97,7 @@ class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., patch_dropout = 0.25):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
self.patch_size = patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

View File

@@ -108,7 +108,7 @@ class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, patch_merge_layer = None, patch_merge_num_tokens = 8, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
self.patch_size = patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

View File

@@ -1,5 +1,10 @@
from collections import namedtuple
import torch
from torch import nn
from torch import nn, cat
import torch.nn.functional as F
from torch.nn import Module, ModuleList
from torch.nn.attention import SDPBackend, sdpa_kernel
from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange
@@ -9,12 +14,15 @@ from einops.layers.torch import Rearrange
def exists(val):
return val is not None
def divisible_by(num, den):
return (num % den) == 0
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# classes
class FeedForward(nn.Module):
class FeedForward(Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
@@ -28,9 +36,11 @@ class FeedForward(nn.Module):
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., use_flash_attn = True):
super().__init__()
self.use_flash_attn = use_flash_attn
self.dropout_p = dropout
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
@@ -48,61 +58,100 @@ class Attention(nn.Module):
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
def flash_attn(self, q, k, v, mask = None):
with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION]):
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask = mask,
dropout_p = self.dropout_p,
is_causal = False,
scale = self.scale
)
return out
def forward(self, x, mask = None):
batch, seq, _ = x.shape
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
attn = self.attend(dots)
attn = self.dropout(attn)
if self.use_flash_attn:
out = self.flash_attn(q, k, v, mask = mask)
else:
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
if exists(mask):
dots = dots.masked_fill(~mask, -torch.finfo(dots.dtype).max)
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., use_flash_attn = True):
super().__init__()
self.use_flash_attn = use_flash_attn
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_flash_attn = use_flash_attn),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
def forward(self, x, mask = None):
for attn, ff in self.layers:
x = attn(x) + x
x = attn(x, mask = mask) + x
x = ff(x) + x
return self.norm(x)
class FactorizedTransformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
class FactorizedTransformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., use_flash_attn = True):
super().__init__()
self.use_flash_attn = use_flash_attn
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_flash_attn = use_flash_attn),
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_flash_attn = use_flash_attn),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
b, f, n, _ = x.shape
def forward(self, x, mask = None):
batch, frames, seq, _ = x.shape
if exists(mask):
mask = repeat(mask, 'b ... -> (b space) ...', space = x.shape[2])
for spatial_attn, temporal_attn, ff in self.layers:
x = rearrange(x, 'b f n d -> (b f) n d')
x = spatial_attn(x) + x
x = rearrange(x, '(b f) n d -> (b n) f d', b=b, f=f)
x = temporal_attn(x) + x
x = rearrange(x, '(b f) n d -> (b n) f d', b = batch, f = frames)
x = temporal_attn(x, mask = mask) + x
x = ff(x) + x
x = rearrange(x, '(b n) f d -> b f n d', b=b, n=n)
x = rearrange(x, '(b n) f d -> b f n d', b = batch, n = seq)
return self.norm(x)
class ViT(nn.Module):
class ViViT(Module):
def __init__(
self,
*,
@@ -122,13 +171,14 @@ class ViT(nn.Module):
dropout = 0.,
emb_dropout = 0.,
variant = 'factorized_encoder',
use_flash_attn: bool = True,
):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(image_patch_size)
self.patch_size = patch_height, patch_width = pair(image_patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size'
assert divisible_by(image_height, patch_height) and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
assert divisible_by(frames, frame_patch_size), 'Frames must be divisible by frame patch size'
assert variant in ('factorized_encoder', 'factorized_self_attention'), f'variant = {variant} is not implemented'
num_image_patches = (image_height // patch_height) * (image_width // patch_width)
@@ -138,6 +188,8 @@ class ViT(nn.Module):
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.frame_patch_size = frame_patch_size
self.global_average_pool = pool == 'mean'
self.to_patch_embedding = nn.Sequential(
@@ -154,11 +206,11 @@ class ViT(nn.Module):
if variant == 'factorized_encoder':
self.temporal_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None
self.spatial_transformer = Transformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)
self.temporal_transformer = Transformer(dim, temporal_depth, heads, dim_head, mlp_dim, dropout)
self.spatial_transformer = Transformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout, use_flash_attn)
self.temporal_transformer = Transformer(dim, temporal_depth, heads, dim_head, mlp_dim, dropout, use_flash_attn)
elif variant == 'factorized_self_attention':
assert spatial_depth == temporal_depth, 'Spatial and temporal depth must be the same for factorized self-attention'
self.factorized_transformer = FactorizedTransformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)
self.factorized_transformer = FactorizedTransformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout, use_flash_attn)
self.pool = pool
self.to_latent = nn.Identity()
@@ -166,25 +218,36 @@ class ViT(nn.Module):
self.mlp_head = nn.Linear(dim, num_classes)
self.variant = variant
def forward(self, video):
x = self.to_patch_embedding(video)
b, f, n, _ = x.shape
def forward(self, video, mask = None):
device = video.device
x = x + self.pos_embedding[:, :f, :n]
x = self.to_patch_embedding(video)
batch, frames, seq, _ = x.shape
x = x + self.pos_embedding[:, :frames, :seq]
if exists(self.spatial_cls_token):
spatial_cls_tokens = repeat(self.spatial_cls_token, '1 1 d -> b f 1 d', b = b, f = f)
x = torch.cat((spatial_cls_tokens, x), dim = 2)
spatial_cls_tokens = repeat(self.spatial_cls_token, '1 1 d -> b f 1 d', b = batch, f = frames)
x = cat((spatial_cls_tokens, x), dim = 2)
x = self.dropout(x)
# maybe temporal mask
temporal_mask = None
if exists(mask):
temporal_mask = reduce(mask, 'b (f patch) -> b f', 'all', patch = self.frame_patch_size)
# the two variants
if self.variant == 'factorized_encoder':
x = rearrange(x, 'b f n d -> (b f) n d')
# attend across space
x = self.spatial_transformer(x)
x = rearrange(x, '(b f) n d -> b f n d', b = b)
x = rearrange(x, '(b f) n d -> b f n d', b = batch)
# excise out the spatial cls tokens or average pool for temporal attention
@@ -193,22 +256,50 @@ class ViT(nn.Module):
# append temporal CLS tokens
if exists(self.temporal_cls_token):
temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = b)
temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = batch)
x = torch.cat((temporal_cls_tokens, x), dim = 1)
x = cat((temporal_cls_tokens, x), dim = 1)
if exists(temporal_mask):
temporal_mask = F.pad(temporal_mask, (1, 0), value = True)
# attend across time
x = self.temporal_transformer(x)
x = self.temporal_transformer(x, mask = temporal_mask)
# excise out temporal cls token or average pool
x = x[:, 0] if not self.global_average_pool else reduce(x, 'b f d -> b d', 'mean')
elif self.variant == 'factorized_self_attention':
x = self.factorized_transformer(x)
x = self.factorized_transformer(x, mask = temporal_mask)
x = x[:, 0, 0] if not self.global_average_pool else reduce(x, 'b f n d -> b d', 'mean')
x = self.to_latent(x)
return self.mlp_head(x)
# main
if __name__ == '__main__':
vivit = ViViT(
dim = 512,
spatial_depth = 2,
temporal_depth = 2,
heads = 4,
mlp_dim = 2048,
image_size = 256,
image_patch_size = 16,
frames = 8,
frame_patch_size = 2,
num_classes = 1000,
variant = 'factorized_encoder',
)
video = torch.randn(3, 3, 8, 256, 256)
mask = torch.randint(0, 2, (3, 8)).bool()
logits = vivit(video, mask = None)
assert logits.shape == (3, 1000)

View File

@@ -0,0 +1,386 @@
# https://openreview.net/forum?id=Co6SCyBIjo
# applied at https://arxiv.org/abs/2605.03269 - 50-85% jump in pick-place moving conveyer belt
import torch
from torch import nn
from torch.nn import Module, ModuleList
import torch.nn.functional as F
from torch.nn.attention import SDPBackend, sdpa_kernel
from einops import rearrange, repeat, reduce, einsum
from einops.layers.torch import Rearrange
# helpers
def exists(val):
return val is not None
def divisible_by(num, den):
return (num % den) == 0
def is_odd(n):
return not divisible_by(n, 2)
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def l2norm(t):
return F.normalize(t, dim = -1)
# normalization helpers
class ChanLayerNorm(Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1))
def forward(self, x):
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) * (var + self.eps).rsqrt() * self.gamma
# classes
class FeedForward(Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., use_flash_attn = True):
super().__init__()
self.use_flash_attn = use_flash_attn
self.dropout_p = dropout
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def flash_attn(self, q, k, v, mask = None):
with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION]):
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask = mask,
dropout_p = self.dropout_p,
is_causal = False,
scale = self.scale
)
return out
def forward(self, x, mask = None):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
if self.use_flash_attn:
out = self.flash_attn(q, k, v, mask = mask)
else:
dots = einsum(q, k, 'b h i d, b h j d -> b h i j') * self.scale
if exists(mask):
dots = dots.masked_fill(~mask, -torch.finfo(dots.dtype).max)
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., use_flash_attn = True):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_flash_attn = use_flash_attn),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x, mask = None):
for attn, ff in self.layers:
x = attn(x, mask = mask) + x
x = ff(x) + x
return self.norm(x)
# moss specific classes
class STSSEncoder(Module):
def __init__(self, dim, local_time = 3, local_height = 3, local_width = 3, hidden_dim = 64):
super().__init__()
self.spatial_to_hidden = nn.Linear(local_height * local_width, hidden_dim)
self.conv = nn.Sequential(
nn.Conv2d(hidden_dim, hidden_dim, 3, padding = 1),
ChanLayerNorm(hidden_dim),
nn.GELU(),
nn.Conv2d(hidden_dim, hidden_dim, 3, padding = 1),
ChanLayerNorm(hidden_dim),
nn.GELU()
)
self.time_to_out = nn.Linear(local_time * hidden_dim, dim)
def forward(self, sim):
b, t, h, w, lt, lh, lw = sim.shape
x = rearrange(sim, 'b t h w lt lh lw -> b t h w lt (lh lw)')
x = self.spatial_to_hidden(x)
x = rearrange(x, 'b t h w lt d -> (b t lt) d h w')
x = self.conv(x)
x = rearrange(x, '(b t lt) d h w -> b t h w (lt d)', b = b, t = t, lt = lt)
return self.time_to_out(x)
class MOSS(Module):
def __init__(
self,
dim,
local_time = 3,
local_height = 3,
local_width = 3,
hidden_dim = 64,
orders = 2,
causal = False
):
super().__init__()
assert is_odd(local_time) and is_odd(local_height) and is_odd(local_width), 'MOSS local dimensions must be odd'
self.local_time = local_time
self.local_height = local_height
self.local_width = local_width
self.causal = causal
self.encoders = ModuleList([STSSEncoder(dim, local_time, local_height, local_width, hidden_dim) for _ in range(orders)])
self.to_order_out = ModuleList([nn.Linear(dim, dim) for _ in range(orders)])
self.to_out = nn.Linear(dim, dim)
def stss_transform(self, x):
lt, lh, lw = self.local_time, self.local_height, self.local_width
x = l2norm(x)
x = rearrange(x, 'b t h w c -> b c t h w')
pad_h, pad_w = lh // 2, lw // 2
pad_t_past, pad_t_future = (lt - 1, 0) if self.causal else (lt // 2, lt // 2)
padded_x = F.pad(x, (pad_w, pad_w, pad_h, pad_h, pad_t_past, pad_t_future))
windows = padded_x.unfold(2, lt, 1).unfold(3, lh, 1).unfold(4, lw, 1)
return einsum(x, windows, 'b c t h w, b c t h w l u v -> b t h w l u v')
def forward(self, x):
out = self.to_out(x)
for encoder, to_order_out in zip(self.encoders, self.to_order_out):
sim = self.stss_transform(x)
x = encoder(sim)
out = out + to_order_out(x)
return out
# main architecture
class ViViT(Module):
def __init__(
self,
*,
image_size,
image_patch_size,
frames,
frame_patch_size,
num_classes,
dim,
spatial_depth,
temporal_depth,
heads,
mlp_dim,
pool = 'cls',
channels = 3,
dim_head = 64,
dropout = 0.,
emb_dropout = 0.,
use_flash_attn: bool = True,
moss_local_time = 3,
moss_local_height = 3,
moss_local_width = 3,
moss_hidden_dim = 64,
moss_orders = 2,
moss_causal = True,
):
super().__init__()
image_height, image_width = pair(image_size)
self.patch_size = patch_height, patch_width = pair(image_patch_size)
assert divisible_by(image_height, patch_height) and divisible_by(image_width, patch_width), 'Image dimensions must be divisible by the patch size.'
assert divisible_by(frames, frame_patch_size), 'Frames must be divisible by frame patch size'
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
num_image_patches = (image_height // patch_height) * (image_width // patch_width)
num_frame_patches = frames // frame_patch_size
patch_dim = channels * patch_height * patch_width * frame_patch_size
self.frame_patch_size = frame_patch_size
self.patch_h = image_height // patch_height
self.patch_w = image_width // patch_width
self.global_average_pool = pool == 'mean'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (f pf) (h p1) (w p2) -> b f (h w) (pf p1 p2 c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim)
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_frame_patches, num_image_patches, dim))
self.dropout = nn.Dropout(emb_dropout)
self.has_cls = not self.global_average_pool
self.spatial_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if self.has_cls else None
self.temporal_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if self.has_cls else None
self.spatial_transformer = Transformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout, use_flash_attn)
self.temporal_transformer = Transformer(dim, temporal_depth, heads, dim_head, mlp_dim, dropout, use_flash_attn)
self.moss = MOSS(
dim,
local_time = moss_local_time,
local_height = moss_local_height,
local_width = moss_local_width,
hidden_dim = moss_hidden_dim,
orders = moss_orders,
causal = moss_causal
)
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, video, mask = None):
x = self.to_patch_embedding(video)
batch, frames, seq, _ = x.shape
x = x + self.pos_embedding[:, :frames, :seq]
if self.has_cls:
spatial_cls_tokens = repeat(self.spatial_cls_token, '1 1 d -> b f 1 d', b = batch, f = frames)
x = torch.cat((spatial_cls_tokens, x), dim = 2)
x = self.dropout(x)
# temporal mask
temporal_mask = None
if exists(mask):
temporal_mask = reduce(mask, 'b (f patch) -> b f', 'all', patch = self.frame_patch_size)
x = rearrange(x, 'b f n d -> (b f) n d')
# attend across space
x = self.spatial_transformer(x)
x = rearrange(x, '(b f) n d -> b f n d', b = batch)
# moss integration over spatial patch tokens
if self.has_cls:
spatial_cls_tokens, patch_tokens = x[:, :, :1], x[:, :, 1:]
else:
patch_tokens = x
patch_tokens = rearrange(patch_tokens, 'b f (h w) d -> b f h w d', h = self.patch_h, w = self.patch_w)
patch_tokens = self.moss(patch_tokens)
patch_tokens = rearrange(patch_tokens, 'b f h w d -> b f (h w) d')
# pool spatial features
moss_pooled = reduce(patch_tokens, 'b f n d -> b f d', 'mean')
if self.has_cls:
x = rearrange(spatial_cls_tokens, 'b f 1 d -> b f d') + moss_pooled
else:
x = moss_pooled
# append temporal cls tokens
if self.has_cls:
temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d -> b 1 d', b = batch)
x = torch.cat((temporal_cls_tokens, x), dim = 1)
if exists(temporal_mask):
temporal_mask = F.pad(temporal_mask, (1, 0), value = True)
# attend across time
x = self.temporal_transformer(x, mask = temporal_mask)
# temporal pooling
x = x[:, 0] if self.has_cls else reduce(x, 'b f d -> b d', 'mean')
return self.mlp_head(x)
if __name__ == '__main__':
vivit = ViViT(
dim = 512,
spatial_depth = 2,
temporal_depth = 2,
heads = 4,
mlp_dim = 2048,
image_size = 256,
image_patch_size = 32,
frames = 8,
frame_patch_size = 2,
num_classes = 1000,
)
video = torch.randn(2, 3, 8, 256, 256)
mask = torch.randint(0, 2, (2, 8)).bool()
logits = vivit(video, mask = None)
assert logits.shape == (2, 1000)
logits = vivit(video, mask = mask)
assert logits.shape == (2, 1000)
moss = MOSS(
dim = 512,
local_time = 3,
local_height = 3,
local_width = 3,
hidden_dim = 64,
orders = 2,
causal = True
)
moss_input = torch.randn(2, 8, 16, 16, 512) # (batch, frames, height, width, dim)
moss_output = moss(moss_input)
assert moss_output.shape == (2, 8, 16, 16, 512)