Compare commits

...

14 Commits
1.17.6 ... main

Author SHA1 Message Date
lucidrains
93df0e6046 add another vit variant, where they found improvements for certain tasks when cls token get its own specialized parameters - layernorms were crucial 2026-05-01 13:49:43 -07:00
lucidrains
8e104e9afc cleanup vit det pool 2026-05-01 12:40:31 -07:00
lucidrains
3f03aa3994 add a vit that can accept an object mask (from sam or other seg models), and only attends and pools those patch tokens 2026-05-01 12:17:19 -07:00
lucidrains
2da1b45b9b allow vit to modulate the parallel and orthog components 2026-04-18 09:46:29 -07:00
lucidrains
7ab07c2499 add vit with orthogonal residual update 2026-04-18 09:24:52 -07:00
lucidrains
dea6b0da56 blur the line between depth and recurrence even more 2026-04-06 09:27:02 -07:00
lucidrains
13284b7af1 first attention residual should be disabled, cleanup 2026-04-05 20:14:06 -07:00
lucidrains
7e18d0302e stop relying on github 2026-03-27 08:07:42 -07:00
lucidrains
b80676e09c add an attention residual example (kimi team) as well as dino / byol redone with sigreg (lejepa) 2026-03-27 07:51:12 -07:00
lucidrains
fc1e727428 add ability to condition on binned advantages for the vision action transformers 2026-02-19 14:10:44 -08:00
lucidrains
6032a54b48 patch 2026-02-11 11:49:51 -08:00
Harikrishna KP
06a1f42924 Fix ViViT Transformer not passing use_flash_attn to Attention and duplicate mask reshape (#360)
Two related bugs in vivit.py:

1. Transformer.__init__ accepted use_flash_attn but never forwarded it to the
   Attention modules it creates. Since Attention defaults to use_flash_attn=True,
   setting use_flash_attn=False on ViViT had no effect on the factorized_encoder
   variant's spatial and temporal transformers.

2. Attention.forward reshaped the mask from 2D to 4D before the flash/non-flash
   branch (line 82), then attempted to reshape it again inside the non-flash
   branch (line 92). When the non-flash code path is actually reached with a
   mask, einops raises an error because the mask is already 4D.

   These bugs masked each other: bug #1 prevented bug #2 from triggering because
   the non-flash path was never taken even when requested.

Fix: pass use_flash_attn through to Attention in Transformer.__init__, and
remove the redundant second mask rearrange in the non-flash branch.
2026-02-11 11:49:31 -08:00
Phil Wang
6ae6a3ab64 cleanup 2026-02-04 13:29:40 -08:00
lucidrains
827300beed add vit with keel post ln, proposed by bytedance for scaling depth 2026-02-04 09:09:17 -08:00
45 changed files with 2203 additions and 654 deletions

3
.github/FUNDING.yml vendored
View File

@@ -1,3 +0,0 @@
# These are supported funding model platforms
github: [lucidrains]

View File

@@ -1,36 +0,0 @@
# This workflow will upload a Python Package using Twine when a release is created
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
# This workflow uses actions that are not certified by GitHub.
# They are provided by a third-party and are governed by
# separate terms of service, privacy policy, and support
# documentation.
name: Upload Python Package
on:
release:
types: [published]
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install build
- name: Build package
run: python -m build
- name: Publish package
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}

View File

@@ -1,34 +0,0 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: Test
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8, 3.9]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cpu
python -m pip install -e .
python -m pip install pytest
- name: Test with pytest
run: |
pytest -q

3
.gitignore vendored
View File

@@ -127,3 +127,6 @@ dmypy.json
# Pyre type checker
.pyre/
# scripts
*.sh

148
README.md
View File

@@ -90,26 +90,26 @@ preds = v(img) # (1, 1000)
## Parameters
- `image_size`: int.
- `image_size`: int.
Image size. If you have rectangular images, make sure your image size is the maximum of the width and height
- `patch_size`: int.
Size of patches. `image_size` must be divisible by `patch_size`.
- `patch_size`: int.
Size of patches. `image_size` must be divisible by `patch_size`.
The number of patches is: ` n = (image_size // patch_size) ** 2` and `n` **must be greater than 16**.
- `num_classes`: int.
- `num_classes`: int.
Number of classes to classify.
- `dim`: int.
- `dim`: int.
Last dimension of output tensor after linear transformation `nn.Linear(..., dim)`.
- `depth`: int.
- `depth`: int.
Number of Transformer blocks.
- `heads`: int.
Number of heads in Multi-head Attention layer.
- `mlp_dim`: int.
Dimension of the MLP (FeedForward) layer.
- `channels`: int, default `3`.
Number of image's channels.
- `dropout`: float between `[0, 1]`, default `0.`.
Dropout rate.
- `emb_dropout`: float between `[0, 1]`, default `0`.
- `heads`: int.
Number of heads in Multi-head Attention layer.
- `mlp_dim`: int.
Dimension of the MLP (FeedForward) layer.
- `channels`: int, default `3`.
Number of image's channels.
- `dropout`: float between `[0, 1]`, default `0.`.
Dropout rate.
- `emb_dropout`: float between `[0, 1]`, default `0`.
Embedding dropout rate.
- `pool`: string, either `cls` token pooling or `mean` pooling
@@ -972,7 +972,7 @@ torch.save(model.state_dict(), './pretrained-net.pt')
<img src="./images/mp3.png" width="400px"></img>
New <a href="https://arxiv.org/abs/2207.07611">paper</a> that introduces masked position prediction pre-training criteria. This strategy is more efficient than the Masked Autoencoder strategy and has comparable performance.
New <a href="https://arxiv.org/abs/2207.07611">paper</a> that introduces masked position prediction pre-training criteria. This strategy is more efficient than the Masked Autoencoder strategy and has comparable performance.
```python
import torch
@@ -1361,7 +1361,7 @@ learner = Dino(
num_classes_K = 65536, # output logits dimensions (referenced as K in paper)
student_temp = 0.9, # student temperature
teacher_temp = 0.04, # teacher temperature, needs to be annealed from 0.04 to 0.07 over 30 epochs
local_upper_crop_scale = 0.4, # upper bound for local crop - 0.4 was recommended in the paper
local_upper_crop_scale = 0.4, # upper bound for local crop - 0.4 was recommended in the paper
global_lower_crop_scale = 0.5, # lower bound for global crop - 0.5 was recommended in the paper
moving_average_decay = 0.9, # moving average of encoder - paper showed anywhere from 0.9 to 0.999 was ok
center_moving_average_decay = 0.9, # moving average of teacher centers - paper showed anywhere from 0.9 to 0.999 was ok
@@ -1735,7 +1735,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
```bibtex
@misc{touvron2020training,
title = {Training data-efficient image transformers & distillation through attention},
title = {Training data-efficient image transformers & distillation through attention},
author = {Hugo Touvron and Matthieu Cord and Matthijs Douze and Francisco Massa and Alexandre Sablayrolles and Hervé Jégou},
year = {2020},
eprint = {2012.12877},
@@ -1768,7 +1768,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
```bibtex
@misc{touvron2021going,
title = {Going deeper with Image Transformers},
title = {Going deeper with Image Transformers},
author = {Hugo Touvron and Matthieu Cord and Alexandre Sablayrolles and Gabriel Synnaeve and Hervé Jégou},
year = {2021},
eprint = {2103.17239},
@@ -1801,7 +1801,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
```bibtex
@misc{heo2021rethinking,
title = {Rethinking Spatial Dimensions of Vision Transformers},
title = {Rethinking Spatial Dimensions of Vision Transformers},
author = {Byeongho Heo and Sangdoo Yun and Dongyoon Han and Sanghyuk Chun and Junsuk Choe and Seong Joon Oh},
year = {2021},
eprint = {2103.16302},
@@ -1845,7 +1845,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
```bibtex
@misc{su2021roformer,
title = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
title = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
author = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
year = {2021},
eprint = {2104.09864},
@@ -1867,7 +1867,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
```bibtex
@misc{chen2021regionvit,
title = {RegionViT: Regional-to-Local Attention for Vision Transformers},
title = {RegionViT: Regional-to-Local Attention for Vision Transformers},
author = {Chun-Fu Chen and Rameswar Panda and Quanfu Fan},
year = {2021},
eprint = {2106.02689},
@@ -1878,7 +1878,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
```bibtex
@misc{wang2021crossformer,
title = {CrossFormer: A Versatile Vision Transformer Hinging on Cross-scale Attention},
title = {CrossFormer: A Versatile Vision Transformer Hinging on Cross-scale Attention},
author = {Wenxiao Wang and Lu Yao and Long Chen and Binbin Lin and Deng Cai and Xiaofei He and Wei Liu},
year = {2021},
eprint = {2108.00154},
@@ -1900,7 +1900,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
```bibtex
@misc{he2021masked,
title = {Masked Autoencoders Are Scalable Vision Learners},
title = {Masked Autoencoders Are Scalable Vision Learners},
author = {Kaiming He and Xinlei Chen and Saining Xie and Yanghao Li and Piotr Dollár and Ross Girshick},
year = {2021},
eprint = {2111.06377},
@@ -1911,7 +1911,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
```bibtex
@misc{xie2021simmim,
title = {SimMIM: A Simple Framework for Masked Image Modeling},
title = {SimMIM: A Simple Framework for Masked Image Modeling},
author = {Zhenda Xie and Zheng Zhang and Yue Cao and Yutong Lin and Jianmin Bao and Zhuliang Yao and Qi Dai and Han Hu},
year = {2021},
eprint = {2111.09886},
@@ -1944,7 +1944,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
```bibtex
@misc{lee2021vision,
title = {Vision Transformer for Small-Size Datasets},
title = {Vision Transformer for Small-Size Datasets},
author = {Seung Hoon Lee and Seunghyun Lee and Byung Cheol Song},
year = {2021},
eprint = {2112.13492},
@@ -1966,7 +1966,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
```bibtex
@misc{yang2022scalablevit,
title = {ScalableViT: Rethinking the Context-oriented Generalization of Vision Transformer},
title = {ScalableViT: Rethinking the Context-oriented Generalization of Vision Transformer},
author = {Rui Yang and Hailong Ma and Jie Wu and Yansong Tang and Xuefeng Xiao and Min Zheng and Xiu Li},
year = {2022},
eprint = {2203.10790},
@@ -2203,37 +2203,119 @@ Coming from computer vision and new to transformers? Here are some resources tha
```bibtex
@misc{carrigg2025decorrelationspeedsvisiontransformers,
title = {Decorrelation Speeds Up Vision Transformers},
title = {Decorrelation Speeds Up Vision Transformers},
author = {Kieran Carrigg and Rob van Gastel and Melda Yeghaian and Sander Dalm and Faysal Boughorbel and Marcel van Gerven},
year = {2025},
eprint = {2510.14657},
archivePrefix = {arXiv},
primaryClass = {cs.CV},
url = {https://arxiv.org/abs/2510.14657},
url = {https://arxiv.org/abs/2510.14657},
}
```
```bibtex
@misc{gopalakrishnan2025decouplingwhatwherepolar,
title = {Decoupling the "What" and "Where" With Polar Coordinate Positional Embeddings},
title = {Decoupling the "What" and "Where" With Polar Coordinate Positional Embeddings},
author = {Anand Gopalakrishnan and Robert Csordás and Jürgen Schmidhuber and Michael C. Mozer},
year = {2025},
eprint = {2509.10534},
archivePrefix = {arXiv},
primaryClass = {cs.LG},
url = {https://arxiv.org/abs/2509.10534},
url = {https://arxiv.org/abs/2509.10534},
}
```
```bibtex
@misc{qiu2025gatedattentionlargelanguage,
title = {Gated Attention for Large Language Models: Non-linearity, Sparsity, and Attention-Sink-Free},
title = {Gated Attention for Large Language Models: Non-linearity, Sparsity, and Attention-Sink-Free},
author = {Zihan Qiu and Zekun Wang and Bo Zheng and Zeyu Huang and Kaiyue Wen and Songlin Yang and Rui Men and Le Yu and Fei Huang and Suozhi Huang and Dayiheng Liu and Jingren Zhou and Junyang Lin},
year = {2025},
eprint = {2505.06708},
archivePrefix = {arXiv},
primaryClass = {cs.CL},
url = {https://arxiv.org/abs/2505.06708},
url = {https://arxiv.org/abs/2505.06708}
}
```
```bibtex
@misc{chen2026postlayernormbackstableexpressive,
title = {Post-LayerNorm Is Back: Stable, ExpressivE, and Deep},
author = {Chen Chen and Lai Wei},
year = {2026},
eprint = {2601.19895},
archivePrefix = {arXiv},
primaryClass = {cs.LG},
url = {https://arxiv.org/abs/2601.19895},
}
```
```bibtex
@misc{intelligence2025pi06vlalearnsexperience,
title = {$\pi^{*}_{0.6}$: a VLA That Learns From Experience},
author = {Physical Intelligence and Ali Amin and Raichelle Aniceto and Ashwin Balakrishna and Kevin Black and Ken Conley and Grace Connors and James Darpinian and Karan Dhabalia and Jared DiCarlo and Danny Driess and Michael Equi and Adnan Esmail and Yunhao Fang and Chelsea Finn and Catherine Glossop and Thomas Godden and Ivan Goryachev and Lachy Groom and Hunter Hancock and Karol Hausman and Gashon Hussein and Brian Ichter and Szymon Jakubczak and Rowan Jen and Tim Jones and Ben Katz and Liyiming Ke and Chandra Kuchi and Marinda Lamb and Devin LeBlanc and Sergey Levine and Adrian Li-Bell and Yao Lu and Vishnu Mano and Mohith Mothukuri and Suraj Nair and Karl Pertsch and Allen Z. Ren and Charvi Sharma and Lucy Xiaoyang Shi and Laura Smith and Jost Tobias Springenberg and Kyle Stachowicz and Will Stoeckle and Alex Swerdlow and James Tanner and Marcel Torne and Quan Vuong and Anna Walling and Haohuan Wang and Blake Williams and Sukwon Yoo and Lili Yu and Ury Zhilinsky and Zhiyuan Zhou},
year = {2025},
eprint = {2511.14759},
archivePrefix = {arXiv},
primaryClass = {cs.LG},
url = {https://arxiv.org/abs/2511.14759},
}
```
```bibtex
@misc{kimiteam2026attentionresiduals,
title = {Attention Residuals},
author = {Kimi Team and Guangyu Chen and Yu Zhang and Jianlin Su and Weixin Xu and Siyuan Pan and Yaoyu Wang and Yucheng Wang and Guanduo Chen and Bohong Yin and Yutian Chen and Junjie Yan and Ming Wei and Y. Zhang and Fanqing Meng and Chao Hong and Xiaotong Xie and Shaowei Liu and Enzhe Lu and Yunpeng Tai and Yanru Chen and Xin Men and Haiqing Guo and Y. Charles and Haoyu Lu and Lin Sui and Jinguo Zhu and Zaida Zhou and Weiran He and Weixiao Huang and Xinran Xu and Yuzhi Wang and Guokun Lai and Yulun Du and Yuxin Wu and Zhilin Yang and Xinyu Zhou},
year = {2026},
eprint = {2603.15031},
archivePrefix = {arXiv},
primaryClass = {cs.CL},
url = {https://arxiv.org/abs/2603.15031},
}
```
```bibtex
@misc{balestriero2025lejepa,
title = {LeJEPA: Provable and Scalable Self-Supervised Learning Without the Heuristics},
author = {Randall Balestriero and Yann LeCun},
year = {2025},
eprint = {2511.08544},
archivePrefix = {arXiv},
primaryClass = {cs.LG},
url = {https://arxiv.org/abs/2511.08544},
}
```
```bibtex
@misc{oh2026revisitingresidualconnectionsorthogonal,
title = {Revisiting Residual Connections: Orthogonal Updates for Stable and Efficient Deep Networks},
author = {Giyeong Oh and Woohyun Cho and Siyeol Kim and Suhwan Choi and Youngjae Yu},
year = {2026},
eprint = {2505.11881},
archivePrefix = {arXiv},
primaryClass = {cs.CV},
url = {https://arxiv.org/abs/2505.11881},
}
```
```bibtex
@inproceedings{niu2026learning,
title = {Learning to Grasp Anything By Playing with Random Toys},
author = {Dantong Niu and Yuvan Sharma and Baifeng Shi and Rachel Ding and Matteo Gioia and Haoru Xue and Henry Tsai and Konstantinos Kallidromitis and Anirudh Pai and S. Shankar Sastry and Trevor Darrell and Jitendra Malik and Roei Herzig},
booktitle = {The Fourteenth International Conference on Learning Representations},
year = {2026},
url = {https://openreview.net/forum?id=NZDaMcpXZm}
}
```
```bibtex
@misc{marouani2026revisitingclspatchtoken,
title = {Revisiting [CLS] and Patch Token Interaction in Vision Transformers},
author = {Alexis Marouani and Oriane Siméoni and Hervé Jégou and Piotr Bojanowski and Huy V. Vo},
year = {2026},
eprint = {2602.08626},
archivePrefix = {arXiv},
primaryClass = {cs.CV},
url = {https://arxiv.org/abs/2602.08626},
}
```

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "vit-pytorch"
version = "1.17.6"
version = "1.20.4"
description = "Vision Transformer (ViT) - Pytorch"
readme = { file = "README.md", content-type = "text/markdown" }
license = { file = "LICENSE" }
@@ -31,8 +31,8 @@ classifiers = [
"Programming Language :: Python :: 3.12",
]
dependencies = [
"einops>=0.7.0",
"torch>=1.10",
"einops>=0.8.2",
"torch>=2.4",
"torchvision",
]
@@ -44,8 +44,8 @@ test = [
]
[project.urls]
Homepage = "https://github.com/lucidrains/vit-pytorch"
Repository = "https://github.com/lucidrains/vit-pytorch"
Homepage = "https://codeberg.org/lucidrains/vit-pytorch"
Repository = "https://codeberg.org/lucidrains/vit-pytorch"
[tool.setuptools]
include-package-data = true

View File

@@ -26,7 +26,7 @@ class AcceptVideoWrapper(Module):
dim_emb = None,
time_seq_len = None,
embed_is_channel_first = False,
output_pos_add_pos_emb = 0, # defaults to first output position to add embedding
output_pos_add_pos_emb = 0, # defaults to first output position to add embedding
proj_embed_to_dim = None
):
super().__init__()

View File

@@ -103,7 +103,7 @@ class JumboViT(Module):
h = image_height // patch_height,
w = image_width // patch_width,
dim = dim,
)
)
jumbo_cls_dim = dim * jumbo_cls_k

View File

@@ -145,7 +145,7 @@ class ViT(nn.Module):
return x
def forward(self, img):
x = self.img_to_tokens(img)
x = self.img_to_tokens(img)
x = self.transformer(x)
@@ -160,7 +160,7 @@ class Adapter(nn.Module):
*,
vit,
num_memories_per_layer = 10,
num_classes = 2,
num_classes = 2,
):
super().__init__()
assert isinstance(vit, ViT)
@@ -188,7 +188,7 @@ class Adapter(nn.Module):
)
# specialized attention mask to preserve the output of the original ViT
# it allows the memory CLS token to attend to all other tokens (and the learnable memory layer tokens), but not vice versa
# it allows the memory CLS token to attend to all other tokens (and the learnable memory layer tokens), but not vice versa
attn_mask = torch.ones((num_patches, num_patches), dtype = torch.bool)
attn_mask = F.pad(attn_mask, (1, num_memories_per_layer), value = False) # main tokens cannot attend to learnable memories per layer
@@ -203,7 +203,7 @@ class Adapter(nn.Module):
# add task specific memory tokens
memory_cls_tokens = repeat(self.memory_cls_token, 'd -> b 1 d', b = b)
tokens = torch.cat((memory_cls_tokens, tokens), dim = 1)
tokens = torch.cat((memory_cls_tokens, tokens), dim = 1)
# pass memories along with image tokens through transformer for attending

319
vit_pytorch/lejepa.py Normal file
View File

@@ -0,0 +1,319 @@
import random
from functools import wraps
import torch
from torch import nn
from torch.nn import Module
import torch.nn.functional as F
from torchvision import transforms as T
from einops import rearrange
# helper functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def singleton(cache_key):
def inner_fn(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
instance = getattr(self, cache_key)
if instance is not None:
return instance
instance = fn(self, *args, **kwargs)
setattr(self, cache_key, instance)
return instance
return wrapper
return inner_fn
def get_module_device(module):
return next(module.parameters()).device
def l2norm(t, eps = 1e-6):
return F.normalize(t, dim = -1, eps = eps)
# loss function
def sigreg_loss(
x,
num_slices = 1024,
domain = (-5, 5),
num_knots = 17
):
# Randall Balestriero - https://arxiv.org/abs/2511.08544
dim, device = x.shape[-1], x.device
# slice sampling
rand_projs = torch.randn((num_slices, dim), device = device)
rand_projs = l2norm(rand_projs)
# integration points
t = torch.linspace(*domain, num_knots, device = device)
# theoretical CF for N(0, 1) and Gauss. window
exp_f = (-0.5 * t.square()).exp()
# empirical CF
x_t = torch.einsum('... d, m d -> ... m', x, rand_projs)
x_t = rearrange(x_t, '... m -> (...) m')
x_t = rearrange(x_t, 'n m -> n m 1') * t
ecf = (1j * x_t).exp().mean(dim = 0)
# weighted L2 distance
err = ecf.sub(exp_f).abs().square().mul(exp_f)
return torch.trapezoid(err, t, dim = -1).mean()
# augmentation utils
class RandomApply(Module):
def __init__(self, fn, p):
super().__init__()
self.fn = fn
self.p = p
def forward(self, x):
if random.random() > self.p:
return x
return self.fn(x)
# MLP class for projector
class L2Norm(Module):
def forward(self, x, eps = 1e-6):
return l2norm(x, eps)
class MLP(Module):
def __init__(self, dim, dim_out, num_layers, hidden_size = 256):
super().__init__()
layers = []
dims = (dim, *((hidden_size,) * (num_layers - 1)))
for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
is_last = ind == (len(dims) - 1)
layers.extend([
nn.Linear(layer_dim_in, layer_dim_out),
nn.GELU() if not is_last else nn.Identity()
])
self.net = nn.Sequential(
*layers,
L2Norm(),
nn.Linear(hidden_size, dim_out)
)
def forward(self, x):
return self.net(x)
# wrapper
class NetWrapper(Module):
def __init__(self, net, output_dim, projection_hidden_size, projection_num_layers, layer = -2):
super().__init__()
self.net = net
self.layer = layer
self.projector = None
self.projection_hidden_size = projection_hidden_size
self.projection_num_layers = projection_num_layers
self.output_dim = output_dim
self.hidden = {}
self.hook_registered = False
def _find_layer(self):
if type(self.layer) == str:
modules = dict([*self.net.named_modules()])
return modules.get(self.layer, None)
elif type(self.layer) == int:
children = [*self.net.children()]
return children[self.layer]
return None
def _hook(self, _, input, output):
device = input[0].device
self.hidden[device] = output.flatten(1)
def _register_hook(self):
layer = self._find_layer()
assert layer is not None, f'hidden layer ({self.layer}) not found'
handle = layer.register_forward_hook(self._hook)
self.hook_registered = True
@singleton('projector')
def _get_projector(self, hidden):
_, dim = hidden.shape
projector = MLP(dim, self.output_dim, self.projection_num_layers, self.projection_hidden_size)
return projector.to(hidden)
def get_embedding(self, x):
if self.layer == -1:
return self.net(x)
if not self.hook_registered:
self._register_hook()
self.hidden.clear()
_ = self.net(x)
hidden = self.hidden[x.device]
self.hidden.clear()
assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
return hidden
def forward(self, x, return_projection = True):
embed = self.get_embedding(x)
if not return_projection:
return embed
projector = self._get_projector(embed)
return projector(embed), embed
# main class
class LeJEPA(nn.Module):
def __init__(
self,
net,
image_size,
hidden_layer = -2,
projection_hidden_size = 256,
num_classes_K = 65336,
projection_layers = 4,
local_upper_crop_scale = 0.4,
global_lower_crop_scale = 0.5,
target_loss_weight = 1.,
sigreg_loss_weight = 1.,
sigreg_loss_kwargs = dict(
num_slices = 1024,
domain = (-5, 5),
num_knots = 17
),
augment_fn = None,
augment_fn2 = None
):
super().__init__()
self.net = net
# default BYOL augmentation
DEFAULT_AUG = torch.nn.Sequential(
RandomApply(
T.ColorJitter(0.8, 0.8, 0.8, 0.2),
p = 0.3
),
T.RandomGrayscale(p=0.2),
T.RandomHorizontalFlip(),
RandomApply(
T.GaussianBlur((3, 3), (1.0, 2.0)),
p = 0.2
),
T.Normalize(
mean=torch.tensor([0.485, 0.456, 0.406]),
std=torch.tensor([0.229, 0.224, 0.225])),
)
self.augment1 = default(augment_fn, DEFAULT_AUG)
self.augment2 = default(augment_fn2, DEFAULT_AUG)
# local and global crops
self.local_crop = T.RandomResizedCrop((image_size, image_size), scale = (0.05, local_upper_crop_scale))
self.global_crop = T.RandomResizedCrop((image_size, image_size), scale = (global_lower_crop_scale, 1.))
self.encoder = NetWrapper(net, num_classes_K, projection_hidden_size, projection_layers, layer = hidden_layer)
self.target_loss_weight = target_loss_weight
self.sigreg_loss_weight = sigreg_loss_weight
self.sigreg_loss_kwargs = sigreg_loss_kwargs
# get device of network and make wrapper same device
device = get_module_device(net)
self.to(device)
# send a mock image tensor to instantiate singleton parameters
self.forward(torch.randn(2, 3, image_size, image_size, device=device))
def forward(
self,
x,
return_embedding = False,
return_projection = True
):
if return_embedding:
return self.encoder(x, return_projection = return_projection)
image_one, image_two = self.augment1(x), self.augment2(x)
local_image_one, local_image_two = self.local_crop(image_one), self.local_crop(image_two)
global_image_one, global_image_two = self.global_crop(image_one), self.global_crop(image_two)
local_images = torch.cat((local_image_one, local_image_two), dim = 0)
proj_locals, _ = self.encoder(local_images)
proj_local_one, proj_local_two = proj_locals.chunk(2, dim = 0)
with torch.no_grad():
global_images = torch.cat((global_image_one, global_image_two), dim = 0)
proj_globals, _ = self.encoder(global_images)
proj_global_one, proj_global_two = proj_globals.chunk(2, dim = 0)
# invariance loss
mse_loss = F.mse_loss(proj_local_one, proj_global_two) + F.mse_loss(proj_local_two, proj_global_one)
# sigreg loss
sreg_loss = sigreg_loss(proj_locals, **self.sigreg_loss_kwargs)
return mse_loss * self.target_loss_weight + sreg_loss * self.sigreg_loss_weight
# quick run
if __name__ == '__main__':
from vit_pytorch import ViT
model = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048
)
learner = LeJEPA(
model,
image_size = 256,
hidden_layer = 'to_latent', # layer name where output is hidden dimension
projection_hidden_size = 256, # projector network hidden dimension
projection_layers = 4, # number of layers in projection network
num_classes_K = 65336, # output dimension
target_loss_weight = 1.0,
sigreg_loss_weight = 1.0
)
opt = torch.optim.Adam(learner.parameters(), lr = 3e-4)
images = torch.randn(8, 3, 256, 256)
loss = learner(images)
opt.zero_grad()
loss.backward()
opt.step()
print('loss:', loss.item())

View File

@@ -182,7 +182,7 @@ class LeViT(nn.Module):
def forward(self, img):
x = self.conv_embedding(img)
x = self.backbone(x)
x = self.backbone(x)
x = self.pool(x)

View File

@@ -52,7 +52,7 @@ class MAE(nn.Module):
if self.encoder.pool == "cls":
tokens += self.encoder.pos_embedding[:, 1:(num_patches + 1)]
elif self.encoder.pool == "mean":
tokens += self.encoder.pos_embedding.to(device, dtype=tokens.dtype)
tokens += self.encoder.pos_embedding.to(device, dtype=tokens.dtype)
# calculate of patches needed to be masked, and get random indices, dividing it up for mask vs unmasked
@@ -87,7 +87,7 @@ class MAE(nn.Module):
mask_tokens = mask_tokens + self.decoder_pos_emb(masked_indices)
# concat the masked tokens to the decoder tokens and attend with decoder
decoder_tokens = torch.zeros(batch, num_patches, self.decoder_dim, device=device)
decoder_tokens[batch_range, unmasked_indices] = unmasked_decoder_tokens
decoder_tokens[batch_range, masked_indices] = mask_tokens

View File

@@ -77,7 +77,7 @@ class Dropsample(nn.Module):
def __init__(self, prob = 0):
super().__init__()
self.prob = prob
def forward(self, x):
device = x.device

View File

@@ -72,7 +72,7 @@ class Dropsample(Module):
def __init__(self, prob = 0):
super().__init__()
self.prob = prob
def forward(self, x):
device = x.device

View File

@@ -1,243 +1,243 @@
import torch
import torch.nn as nn
from einops import rearrange
from einops.layers.torch import Reduce
# helpers
def conv_1x1_bn(inp, oup):
return nn.Sequential(
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.SiLU()
)
def conv_nxn_bn(inp, oup, kernel_size=3, stride=1):
return nn.Sequential(
nn.Conv2d(inp, oup, kernel_size, stride, 1, bias=False),
nn.BatchNorm2d(oup),
nn.SiLU()
)
# classes
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.SiLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b p h n d -> b p n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
"""Transformer block described in ViT.
Paper: https://arxiv.org/abs/2010.11929
Based on: https://github.com/lucidrains/vit-pytorch
"""
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads, dim_head, dropout),
FeedForward(dim, mlp_dim, dropout)
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class MV2Block(nn.Module):
"""MV2 block described in MobileNetV2.
Paper: https://arxiv.org/pdf/1801.04381
Based on: https://github.com/tonylins/pytorch-mobilenet-v2
"""
def __init__(self, inp, oup, stride=1, expansion=4):
super().__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = int(inp * expansion)
self.use_res_connect = self.stride == 1 and inp == oup
if expansion == 1:
self.conv = nn.Sequential(
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
1, groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
)
else:
self.conv = nn.Sequential(
# pw
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
1, groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
)
def forward(self, x):
out = self.conv(x)
if self.use_res_connect:
out = out + x
return out
class MobileViTBlock(nn.Module):
def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):
super().__init__()
self.ph, self.pw = patch_size
self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
self.conv2 = conv_1x1_bn(channel, dim)
self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout)
self.conv3 = conv_1x1_bn(dim, channel)
self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)
def forward(self, x):
y = x.clone()
# Local representations
x = self.conv1(x)
x = self.conv2(x)
# Global representations
_, _, h, w = x.shape
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
x = self.transformer(x)
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)
# Fusion
x = self.conv3(x)
x = torch.cat((x, y), 1)
x = self.conv4(x)
return x
class MobileViT(nn.Module):
"""MobileViT.
Paper: https://arxiv.org/abs/2110.02178
Based on: https://github.com/chinhsuanwu/mobilevit-pytorch
"""
def __init__(
self,
image_size,
dims,
channels,
num_classes,
expansion=4,
kernel_size=3,
patch_size=(2, 2),
depths=(2, 4, 3)
):
super().__init__()
assert len(dims) == 3, 'dims must be a tuple of 3'
assert len(depths) == 3, 'depths must be a tuple of 3'
ih, iw = image_size
ph, pw = patch_size
assert ih % ph == 0 and iw % pw == 0
init_dim, *_, last_dim = channels
self.conv1 = conv_nxn_bn(3, init_dim, stride=2)
self.stem = nn.ModuleList([])
self.stem.append(MV2Block(channels[0], channels[1], 1, expansion))
self.stem.append(MV2Block(channels[1], channels[2], 2, expansion))
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
self.trunk = nn.ModuleList([])
self.trunk.append(nn.ModuleList([
MV2Block(channels[3], channels[4], 2, expansion),
MobileViTBlock(dims[0], depths[0], channels[5],
kernel_size, patch_size, int(dims[0] * 2))
]))
self.trunk.append(nn.ModuleList([
MV2Block(channels[5], channels[6], 2, expansion),
MobileViTBlock(dims[1], depths[1], channels[7],
kernel_size, patch_size, int(dims[1] * 4))
]))
self.trunk.append(nn.ModuleList([
MV2Block(channels[7], channels[8], 2, expansion),
MobileViTBlock(dims[2], depths[2], channels[9],
kernel_size, patch_size, int(dims[2] * 4))
]))
self.to_logits = nn.Sequential(
conv_1x1_bn(channels[-2], last_dim),
Reduce('b c h w -> b c', 'mean'),
nn.Linear(channels[-1], num_classes, bias=False)
)
def forward(self, x):
x = self.conv1(x)
for conv in self.stem:
x = conv(x)
for conv, attn in self.trunk:
x = conv(x)
x = attn(x)
return self.to_logits(x)
import torch
import torch.nn as nn
from einops import rearrange
from einops.layers.torch import Reduce
# helpers
def conv_1x1_bn(inp, oup):
return nn.Sequential(
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.SiLU()
)
def conv_nxn_bn(inp, oup, kernel_size=3, stride=1):
return nn.Sequential(
nn.Conv2d(inp, oup, kernel_size, stride, 1, bias=False),
nn.BatchNorm2d(oup),
nn.SiLU()
)
# classes
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.SiLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b p h n d -> b p n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
"""Transformer block described in ViT.
Paper: https://arxiv.org/abs/2010.11929
Based on: https://github.com/lucidrains/vit-pytorch
"""
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads, dim_head, dropout),
FeedForward(dim, mlp_dim, dropout)
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class MV2Block(nn.Module):
"""MV2 block described in MobileNetV2.
Paper: https://arxiv.org/pdf/1801.04381
Based on: https://github.com/tonylins/pytorch-mobilenet-v2
"""
def __init__(self, inp, oup, stride=1, expansion=4):
super().__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = int(inp * expansion)
self.use_res_connect = self.stride == 1 and inp == oup
if expansion == 1:
self.conv = nn.Sequential(
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
1, groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
)
else:
self.conv = nn.Sequential(
# pw
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
1, groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
)
def forward(self, x):
out = self.conv(x)
if self.use_res_connect:
out = out + x
return out
class MobileViTBlock(nn.Module):
def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):
super().__init__()
self.ph, self.pw = patch_size
self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
self.conv2 = conv_1x1_bn(channel, dim)
self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout)
self.conv3 = conv_1x1_bn(dim, channel)
self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)
def forward(self, x):
y = x.clone()
# Local representations
x = self.conv1(x)
x = self.conv2(x)
# Global representations
_, _, h, w = x.shape
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
x = self.transformer(x)
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)
# Fusion
x = self.conv3(x)
x = torch.cat((x, y), 1)
x = self.conv4(x)
return x
class MobileViT(nn.Module):
"""MobileViT.
Paper: https://arxiv.org/abs/2110.02178
Based on: https://github.com/chinhsuanwu/mobilevit-pytorch
"""
def __init__(
self,
image_size,
dims,
channels,
num_classes,
expansion=4,
kernel_size=3,
patch_size=(2, 2),
depths=(2, 4, 3)
):
super().__init__()
assert len(dims) == 3, 'dims must be a tuple of 3'
assert len(depths) == 3, 'depths must be a tuple of 3'
ih, iw = image_size
ph, pw = patch_size
assert ih % ph == 0 and iw % pw == 0
init_dim, *_, last_dim = channels
self.conv1 = conv_nxn_bn(3, init_dim, stride=2)
self.stem = nn.ModuleList([])
self.stem.append(MV2Block(channels[0], channels[1], 1, expansion))
self.stem.append(MV2Block(channels[1], channels[2], 2, expansion))
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
self.trunk = nn.ModuleList([])
self.trunk.append(nn.ModuleList([
MV2Block(channels[3], channels[4], 2, expansion),
MobileViTBlock(dims[0], depths[0], channels[5],
kernel_size, patch_size, int(dims[0] * 2))
]))
self.trunk.append(nn.ModuleList([
MV2Block(channels[5], channels[6], 2, expansion),
MobileViTBlock(dims[1], depths[1], channels[7],
kernel_size, patch_size, int(dims[1] * 4))
]))
self.trunk.append(nn.ModuleList([
MV2Block(channels[7], channels[8], 2, expansion),
MobileViTBlock(dims[2], depths[2], channels[9],
kernel_size, patch_size, int(dims[2] * 4))
]))
self.to_logits = nn.Sequential(
conv_1x1_bn(channels[-2], last_dim),
Reduce('b c h w -> b c', 'mean'),
nn.Linear(channels[-1], num_classes, bias=False)
)
def forward(self, x):
x = self.conv1(x)
for conv in self.stem:
x = conv(x)
for conv, attn in self.trunk:
x = conv(x)
x = attn(x)
return self.to_logits(x)

View File

@@ -110,7 +110,7 @@ class ViT(nn.Module):
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
@@ -178,7 +178,7 @@ class MP3(nn.Module):
attended_tokens = self.vit.transformer(tokens, tokens_unmasked)
logits = rearrange(self.mlp_head(attended_tokens), 'b n d -> (b n) d')
# Define labels
labels = repeat(torch.arange(num_patches, device = device), 'n -> (b n)', b = batch)
loss = F.cross_entropy(logits, labels)

View File

@@ -343,11 +343,11 @@ class NaViT(nn.Module):
# need to know how many images for final attention pooling
num_images = torch.tensor(num_images, device = device, dtype = torch.long)
num_images = torch.tensor(num_images, device = device, dtype = torch.long)
# to patches
x = self.to_patch_embedding(patches)
x = self.to_patch_embedding(patches)
# factorized 2d absolute positional embedding

View File

@@ -64,7 +64,7 @@ class Attention(Module):
self.to_out = nn.Linear(dim_inner, dim, bias = False)
def forward(
self,
self,
x,
context: Tensor | None = None
):

View File

@@ -64,7 +64,7 @@ class Attention(Module):
self.to_out = nn.Linear(dim_inner, dim, bias = False)
def forward(
self,
self,
x,
context: Tensor | None = None
):

View File

@@ -154,7 +154,7 @@ class PiT(nn.Module):
for ind, (layer_depth, layer_heads) in enumerate(zip(depth, heads)):
not_last = ind < (len(depth) - 1)
layers.append(Transformer(dim, layer_depth, layer_heads, dim_head, mlp_dim, dropout))
if not_last:

View File

@@ -146,7 +146,7 @@ class R2LTransformer(nn.Module):
region_tokens = rearrange(region_tokens, 'b c h w -> b (h w) c')
# calculate local relative positional bias
h_range = torch.arange(window_size_h, device = device)
w_range = torch.arange(window_size_w, device = device)

View File

@@ -187,7 +187,7 @@ class InteractiveWindowedSelfAttention(nn.Module):
out = rearrange(out, '(b x y) h (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz_h, y = width // wsz_w, w1 = wsz_h, w2 = wsz_w)
# add LIM output
# add LIM output
out = out + local_out

View File

@@ -26,7 +26,7 @@ def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
omega = 1. / (temperature ** omega)
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
return pe.type(dtype)
@@ -57,7 +57,7 @@ class Attend(nn.Module):
config = self.cuda_config if q.is_cuda else self.cpu_config
# flash attention - https://arxiv.org/abs/2205.14135
with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(q, k, v)

View File

@@ -34,7 +34,7 @@ def posemb_sincos_3d(patches, temperature = 10000, dtype = torch.float32):
z = z.flatten()[:, None] * omega[None, :]
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos(), z.sin(), z.cos()), dim = 1)
@@ -52,7 +52,7 @@ class Attend(Module):
def flash_attn(self, q, k, v):
# flash attention - https://arxiv.org/abs/2205.14135
with torch.backends.cuda.sdp_kernel(**self.config._asdict()):
out = F.scaled_dot_product_attention(q, k, v)

View File

@@ -35,7 +35,7 @@ def FeedForward(dim, hidden_dim):
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64):

View File

@@ -98,7 +98,7 @@ class SimpleViT(nn.Module):
h = image_height // patch_height,
w = image_width // patch_width,
dim = dim,
)
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

View File

@@ -26,7 +26,7 @@ def posemb_sincos_3d(patches, temperature = 10000, dtype = torch.float32):
z = z.flatten()[:, None] * omega[None, :]
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos(), z.sin(), z.cos()), dim = 1)

View File

@@ -0,0 +1,241 @@
from __future__ import annotations
import torch
from torch import nn, Tensor
from torch.nn import Module, ModuleList
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
def last(arr):
return arr[-1]
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def divisible_by(num, den):
return (num % den) == 0
def posemb_sincos_2d(h, w, dim, temperature = 10000, dtype = torch.float32):
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing = 'ij')
assert divisible_by(dim, 4), 'feature dimension must be multiple of 4 for sincos emb'
omega = torch.arange(dim // 4) / (dim // 4 - 1)
omega = 1.0 / (temperature ** omega)
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
return pe.type(dtype)
# classes
class FeedForward(Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
def forward(self, x):
return self.net(x)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, cross_attend = False):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.norm_context = nn.LayerNorm(dim) if cross_attend else None
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x, context = None):
x = self.norm(x)
if exists(context):
context = self.norm_context(context)
else:
context = x
q = self.to_q(x)
k, v = self.to_kv(context).chunk(2, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = dots.softmax(dim = -1)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class AttentionResidual(Module):
def __init__(self, fn, dim, heads = 8, dim_head = 64, learned_query = True, disable = False):
super().__init__()
self.fn = fn
self.disable = disable
if disable:
return
self.attn = Attention(dim, heads = heads, dim_head = dim_head, cross_attend = True)
self.learned_query = nn.Parameter(torch.randn(dim)) if learned_query else None
def forward(self, history: list[Tensor]) -> Tensor:
if self.disable:
return self.fn(last(history))
batch, seq_len = history[0].shape[:2]
context = torch.stack(history, dim = 2)
context = rearrange(context, 'b n l d -> (b n) l d')
if exists(self.learned_query):
q = repeat(self.learned_query, 'd -> (b n) 1 d', b = batch, n = seq_len)
else:
q = rearrange(last(history), 'b n d -> (b n) 1 d')
pooled = self.attn(q, context = context)
pooled = rearrange(pooled, '(b n) 1 d -> b n d', b = batch, n = seq_len)
return self.fn(pooled)
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, learned_query = True):
super().__init__()
self.layers = ModuleList([])
for ind in range(depth):
is_first = ind == 0
self.layers.append(ModuleList([
AttentionResidual(Attention(dim, heads = heads, dim_head = dim_head), dim, heads = heads, dim_head = dim_head, learned_query = learned_query, disable = is_first),
AttentionResidual(FeedForward(dim, mlp_dim), dim, heads = heads, dim_head = dim_head, learned_query = learned_query),
]))
self.final_pool = AttentionResidual(nn.LayerNorm(dim), dim, heads = heads, dim_head = dim_head, learned_query = learned_query)
def forward(
self,
x,
history: list[Tensor] | None = None,
return_history = False
):
history = [*default(history, [])]
history.append(x)
for attn_residual, ff_residual in self.layers:
history.append(attn_residual(history))
history.append(ff_residual(history))
out = self.final_pool(history)
if return_history:
return out, history
return out
class SimpleViTAttnResidual(Module):
def __init__(
self,
*,
image_size,
patch_size,
num_classes,
dim,
depth,
heads,
mlp_dim,
channels = 3,
dim_head = 64,
learned_query = True
):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert divisible_by(image_height, patch_height) and divisible_by(image_width, patch_width), 'Image dimensions must be divisible by the patch size.'
patch_dim = channels * patch_height * patch_width
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.pos_embedding = posemb_sincos_2d(
h = image_height // patch_height,
w = image_width // patch_width,
dim = dim,
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, learned_query = learned_query)
self.to_latent = nn.Identity()
self.linear_head = nn.Linear(dim, num_classes)
def forward(
self,
img,
history: list[Tensor] | None = None,
return_history = False
):
device, dtype = img.device, img.dtype
x = self.to_patch_embedding(img)
x += self.pos_embedding.to(device, dtype = dtype)
x = self.transformer(x, history = history, return_history = return_history)
if return_history:
x, history = x
x = x.mean(dim = 1)
x = self.to_latent(x)
out = self.linear_head(x)
if return_history:
return out, history
return out
if __name__ == '__main__':
for learned_query in (True, False):
v = SimpleViTAttnResidual(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
learned_query = learned_query
)
img = torch.randn(2, 3, 256, 256)
preds, history = v(img, return_history = True)
assert preds.shape == (2, 1000)
preds, _ = v(img, history = history, return_history = True)
assert preds.shape == (2, 1000)

View File

@@ -0,0 +1,206 @@
# Revisiting Residual Connections: Orthogonal Updates for Stable and Efficient Deep Networks
# Giyeong Oh et al. https://arxiv.org/abs/2505.11881
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import Module, ModuleList
from einops import rearrange
from einops.layers.torch import Rearrange
# helpers
def exists(v):
return v is not None
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
omega = torch.arange(dim // 4) / (dim // 4 - 1)
omega = 1.0 / (temperature ** omega)
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
return pe.type(dtype)
# classes
class FeedForward(Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
def forward(self, x):
return self.net(x)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class OrthogonalResidualUpdate(Module):
def __init__(
self,
block: Module,
dim = None,
double_precision = True,
learned = False
):
super().__init__()
self.block = block
self.double_precision = double_precision
self.learned = learned
if learned:
assert exists(dim)
self.to_modulation = nn.Linear(dim, 2)
def orthog_proj(self, block_out, residual):
use_double, dtype = self.double_precision, residual.dtype
if use_double:
residual, block_out = residual.double(), block_out.double()
# get orthogonal projection of the attention or feedforward output respect to residual
unit = F.normalize(residual, dim = -1)
parallel = (block_out * unit).sum(dim = -1, keepdim = True) * unit
orthogonal = block_out - parallel
# back to original dtype if double precision
if use_double:
parallel, orthogonal = parallel.to(dtype), orthogonal.to(dtype)
return parallel, orthogonal
def forward(self, residual):
block_out = self.block(residual)
parallel_update, orthog_update = self.orthog_proj(block_out, residual)
if self.learned:
parallel_mod, orthog_mod = self.to_modulation(block_out).sigmoid().split(1, dim = -1)
parallel_update = parallel_update * parallel_mod
orthog_update = orthog_update * orthog_mod
else:
parallel_update = 0
return residual + parallel_update + orthog_update
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, orthog_residual_update_kwargs: dict = dict()):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = ModuleList([])
for _ in range(depth):
attn = Attention(dim, heads = heads, dim_head = dim_head)
ff = FeedForward(dim, mlp_dim)
self.layers.append(ModuleList([
OrthogonalResidualUpdate(attn, dim = dim, **orthog_residual_update_kwargs),
OrthogonalResidualUpdate(ff, dim = dim, **orthog_residual_update_kwargs)
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x)
x = ff(x)
return self.norm(x)
class SimpleViT(Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, orthog_residual_update_kwargs: dict = dict()):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
patch_dim = channels * patch_height * patch_width
self.to_patch_embedding = nn.Sequential(
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.pos_embedding = posemb_sincos_2d(
h = image_height // patch_height,
w = image_width // patch_width,
dim = dim,
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, orthog_residual_update_kwargs)
self.pool = "mean"
self.to_latent = nn.Identity()
self.linear_head = nn.Linear(dim, num_classes)
def forward(self, img):
device = img.device
x = self.to_patch_embedding(img)
x += self.pos_embedding.to(device, dtype=x.dtype)
x = self.transformer(x)
x = x.mean(dim = 1)
x = self.to_latent(x)
return self.linear_head(x)
# quick test
if __name__ == '__main__':
vit = SimpleViT(
image_size = 256,
patch_size = 16,
num_classes = 10,
dim = 512,
depth = 2,
heads = 4,
mlp_dim = 2048,
orthog_residual_update_kwargs = dict(
learned = True
)
)
images = torch.randn(2, 3, 256, 256)
assert vit(images).shape == (2, 10)

View File

@@ -186,7 +186,7 @@ class SimpleViT(nn.Module):
h = image_height // patch_height,
w = image_width // patch_width,
dim = dim,
)
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, num_residual_streams)

View File

@@ -18,7 +18,7 @@ def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
omega = 1. / (temperature ** omega)
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
return pe.type(dtype)

View File

@@ -119,7 +119,7 @@ class SimpleViT(nn.Module):
h = image_height // patch_height,
w = image_width // patch_width,
dim = dim,
)
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

View File

@@ -105,7 +105,7 @@ class SimpleViT(nn.Module):
h = image_height // patch_height,
w = image_width // patch_width,
dim = dim,
)
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

View File

@@ -0,0 +1,205 @@
from __future__ import annotations
# Alexis Marouani et al. https://arxiv.org/abs/2602.08626
import torch
from torch import nn, cat, Tensor, is_tensor
from torch.nn import Module, ModuleList
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
omega = torch.arange(dim // 4) / (dim // 4 - 1)
omega = 1.0 / (temperature ** omega)
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
return pe.type(dtype)
# classes
class Specialized(Module):
def __init__(self, modules: list[Module]):
super().__init__()
self.fns = ModuleList(modules)
def forward(
self,
x: Tensor | list[Tensor],
token_lens: tuple[int, ...] = None
):
if is_tensor(x):
assert exists(token_lens)
x = x.split(token_lens, dim = 1)
assert len(self.fns) == len(x)
out = tuple(fn(t) for fn, t in zip(self.fns, x))
if is_tensor:
out = cat(out, dim = 1)
return out
class FeedForward(Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.norm = Specialized([
nn.LayerNorm(dim),
nn.LayerNorm(dim)
])
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
def forward(self, x, token_lens = None):
x = self.norm(x, token_lens = token_lens)
return self.net(x)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, specialize_qkv = False):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = Specialized([
nn.LayerNorm(dim),
nn.LayerNorm(dim)
])
self.attend = nn.Softmax(dim = -1)
self.specialize_qkv = specialize_qkv
if specialize_qkv:
self.to_qkv = Specialized([
nn.Linear(dim, inner_dim * 3, bias = False),
nn.Linear(dim, inner_dim * 3, bias = False)
])
else:
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x, token_lens = None):
x = self.norm(x, token_lens = token_lens)
if self.specialize_qkv:
qkv = self.to_qkv(x, token_lens = token_lens).chunk(3, dim = -1)
else:
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = (rearrange(t, 'b n (h d) -> b h n d', h = self.heads) for t in qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, specialize_qkv_depth):
super().__init__()
self.norm = Specialized([nn.LayerNorm(dim), nn.LayerNorm(dim)])
self.layers = ModuleList([])
for ind in range(depth):
specialize_qkv = ind < specialize_qkv_depth
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, specialize_qkv = specialize_qkv),
FeedForward(dim, mlp_dim)
]))
def forward(self, x, token_lens = None):
for attn, ff in self.layers:
x = attn(x, token_lens = token_lens) + x
x = ff(x, token_lens = token_lens) + x
return self.norm(x, token_lens = token_lens)
class SimpleViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, specialize_qkv_depth = None):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
patch_dim = channels * patch_height * patch_width
self.to_patch_embedding = nn.Sequential(
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.pos_embedding = posemb_sincos_2d(
h = image_height // patch_height,
w = image_width // patch_width,
dim = dim,
)
self.cls_token = nn.Parameter(torch.randn(dim) * 1e-2)
specialize_qkv_depth = default(specialize_qkv_depth, depth // 3) # author found just first third of transformer having specialized qkv projection for cls token is enough
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, specialize_qkv_depth)
self.pool = 'cls'
self.to_latent = nn.Identity()
self.linear_head = nn.Linear(dim, num_classes)
def forward(self, img):
device = img.device
x = self.to_patch_embedding(img)
x += self.pos_embedding.to(device, dtype=x.dtype)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, 'd -> b 1 d', b = b)
x = cat((cls_tokens, x), dim = 1)
x = self.transformer(x, token_lens = (1, n))
x = x[:, 0]
x = self.to_latent(x)
return self.linear_head(x)
if __name__ == '__main__':
v = SimpleViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048
)
img = torch.randn(1, 3, 256, 256)
out = v(img)
assert out.shape == (1, 1000)

View File

@@ -120,7 +120,7 @@ class SimpleViT(Module):
h = image_height // patch_height,
w = image_width // patch_width,
dim = dim,
)
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

View File

@@ -41,7 +41,7 @@ def posemb_sincos_2d(
omega = temperature ** -omega
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
pe = pe.type(dtype)
@@ -442,7 +442,8 @@ class VAAT(Module):
self_attn_heads = 4,
self_attn_dim_head = 32,
ast_layer_indices: tuple[int, ...] | None = None,
vit_layer_indices: tuple[int, ...] | None = None
vit_layer_indices: tuple[int, ...] | None = None,
num_advantage_bins = 0
):
super().__init__()
@@ -480,7 +481,7 @@ class VAAT(Module):
assert len(ast_layer_indices) == depth, f'number of ast layer indices {len(ast_layer_indices)} does not much the VAAT depth {depth}'
self.register_buffer('ast_layer_indices', tensor(vit_layer_indices), persistent = False)
self.register_buffer('ast_layer_indices', tensor(ast_layer_indices), persistent = False)
# handle maybe multiple frames
@@ -511,6 +512,14 @@ class VAAT(Module):
self.action_pos_emb = nn.Parameter(torch.randn(action_chunk_len, dim) * 1e-2)
# handle maybe advantage conditioning
self.has_advantages = num_advantage_bins > 0
self.num_advantage_bins = num_advantage_bins
if self.has_advantages:
self.advantage_emb = nn.Embedding(num_advantage_bins + 1, dim)
self.layers = ModuleList([])
for _ in range(depth):
@@ -540,14 +549,15 @@ class VAAT(Module):
video_or_image, # (b v? c t? h w) - batch, views [wrist + third person or more], channels, maybe time, height, width
audio_or_spec, # (b v? t) | (b v?f t) - batch, audio len | batch, spec freq, time
*,
extra = None, # (b d) - batch, dim extra
extra = None, # (b d) - batch, dim extra
tasks = None, # (b)
advantages = None,# (b)
actions = None, # (b k d) - batch, action chunk length, action dimension
return_hiddens = False,
freeze_vit = False,
freeze_ast = False
):
batch = video_or_image.shape[0]
batch, device = video_or_image.shape[0], video_or_image.device
return_loss = exists(actions)
# handle some various input dimensions
@@ -655,53 +665,66 @@ class VAAT(Module):
audio_context = rearrange(audio_hiddens, 'l b v n d -> l b (v n) d')
# get main action tokens and maybe append extra
# main action tokens
action_tokens = repeat(self.action_pos_emb, 'k d -> b k d', b = batch)
action_tokens = repeat(self.action_pos_emb, 'n d -> b n d', b = batch)
has_extra = exists(extra)
# maybe advantage tokens
if has_extra:
assert self.accept_extra_token
empty_token = action_tokens[:, 0:0]
extra_token = self.to_extra_token(extra)
maybe_advantage_embed = empty_token
action_tokens, packed_extra = pack([action_tokens, extra_token], 'b * d')
if self.has_advantages and exists(advantages):
if isinstance(advantages, int):
advantages = torch.full((batch,), advantages, device = device, dtype = torch.long)
maybe_advantage_embed = self.advantage_emb(advantages + 1)
# register tokens
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
register_tokens = empty_token
action_tokens, registers_packed_shape = pack((register_tokens, action_tokens), 'b * d')
if exists(self.register_tokens):
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
# cross attention
# extra
hiddens = [action_tokens]
maybe_extra_embed = empty_token
has_extra = exists(extra)
if has_extra:
assert self.accept_extra_token
maybe_extra_embed = self.to_extra_token(extra)
# pack all tokens for attention
tokens, ps = pack((register_tokens, maybe_advantage_embed, action_tokens, maybe_extra_embed), 'b * d')
# transformer
hiddens = [tokens]
for (maybe_film, maybe_self_attn, image_cross_attn, audio_cross_attn, ff), image_layer_context, audio_layer_context in zip(self.layers, image_context, audio_context):
if exists(tasks):
action_tokens = maybe_film(action_tokens, task_emb)
if exists(maybe_film) and exists(tasks):
tokens = maybe_film(tokens, task_emb)
action_tokens = image_cross_attn(action_tokens, image_layer_context) + action_tokens
tokens = image_cross_attn(tokens, image_layer_context) + tokens
action_tokens = audio_cross_attn(action_tokens, audio_layer_context) + action_tokens
tokens = audio_cross_attn(tokens, audio_layer_context) + tokens
if exists(maybe_self_attn):
action_tokens = maybe_self_attn(action_tokens) + action_tokens
tokens = maybe_self_attn(tokens) + tokens
action_tokens = ff(action_tokens) + action_tokens
tokens = ff(tokens) + tokens
hiddens.append(action_tokens)
hiddens.append(tokens)
# unpack registers
# unpack register, advantage, action, and extra tokens
_, action_tokens = unpack(action_tokens, registers_packed_shape, 'b * d')
# maybe unpack extra
if has_extra:
action_tokens, _ = unpack(action_tokens, packed_extra, 'b * d')
maybe_register_embed, maybe_advantage_embed, action_tokens, maybe_extra_embed = unpack(tokens, ps, 'b * d')
# norm and prediction
@@ -744,43 +767,51 @@ if __name__ == '__main__':
mlp_dim = 384 * 4
)
vaat = VAAT(
vit,
ast,
dim = 512,
depth = 9,
heads = 8,
dim_head = 64,
mlp_dim = 2048,
dim_action = 20,
action_chunk_len = 7,
time_seq_len = 4,
num_image_views = 2,
num_audio_views = 2,
num_tasks = 4,
add_self_attn = True,
dim_extra_token = 33, # extra token with some variable dimension
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
0, 0, 1, 1, 2, 2, 3, 3, 4
),
ast_layer_indices = (
1, 1, 1, 2, 2, 2, 3, 3, 3
for num_adv_bins in (0, 2, 10):
vaat = VAAT(
vit,
ast,
dim = 512,
depth = 9,
heads = 8,
dim_head = 64,
mlp_dim = 2048,
dim_action = 20,
action_chunk_len = 7,
time_seq_len = 4,
num_image_views = 2,
num_audio_views = 2,
num_tasks = 4,
num_advantage_bins = num_adv_bins,
add_self_attn = True,
dim_extra_token = 33, # extra token with some variable dimension
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
0, 0, 1, 1, 2, 2, 3, 3, 4
),
ast_layer_indices = (
1, 1, 1, 2, 2, 2, 3, 3, 3
)
)
)
images = torch.randn(2, 2, 3, 4, 256, 256) # (2 views with 4 frames)
audio = torch.randn(2, 2, 14_100 * 5)
images = torch.randn(2, 2, 3, 4, 256, 256) # (2 views with 4 frames)
audio = torch.randn(2, 2, 14_100 * 5)
tasks = torch.randint(0, 4, (2,))
extra = torch.randn(2, 33) # extra internal state
tasks = torch.randint(0, 4, (2,))
extra = torch.randn(2, 33) # extra internal state
actions = torch.randn(2, 7, 20) # actions for learning
# advantage conditioning
loss = vaat(images, audio, actions = actions, tasks = tasks, extra = extra, freeze_vit = True)
loss.backward()
advantages = None
if num_adv_bins > 0:
advantages = torch.randint(-1, num_adv_bins, (2,))
# after much training
actions = torch.randn(2, 7, 20) # actions for learning
pred_actions, hiddens = vaat(images, audio, tasks = tasks, extra = extra, return_hiddens = True)
loss = vaat(images, audio, actions = actions, advantages = advantages, tasks = tasks, extra = extra, freeze_vit = True)
loss.backward()
assert pred_actions.shape == (2, 7, 20)
# after much training
pred_actions, hiddens = vaat(images, audio, advantages = advantages, tasks = tasks, extra = extra, return_hiddens = True)
assert pred_actions.shape == (2, 7, 20)

View File

@@ -278,7 +278,8 @@ class VAT(Module):
add_self_attn = True, # in the paper, they didn't have any ways for the action token to exchange information with the extra token, so we'll just add it as an option
self_attn_heads = 4,
self_attn_dim_head = 32,
vit_layer_indices: tuple[int, ...] | None = None
vit_layer_indices: tuple[int, ...] | None = None,
num_advantage_bins = 0
):
super().__init__()
@@ -324,6 +325,14 @@ class VAT(Module):
self.action_pos_emb = nn.Parameter(torch.randn(action_chunk_len, dim) * 1e-2)
# handle maybe advantage conditioning
self.has_advantages = num_advantage_bins > 0
self.num_advantage_bins = num_advantage_bins
if self.has_advantages:
self.advantage_emb = nn.Embedding(num_advantage_bins + 1, dim)
self.layers = ModuleList([])
for _ in range(depth):
@@ -351,13 +360,14 @@ class VAT(Module):
self,
video_or_image, # (b v? c t? h w) - batch, views [wrist + third person or more], channels, maybe time, height, width
*,
extra = None, # (b d) - batch, dim extra
extra = None, # (b d) - batch, dim extra
tasks = None, # (b)
advantages = None,# (b)
actions = None, # (b k d) - batch, action chunk length, action dimension
return_hiddens = False,
freeze_vit = False
):
batch = video_or_image.shape[0]
batch, device = video_or_image.shape[0], video_or_image.device
return_loss = exists(actions)
# handle some various input dimensions
@@ -423,51 +433,64 @@ class VAT(Module):
context = rearrange(hiddens, 'l b v t n d -> l b (v t n) d')
# get main action tokens and maybe append extra
# main action tokens
action_tokens = repeat(self.action_pos_emb, 'k d -> b k d', b = batch)
action_tokens = repeat(self.action_pos_emb, 'n d -> b n d', b = batch)
has_extra = exists(extra)
# maybe advantage tokens
if has_extra:
assert self.accept_extra_token
empty_token = action_tokens[:, 0:0]
extra_token = self.to_extra_token(extra)
maybe_advantage_embed = empty_token
action_tokens, packed_extra = pack([action_tokens, extra_token], 'b * d')
if self.has_advantages and exists(advantages):
if isinstance(advantages, int):
advantages = torch.full((batch,), advantages, device = device, dtype = torch.long)
maybe_advantage_embed = self.advantage_emb(advantages + 1)
# register tokens
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
register_tokens = empty_token
action_tokens, registers_packed_shape = pack((register_tokens, action_tokens), 'b * d')
if exists(self.register_tokens):
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
# cross attention
# extra
hiddens = [action_tokens]
maybe_extra_embed = empty_token
has_extra = exists(extra)
if has_extra:
assert self.accept_extra_token
maybe_extra_embed = self.to_extra_token(extra)
# pack all tokens for attention
tokens, ps = pack((register_tokens, maybe_advantage_embed, action_tokens, maybe_extra_embed), 'b * d')
# transformer
hiddens = [tokens]
for (maybe_film, maybe_self_attn, cross_attn, ff), layer_context in zip(self.layers, context):
if exists(tasks):
action_tokens = maybe_film(action_tokens, task_emb)
if exists(maybe_film) and exists(tasks):
tokens = maybe_film(tokens, task_emb)
action_tokens = cross_attn(action_tokens, layer_context) + action_tokens
tokens = cross_attn(tokens, layer_context) + tokens
if exists(maybe_self_attn):
action_tokens = maybe_self_attn(action_tokens) + action_tokens
tokens = maybe_self_attn(tokens) + tokens
action_tokens = ff(action_tokens) + action_tokens
tokens = ff(tokens) + tokens
hiddens.append(action_tokens)
hiddens.append(tokens)
# unpack registers
# unpack register, advantage, action, and extra tokens
_, action_tokens = unpack(action_tokens, registers_packed_shape, 'b * d')
# maybe unpack extra
if has_extra:
action_tokens, _ = unpack(action_tokens, packed_extra, 'b * d')
maybe_register_embed, maybe_advantage_embed, action_tokens, maybe_extra_embed = unpack(tokens, ps, 'b * d')
# norm and prediction
@@ -501,36 +524,44 @@ if __name__ == '__main__':
mlp_dim = 1024
)
vat = VAT(
vit,
dim = 512,
depth = 9,
heads = 8,
dim_head = 64,
mlp_dim = 2048,
dim_action = 20,
action_chunk_len = 7,
time_seq_len = 4,
num_views = 2,
num_tasks = 4,
add_self_attn = True,
dim_extra_token = 33, # extra token with some variable dimension
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
0, 0, 1, 1, 2, 2, 3, 3, 4
for num_adv_bins in (0, 2, 10):
vat = VAT(
vit,
dim = 512,
depth = 9,
heads = 8,
dim_head = 64,
mlp_dim = 2048,
dim_action = 20,
action_chunk_len = 7,
time_seq_len = 4,
num_views = 2,
num_tasks = 4,
num_advantage_bins = num_adv_bins,
add_self_attn = True,
dim_extra_token = 33, # extra token with some variable dimension
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
0, 0, 1, 1, 2, 2, 3, 3, 4
)
)
)
images = torch.randn(2, 2, 3, 4, 256, 256) # (2 views with 4 frames)
tasks = torch.randint(0, 4, (2,))
extra = torch.randn(2, 33) # extra internal state
images = torch.randn(2, 2, 3, 4, 256, 256) # (2 views with 4 frames)
tasks = torch.randint(0, 4, (2,))
extra = torch.randn(2, 33) # extra internal state
actions = torch.randn(2, 7, 20) # actions for learning
# advantage conditioning
loss = vat(images, actions = actions, tasks = tasks, extra = extra, freeze_vit = True)
loss.backward()
advantages = None
if num_adv_bins > 0:
advantages = torch.randint(-1, num_adv_bins, (2,))
# after much training
actions = torch.randn(2, 7, 20) # actions for learning
pred_actions, hiddens = vat(images, tasks = tasks, extra = extra, return_hiddens = True)
loss = vat(images, actions = actions, advantages = advantages, tasks = tasks, extra = extra, freeze_vit = True)
loss.backward()
assert pred_actions.shape == (2, 7, 20)
# after much training
pred_actions, hiddens = vat(images, advantages = advantages, tasks = tasks, extra = extra, return_hiddens = True)
assert pred_actions.shape == (2, 7, 20)

View File

@@ -188,6 +188,7 @@ class SigLIPVAT(Module):
self_attn_heads = 4,
self_attn_dim_head = 32,
vit_layer_indices: tuple[int, ...] | None = None,
num_advantage_bins = 0,
siglip_image_size = 224,
siglip_patch_size = 14,
siglip_dim = 1152,
@@ -240,6 +241,14 @@ class SigLIPVAT(Module):
self.action_pos_emb = nn.Parameter(torch.randn(action_chunk_len, dim) * 1e-2)
# handle maybe advantage conditioning
self.has_advantages = num_advantage_bins > 0
self.num_advantage_bins = num_advantage_bins
if self.has_advantages:
self.advantage_emb = nn.Embedding(num_advantage_bins + 1, dim)
self.layers = ModuleList([])
for _ in range(depth):
maybe_film = FiLM(dim = dim) if self.has_tasks else None
@@ -281,13 +290,13 @@ class SigLIPVAT(Module):
# Auto-detect prefix based on keys
with safe_open(weights_path, framework = 'pt') as f:
keys = f.keys()
vi_p = ''
if any(k.startswith('paligemma_with_expert.paligemma.model.vision_tower.vision_model') for k in keys):
vi_p = 'paligemma_with_expert.paligemma.model.vision_tower.vision_model.'
elif any(k.startswith('vision_model') for k in keys):
vi_p = 'vision_model.'
pz_state = self.vit.state_dict()
def copy_weight_bias(pz_prefix, vi_prefix):
@@ -333,15 +342,16 @@ class SigLIPVAT(Module):
def forward(
self,
video_or_image, # (b v? c t? h w)
video_or_image, # (b v? c t? h w) - batch, views [wrist + third person or more], channels, maybe time, height, width
*,
extra = None,
tasks = None,
actions = None,
extra = None, # (b d) - batch, dim extra
tasks = None, # (b)
advantages = None,# (b)
actions = None, # (b k d) - batch, action chunk length, action dimension
return_hiddens = False,
freeze_vit = False
):
batch = video_or_image.shape[0]
batch, device = video_or_image.shape[0], video_or_image.device
return_loss = exists(actions)
# handle some various input dimensions
@@ -397,46 +407,62 @@ class SigLIPVAT(Module):
context = rearrange(hiddens, 'l b v t n d -> l b (v t n) d')
# get main action tokens and maybe append extra
# main action tokens
action_tokens = repeat(self.action_pos_emb, 'k d -> b k d', b = batch)
action_tokens = repeat(self.action_pos_emb, 'n d -> b n d', b = batch)
has_extra = exists(extra)
if has_extra:
extra_token = self.to_extra_token(extra)
action_tokens, packed_extra = pack([action_tokens, extra_token], 'b * d')
# maybe advantage tokens
empty_token = action_tokens[:, 0:0]
maybe_advantage_embed = empty_token
if self.has_advantages and exists(advantages):
if isinstance(advantages, int):
advantages = torch.full((batch,), advantages, device = device, dtype = torch.long)
maybe_advantage_embed = self.advantage_emb(advantages + 1)
# register tokens
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
action_tokens, registers_packed_shape = pack((register_tokens, action_tokens), 'b * d')
register_tokens = empty_token
# cross attention
if exists(self.register_tokens):
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
vat_hiddens = [action_tokens]
# extra
maybe_extra_embed = empty_token
has_extra = exists(extra)
if has_extra:
maybe_extra_embed = self.to_extra_token(extra)
# pack all tokens for attention
tokens, ps = pack((register_tokens, maybe_advantage_embed, action_tokens, maybe_extra_embed), 'b * d')
# transformer
vat_hiddens = [tokens]
for (maybe_film, maybe_self_attn, cross_attn, ff), layer_context in zip(self.layers, context):
if exists(tasks):
action_tokens = maybe_film(action_tokens, task_emb)
if exists(maybe_film) and exists(tasks):
tokens = maybe_film(tokens, task_emb)
action_tokens = cross_attn(action_tokens, layer_context) + action_tokens
tokens = cross_attn(tokens, layer_context) + tokens
if exists(maybe_self_attn):
action_tokens = maybe_self_attn(action_tokens) + action_tokens
tokens = maybe_self_attn(tokens) + tokens
action_tokens = ff(action_tokens) + action_tokens
tokens = ff(tokens) + tokens
vat_hiddens.append(action_tokens)
vat_hiddens.append(tokens)
# unpack registers
# unpack register, advantage, action, and extra tokens
_, action_tokens = unpack(action_tokens, registers_packed_shape, 'b * d')
# maybe unpack extra
if has_extra:
action_tokens, _ = unpack(action_tokens, packed_extra, 'b * d')
maybe_register_embed, maybe_advantage_embed, action_tokens, maybe_extra_embed = unpack(tokens, ps, 'b * d')
# norm and prediction
@@ -456,32 +482,40 @@ class SigLIPVAT(Module):
# quick test
if __name__ == '__main__':
vat = SigLIPVAT(
num_tasks = 4,
dim_extra_token = 32,
time_seq_len = 2,
num_views = 2,
depth = 4,
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
0, 1, 26, 27
for num_adv_bins in (0, 2, 10):
vat = SigLIPVAT(
num_tasks = 4,
dim_extra_token = 32,
time_seq_len = 2,
num_views = 2,
depth = 4,
num_advantage_bins = num_adv_bins,
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
0, 1, 26, 27
)
)
)
vat.load_siglip() # load siglip weights from hf
vat.load_siglip() # load siglip weights from hf
# inputs
# inputs
images = torch.randn(1, 2, 3, 2, 224, 224) # (b, v, c, t, h, w)
tasks = torch.randint(0, 4, (1,))
extra = torch.randn(1, 32)
images = torch.randn(1, 2, 3, 2, 224, 224) # (b, v, c, t, h, w)
tasks = torch.randint(0, 4, (1,))
extra = torch.randn(1, 32)
actions = torch.randn(1, 50, 32) # actions for learning
# advantage conditioning
loss = vat(images, actions = actions, tasks = tasks, extra = extra, freeze_vit = True)
loss.backward()
advantages = None
if num_adv_bins > 0:
advantages = torch.randint(-1, num_adv_bins, (1,))
# after much training
actions = torch.randn(1, 50, 32) # actions for learning
pred_actions = vat(images, tasks = tasks, extra = extra)
assert pred_actions.shape == (1, 50, 32)
loss = vat(images, actions = actions, advantages = advantages, tasks = tasks, extra = extra, freeze_vit = True)
loss.backward()
# after much training
pred_actions = vat(images, advantages = advantages, tasks = tasks, extra = extra)
assert pred_actions.shape == (1, 50, 32)

254
vit_pytorch/vit_detpool.py Normal file
View File

@@ -0,0 +1,254 @@
from __future__ import annotations
# DetPool ViT - a vit that accepts an object mask and attends and pools only using that mask - table 1
# Dantong Niu et al. - https://openreview.net/forum?id=NZDaMcpXZm
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import Module, ModuleList
from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange, Reduce
# helpers
def exists(val):
return val is not None
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def masked_mean(t, mask, dim = 1, eps = 1e-5):
if not exists(mask):
return t.mean(dim = dim)
mask = rearrange(mask.bool(), '... -> ... 1')
t = t.masked_fill(~mask, 0.)
return t.sum(dim = dim) / mask.sum(dim = dim).clamp(min = eps)
# classes
class FeedForward(Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x, mask = None):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = (rearrange(t, 'b n (h d) -> b h n d', h = self.heads) for t in qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
mask_value = -torch.finfo(dots.dtype).max
dots = dots.masked_fill(~mask, mask_value)
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x, mask = None):
for attn, ff in self.layers:
x = attn(x, mask = mask) + x
x = ff(x) + x
return self.norm(x)
class ViTDetPool(Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, use_cls_token = True, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., mask_generator: Module | None = None):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
self.patch_height = patch_height
self.patch_width = patch_width
self.downsample_mask = Reduce('b (h p1) (w p2) -> b (h w)', 'max', p1 = patch_height, p2 = patch_width)
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
# maybe cls
self.use_cls_token = use_cls_token
if use_cls_token:
self.cls_token = nn.Parameter(torch.randn(dim) * 1e-2)
self.pos_embedding = nn.Parameter(torch.randn(num_patches, dim) * 1e-2)
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes) if num_classes > 0 else None
self.mask_generator = mask_generator
def forward(self, img, object_mask = None):
if not exists(object_mask) and exists(self.mask_generator):
with torch.no_grad():
self.mask_generator.eval()
object_mask = self.mask_generator(img)
has_cls = self.use_cls_token
batch, _, height, width = img.shape
tokens = self.to_patch_embedding(img)
seq = tokens.shape[1]
tokens = tokens + self.pos_embedding[:seq]
if has_cls:
cls_token = repeat(self.cls_token, 'd -> b d', b = batch)
tokens, packed_shape = pack((cls_token, tokens), 'b * d')
tokens = self.dropout(tokens)
# handle the attention mask, and for final pooling
mask = None
if exists(object_mask):
assert object_mask.ndim in {3, 2}
if object_mask.shape == (batch, height, width):
mask = self.downsample_mask(object_mask)
else:
mask = object_mask
mask = rearrange(mask, 'b ... -> b (...)')
assert mask.shape == (batch, seq)
if has_cls:
mask = F.pad(mask, (1, 0), value = True)
# attend with maybe mask
tokens = self.transformer(tokens, mask = mask)
if not exists(self.mlp_head):
return tokens
# splice out cls
if has_cls:
_, tokens = unpack(tokens, packed_shape, 'b * d')
if exists(mask):
mask = mask[..., 1:]
# pooling with the mask
pooled = masked_mean(tokens, mask, dim = 1)
pooled = self.to_latent(pooled)
return self.mlp_head(pooled)
# quick test
if __name__ == '__main__':
vit = ViTDetPool(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
img = torch.randn(1, 3, 256, 256)
object_mask = torch.randint(0, 2, (1, 256, 256)).bool()
preds = vit(img, object_mask = object_mask)
assert preds.shape == (1, 1000)
preds_no_mask = vit(img)
assert preds_no_mask.shape == (1, 1000)
# test with module included
class MockMasker(Module):
def forward(self, img):
batch, _, height, width = img.shape
return torch.ones(batch, height, width).bool()
vit = ViTDetPool(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 1,
heads = 16,
mlp_dim = 2048,
mask_generator = MockMasker()
)
preds = vit(img)
assert preds.shape == (1, 1000)

View File

@@ -31,7 +31,7 @@ class FeedForward(Module):
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
@@ -40,31 +40,31 @@ class Attention(Module):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
@@ -79,7 +79,7 @@ class Transformer(Module):
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
@@ -105,73 +105,73 @@ class ViTND(Module):
emb_dropout: float = 0.
):
super().__init__()
assert 1 <= ndim <= 7, 'ndim must be between 1 and 7'
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.ndim = ndim
self.pool = pool
input_shape = ensure_tuple(input_shape, ndim)
patch_size = ensure_tuple(patch_size, ndim)
for i, (inp_dim, patch_dim) in enumerate(zip(input_shape, patch_size)):
assert inp_dim % patch_dim == 0, f'Input dimension {i} ({inp_dim}) must be divisible by patch size ({patch_dim})'
num_patches_per_dim = [inp_dim // patch_dim for inp_dim, patch_dim in zip(input_shape, patch_size)]
num_patches = 1
for n in num_patches_per_dim:
num_patches *= n
patch_dim = channels
for p in patch_size:
patch_dim *= p
dim_names = 'fghijkl'[:ndim]
input_dims = [f'({d} p{i})' for i, d in enumerate(dim_names)]
patch_dims = [f'p{i}' for i in range(ndim)]
input_pattern = f'b c {join(input_dims)}'
output_pattern = f'b ({join(dim_names)}) ({join(patch_dims)} c)'
rearrange_str = f'{input_pattern} -> {output_pattern}'
rearrange_kwargs = {f'p{i}': p for i, p in enumerate(patch_size)}
self.to_patch_embedding = nn.Sequential(
Rearrange(rearrange_str, **rearrange_kwargs),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes)
def forward(self, x):
x = self.to_patch_embedding(x)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim = 1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x)
x = x[:, 1:].mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)
if __name__ == '__main__':
model = ViTND(
ndim = 4,
input_shape = (8, 16, 32, 64),
@@ -185,7 +185,7 @@ if __name__ == '__main__':
dropout = 0.1,
emb_dropout = 0.1
)
occupancy_time = torch.randn(2, 3, 8, 16, 32, 64)
logits = model(occupancy_time)

View File

@@ -121,7 +121,7 @@ class FeedForward(Module):
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
@@ -130,14 +130,14 @@ class Attention(Module):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qk = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_v = nn.Linear(dim, inner_dim, bias = False)
@@ -145,7 +145,7 @@ class Attention(Module):
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x, polar_pos_emb = None):
x = self.norm(x)
qkv = (*self.to_qk(x).chunk(2, dim = -1), self.to_v(x))
@@ -156,12 +156,12 @@ class Attention(Module):
freqs, bias = polar_pos_emb
q = apply_polar_pos_emb(q, freqs)
k = apply_polar_pos_emb(k, freqs + bias)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
@@ -180,7 +180,7 @@ class Transformer(Module):
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x, pos = None):
# pope embedding
@@ -219,45 +219,45 @@ class ViTND(Module):
pope_init_learned_bias_uniform = False
):
super().__init__()
assert 1 <= ndim <= 7, 'ndim must be between 1 and 7'
self.ndim = ndim
input_shape = ensure_tuple(input_shape, ndim)
patch_size = ensure_tuple(patch_size, ndim)
for i, (inp_dim, patch_dim) in enumerate(zip(input_shape, patch_size)):
assert inp_dim % patch_dim == 0, f'Input dimension {i} ({inp_dim}) must be divisible by patch size ({patch_dim})'
num_patches_per_dim = [inp_dim // patch_dim for inp_dim, patch_dim in zip(input_shape, patch_size)]
num_patches = 1
for n in num_patches_per_dim:
num_patches *= n
patch_dim = channels
for p in patch_size:
patch_dim *= p
dim_names = 'fghijkl'[:ndim]
input_dims = [f'({d} p{i})' for i, d in enumerate(dim_names)]
patch_dims = [f'p{i}' for i in range(ndim)]
input_pattern = f'b c {join(input_dims)}'
output_pattern = f'b {join(dim_names)} ({join(patch_dims)} c)'
rearrange_str = f'{input_pattern} -> {output_pattern}'
rearrange_kwargs = {f'p{i}': p for i, p in enumerate(patch_size)}
self.to_patch_embedding = nn.Sequential(
Rearrange(rearrange_str, **rearrange_kwargs),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.dropout = nn.Dropout(emb_dropout)
# golden gate pope
self.polar_emb = GoldenGatePoPENd(
@@ -269,12 +269,12 @@ class ViTND(Module):
p_zero_freqs = pope_p_zero_freqs,
init_learned_bias_uniform = pope_init_learned_bias_uniform
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, polar_emb = self.polar_emb)
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes)
def muon_parameters(self):
params = []
@@ -298,9 +298,9 @@ class ViTND(Module):
return_embed = False
):
x = self.to_patch_embedding(x) # (b, *spatial_dims, patch_dim)
batch, *spatial_dims, _, device = *x.shape, x.device
# Generate position coordinates
grids = [arange(d, device = device, dtype = torch.float32) for d in spatial_dims]
@@ -308,12 +308,12 @@ class ViTND(Module):
pos = stack(grid, dim = -1) # (*spatial_dims, ndim)
# flatten spatial dimensions for attention with nd rotary
pos = repeat(pos, '... p -> b (...) p', b = batch)
x, packed_shape = pack([x], 'b * d')
x = self.dropout(x)
embed = self.transformer(x, pos)
# return the embed with reconstituted patch shape
@@ -330,7 +330,7 @@ class ViTND(Module):
return self.mlp_head(pooled)
if __name__ == '__main__':
model = ViTND(
ndim = 5,
input_shape = (4, 8, 16, 32, 64),

View File

@@ -75,23 +75,23 @@ class GoldenGateRoPENd(Module):
# input shape: (b, h, n, d) where d = head_dim
# pos shape: (b, n, p) where p = pos_dim
# self.freqs shape: (h, f, p) where f = d // 2
x, y = input.float().chunk(2, dim = -1) # both (b, h, n, f)
# Expand dimensions for broadcasting
freqs = rearrange(self.freqs, 'h f p -> 1 h 1 f p')
positions = rearrange(pos.float(), 'b n p -> b 1 n 1 p')
# Compute theta for each (batch, head, seq, freq)
theta = reduce(freqs * positions, 'b h n f p -> b h n f', 'sum')
cos_theta = torch.cos(theta)
sin_theta = torch.sin(theta)
# Apply rotation
x_out = x * cos_theta - y * sin_theta
y_out = x * sin_theta + y * cos_theta
output = cat((x_out, y_out), dim=-1)
return output.type_as(input)
@@ -108,7 +108,7 @@ class FeedForward(Module):
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
@@ -117,15 +117,15 @@ class Attention(Module):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.rotary_emb = rotary_emb
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qk = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_v = nn.Linear(dim, inner_dim, bias = False)
@@ -133,24 +133,24 @@ class Attention(Module):
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x, pos = None):
x = self.norm(x)
qkv = (*self.to_qk(x).chunk(2, dim = -1), self.to_v(x))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
# Apply rotary embeddings if available
if exists(self.rotary_emb):
assert exists(pos)
q = self.rotary_emb(q, pos)
k = self.rotary_emb(k, pos)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
@@ -165,7 +165,7 @@ class Transformer(Module):
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, rotary_emb = rotary_emb),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x, pos = None):
for attn, ff in self.layers:
x = attn(x, pos) + x
@@ -193,45 +193,45 @@ class ViTND(Module):
rope_p_zero_freqs: float = 0.0
):
super().__init__()
assert 1 <= ndim <= 7, 'ndim must be between 1 and 7'
self.ndim = ndim
input_shape = ensure_tuple(input_shape, ndim)
patch_size = ensure_tuple(patch_size, ndim)
for i, (inp_dim, patch_dim) in enumerate(zip(input_shape, patch_size)):
assert inp_dim % patch_dim == 0, f'Input dimension {i} ({inp_dim}) must be divisible by patch size ({patch_dim})'
num_patches_per_dim = [inp_dim // patch_dim for inp_dim, patch_dim in zip(input_shape, patch_size)]
num_patches = 1
for n in num_patches_per_dim:
num_patches *= n
patch_dim = channels
for p in patch_size:
patch_dim *= p
dim_names = 'fghijkl'[:ndim]
input_dims = [f'({d} p{i})' for i, d in enumerate(dim_names)]
patch_dims = [f'p{i}' for i in range(ndim)]
input_pattern = f'b c {join(input_dims)}'
output_pattern = f'b {join(dim_names)} ({join(patch_dims)} c)'
rearrange_str = f'{input_pattern} -> {output_pattern}'
rearrange_kwargs = {f'p{i}': p for i, p in enumerate(patch_size)}
self.to_patch_embedding = nn.Sequential(
Rearrange(rearrange_str, **rearrange_kwargs),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.dropout = nn.Dropout(emb_dropout)
# Create rotary embeddings
self.rotary_emb = GoldenGateRoPENd(
dim_pos = ndim,
@@ -241,12 +241,12 @@ class ViTND(Module):
rope_max_freq = rope_max_freq,
rope_p_zero_freqs = rope_p_zero_freqs
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, rotary_emb = self.rotary_emb)
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes)
def muon_parameters(self):
params = []
@@ -270,9 +270,9 @@ class ViTND(Module):
return_embed = False
):
x = self.to_patch_embedding(x) # (b, *spatial_dims, patch_dim)
batch, *spatial_dims, _, device = *x.shape, x.device
# Generate position coordinates
grids = [arange(d, device = device, dtype = torch.float32) for d in spatial_dims]
@@ -280,12 +280,12 @@ class ViTND(Module):
pos = stack(grid, dim = -1) # (*spatial_dims, ndim)
# flatten spatial dimensions for attention with nd rotary
pos = repeat(pos, '... p -> b (...) p', b = batch)
x, packed_shape = pack([x], 'b * d')
x = self.dropout(x)
embed = self.transformer(x, pos)
# return the embed with reconstituted patch shape
@@ -303,7 +303,7 @@ class ViTND(Module):
if __name__ == '__main__':
model = ViTND(
ndim = 5,
input_shape = (4, 8, 16, 32, 64),

View File

@@ -231,4 +231,4 @@ if __name__ == '__main__':
decorr_loss = DecorrelationLoss(0.0001, soft_validate_num_sampled = True)
out = decorr_loss(hiddens)
assert out.item() == 0
assert out.item() == 0

View File

@@ -0,0 +1,217 @@
from __future__ import annotations
import torch
from torch import nn
from torch.nn import Module, ModuleList
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# functions
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# classes
class FeedForward(Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim, bias = False),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.norm = nn.LayerNorm(dim, bias = False)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(Module):
def __init__(
self,
dim,
depth,
heads,
dim_head,
mlp_dim,
dropout = 0.,
keel_residual_scale = None
):
super().__init__()
assert depth > 1
self.layers = ModuleList([])
for _ in range(depth):
self.layers.extend([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
])
num_layers = depth * 2
self.keel_residual_scale = default(keel_residual_scale, num_layers)
self.post_norms = ModuleList([nn.LayerNorm(dim, bias = False) for _ in range(num_layers - 1)])
def forward(self, x):
residual_scale = self.keel_residual_scale
for layer_ind, layer in enumerate(self.layers):
first_layer = layer_ind == 0
residual = x
out = layer(x)
if first_layer:
x = out + residual
continue
post_norm = self.post_norms[layer_ind - 1]
x = post_norm(out + residual * residual_scale)
return x
class ViT(Module):
def __init__(
self,
*,
image_size,
patch_size,
num_classes,
dim,
depth,
heads,
mlp_dim,
pool = 'cls',
channels = 3,
dim_head = 64,
dropout = 0.,
emb_dropout = 0.,
keel_residual_scale = None
):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
num_cls_tokens = 1 if pool == 'cls' else 0
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.cls_token = nn.Parameter(torch.randn(num_cls_tokens, dim))
self.pos_embedding = nn.Parameter(torch.randn(num_patches + num_cls_tokens, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(
dim,
depth,
heads,
dim_head,
mlp_dim,
dropout,
keel_residual_scale = keel_residual_scale
)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes) if num_classes > 0 else None
def forward(self, img):
batch = img.shape[0]
x = self.to_patch_embedding(img)
cls_tokens = repeat(self.cls_token, '... d -> b ... d', b = batch)
x = torch.cat((cls_tokens, x), dim = 1)
seq = x.shape[1]
x = x + self.pos_embedding[:seq]
x = self.dropout(x)
x = self.transformer(x)
if not exists(self.mlp_head):
return x
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)
if __name__ == '__main__':
v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
img = torch.randn(1, 3, 256, 256)
preds = v(img)
assert preds.shape == (1, 1000)

View File

@@ -89,7 +89,6 @@ class Attention(Module):
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
dots = dots.masked_fill(~mask, -torch.finfo(dots.dtype).max)
attn = self.attend(dots)
@@ -109,7 +108,7 @@ class Transformer(Module):
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_flash_attn = use_flash_attn),
FeedForward(dim, mlp_dim, dropout = dropout)
]))