mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2026-03-28 00:22:32 +00:00
Compare commits
20 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7e18d0302e | ||
|
|
b80676e09c | ||
|
|
fc1e727428 | ||
|
|
6032a54b48 | ||
|
|
06a1f42924 | ||
|
|
6ae6a3ab64 | ||
|
|
827300beed | ||
|
|
a7c4e7f79f | ||
|
|
54ec3f2af5 | ||
|
|
9aa52cce49 | ||
|
|
4c89017444 | ||
|
|
580258d99e | ||
|
|
6f1caef987 | ||
|
|
fb5014f0ee | ||
|
|
0b7518ef45 | ||
|
|
077d8c188f | ||
|
|
5888f05300 | ||
|
|
d518e89573 | ||
|
|
dd6462d19b | ||
|
|
a1ee1daa1a |
3
.github/FUNDING.yml
vendored
3
.github/FUNDING.yml
vendored
@@ -1,3 +0,0 @@
|
||||
# These are supported funding model platforms
|
||||
|
||||
github: [lucidrains]
|
||||
36
.github/workflows/python-publish.yml
vendored
36
.github/workflows/python-publish.yml
vendored
@@ -1,36 +0,0 @@
|
||||
# This workflow will upload a Python Package using Twine when a release is created
|
||||
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
|
||||
|
||||
# This workflow uses actions that are not certified by GitHub.
|
||||
# They are provided by a third-party and are governed by
|
||||
# separate terms of service, privacy policy, and support
|
||||
# documentation.
|
||||
|
||||
name: Upload Python Package
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.x'
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install build
|
||||
- name: Build package
|
||||
run: python -m build
|
||||
- name: Publish package
|
||||
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
|
||||
with:
|
||||
user: __token__
|
||||
password: ${{ secrets.PYPI_API_TOKEN }}
|
||||
34
.github/workflows/python-test.yml
vendored
34
.github/workflows/python-test.yml
vendored
@@ -1,34 +0,0 @@
|
||||
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
|
||||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
|
||||
|
||||
name: Test
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.8, 3.9]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cpu
|
||||
python -m pip install -e .
|
||||
python -m pip install pytest
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
pytest -q
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -127,3 +127,6 @@ dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# scripts
|
||||
*.sh
|
||||
|
||||
134
README.md
134
README.md
@@ -49,7 +49,7 @@
|
||||
|
||||
## Vision Transformer - Pytorch
|
||||
|
||||
Implementation of <a href="https://openreview.net/pdf?id=YicbFdNTTy">Vision Transformer</a>, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. Significance is further explained in <a href="https://www.youtube.com/watch?v=TrdevFK_am4">Yannic Kilcher's</a> video. There's really not much to code here, but may as well lay it out for everyone so we expedite the attention revolution.
|
||||
Implementation of <a href="https://openreview.net/pdf?id=YicbFdNTTy">Vision Transformer</a>, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. Significance is further explained in <a href="https://www.youtube.com/watch?v=TrdevFK_am4">Yannic Kilcher's</a> video. There's really not much to code here, but may as well lay it out for everyone so we expedite the [attention](https://www.youtube.com/watch?v=eMlx5fFNoYc) revolution.
|
||||
|
||||
For a Pytorch implementation with pretrained models, please see Ross Wightman's repository <a href="https://github.com/rwightman/pytorch-image-models">here</a>.
|
||||
|
||||
@@ -90,26 +90,26 @@ preds = v(img) # (1, 1000)
|
||||
|
||||
## Parameters
|
||||
|
||||
- `image_size`: int.
|
||||
- `image_size`: int.
|
||||
Image size. If you have rectangular images, make sure your image size is the maximum of the width and height
|
||||
- `patch_size`: int.
|
||||
Size of patches. `image_size` must be divisible by `patch_size`.
|
||||
- `patch_size`: int.
|
||||
Size of patches. `image_size` must be divisible by `patch_size`.
|
||||
The number of patches is: ` n = (image_size // patch_size) ** 2` and `n` **must be greater than 16**.
|
||||
- `num_classes`: int.
|
||||
- `num_classes`: int.
|
||||
Number of classes to classify.
|
||||
- `dim`: int.
|
||||
- `dim`: int.
|
||||
Last dimension of output tensor after linear transformation `nn.Linear(..., dim)`.
|
||||
- `depth`: int.
|
||||
- `depth`: int.
|
||||
Number of Transformer blocks.
|
||||
- `heads`: int.
|
||||
Number of heads in Multi-head Attention layer.
|
||||
- `mlp_dim`: int.
|
||||
Dimension of the MLP (FeedForward) layer.
|
||||
- `channels`: int, default `3`.
|
||||
Number of image's channels.
|
||||
- `dropout`: float between `[0, 1]`, default `0.`.
|
||||
Dropout rate.
|
||||
- `emb_dropout`: float between `[0, 1]`, default `0`.
|
||||
- `heads`: int.
|
||||
Number of heads in Multi-head Attention layer.
|
||||
- `mlp_dim`: int.
|
||||
Dimension of the MLP (FeedForward) layer.
|
||||
- `channels`: int, default `3`.
|
||||
Number of image's channels.
|
||||
- `dropout`: float between `[0, 1]`, default `0.`.
|
||||
Dropout rate.
|
||||
- `emb_dropout`: float between `[0, 1]`, default `0`.
|
||||
Embedding dropout rate.
|
||||
- `pool`: string, either `cls` token pooling or `mean` pooling
|
||||
|
||||
@@ -972,7 +972,7 @@ torch.save(model.state_dict(), './pretrained-net.pt')
|
||||
|
||||
<img src="./images/mp3.png" width="400px"></img>
|
||||
|
||||
New <a href="https://arxiv.org/abs/2207.07611">paper</a> that introduces masked position prediction pre-training criteria. This strategy is more efficient than the Masked Autoencoder strategy and has comparable performance.
|
||||
New <a href="https://arxiv.org/abs/2207.07611">paper</a> that introduces masked position prediction pre-training criteria. This strategy is more efficient than the Masked Autoencoder strategy and has comparable performance.
|
||||
|
||||
```python
|
||||
import torch
|
||||
@@ -1358,10 +1358,10 @@ learner = Dino(
|
||||
hidden_layer = 'to_latent', # hidden layer name or index, from which to extract the embedding
|
||||
projection_hidden_size = 256, # projector network hidden dimension
|
||||
projection_layers = 4, # number of layers in projection network
|
||||
num_classes_K = 65336, # output logits dimensions (referenced as K in paper)
|
||||
num_classes_K = 65536, # output logits dimensions (referenced as K in paper)
|
||||
student_temp = 0.9, # student temperature
|
||||
teacher_temp = 0.04, # teacher temperature, needs to be annealed from 0.04 to 0.07 over 30 epochs
|
||||
local_upper_crop_scale = 0.4, # upper bound for local crop - 0.4 was recommended in the paper
|
||||
local_upper_crop_scale = 0.4, # upper bound for local crop - 0.4 was recommended in the paper
|
||||
global_lower_crop_scale = 0.5, # lower bound for global crop - 0.5 was recommended in the paper
|
||||
moving_average_decay = 0.9, # moving average of encoder - paper showed anywhere from 0.9 to 0.999 was ok
|
||||
center_moving_average_decay = 0.9, # moving average of teacher centers - paper showed anywhere from 0.9 to 0.999 was ok
|
||||
@@ -1735,7 +1735,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
|
||||
```bibtex
|
||||
@misc{touvron2020training,
|
||||
title = {Training data-efficient image transformers & distillation through attention},
|
||||
title = {Training data-efficient image transformers & distillation through attention},
|
||||
author = {Hugo Touvron and Matthieu Cord and Matthijs Douze and Francisco Massa and Alexandre Sablayrolles and Hervé Jégou},
|
||||
year = {2020},
|
||||
eprint = {2012.12877},
|
||||
@@ -1768,7 +1768,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
|
||||
```bibtex
|
||||
@misc{touvron2021going,
|
||||
title = {Going deeper with Image Transformers},
|
||||
title = {Going deeper with Image Transformers},
|
||||
author = {Hugo Touvron and Matthieu Cord and Alexandre Sablayrolles and Gabriel Synnaeve and Hervé Jégou},
|
||||
year = {2021},
|
||||
eprint = {2103.17239},
|
||||
@@ -1801,7 +1801,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
|
||||
```bibtex
|
||||
@misc{heo2021rethinking,
|
||||
title = {Rethinking Spatial Dimensions of Vision Transformers},
|
||||
title = {Rethinking Spatial Dimensions of Vision Transformers},
|
||||
author = {Byeongho Heo and Sangdoo Yun and Dongyoon Han and Sanghyuk Chun and Junsuk Choe and Seong Joon Oh},
|
||||
year = {2021},
|
||||
eprint = {2103.16302},
|
||||
@@ -1845,7 +1845,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
|
||||
```bibtex
|
||||
@misc{su2021roformer,
|
||||
title = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
|
||||
title = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
|
||||
author = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
|
||||
year = {2021},
|
||||
eprint = {2104.09864},
|
||||
@@ -1867,7 +1867,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
|
||||
```bibtex
|
||||
@misc{chen2021regionvit,
|
||||
title = {RegionViT: Regional-to-Local Attention for Vision Transformers},
|
||||
title = {RegionViT: Regional-to-Local Attention for Vision Transformers},
|
||||
author = {Chun-Fu Chen and Rameswar Panda and Quanfu Fan},
|
||||
year = {2021},
|
||||
eprint = {2106.02689},
|
||||
@@ -1878,7 +1878,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
|
||||
```bibtex
|
||||
@misc{wang2021crossformer,
|
||||
title = {CrossFormer: A Versatile Vision Transformer Hinging on Cross-scale Attention},
|
||||
title = {CrossFormer: A Versatile Vision Transformer Hinging on Cross-scale Attention},
|
||||
author = {Wenxiao Wang and Lu Yao and Long Chen and Binbin Lin and Deng Cai and Xiaofei He and Wei Liu},
|
||||
year = {2021},
|
||||
eprint = {2108.00154},
|
||||
@@ -1900,7 +1900,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
|
||||
```bibtex
|
||||
@misc{he2021masked,
|
||||
title = {Masked Autoencoders Are Scalable Vision Learners},
|
||||
title = {Masked Autoencoders Are Scalable Vision Learners},
|
||||
author = {Kaiming He and Xinlei Chen and Saining Xie and Yanghao Li and Piotr Dollár and Ross Girshick},
|
||||
year = {2021},
|
||||
eprint = {2111.06377},
|
||||
@@ -1911,7 +1911,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
|
||||
```bibtex
|
||||
@misc{xie2021simmim,
|
||||
title = {SimMIM: A Simple Framework for Masked Image Modeling},
|
||||
title = {SimMIM: A Simple Framework for Masked Image Modeling},
|
||||
author = {Zhenda Xie and Zheng Zhang and Yue Cao and Yutong Lin and Jianmin Bao and Zhuliang Yao and Qi Dai and Han Hu},
|
||||
year = {2021},
|
||||
eprint = {2111.09886},
|
||||
@@ -1944,7 +1944,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
|
||||
```bibtex
|
||||
@misc{lee2021vision,
|
||||
title = {Vision Transformer for Small-Size Datasets},
|
||||
title = {Vision Transformer for Small-Size Datasets},
|
||||
author = {Seung Hoon Lee and Seunghyun Lee and Byung Cheol Song},
|
||||
year = {2021},
|
||||
eprint = {2112.13492},
|
||||
@@ -1966,7 +1966,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
|
||||
```bibtex
|
||||
@misc{yang2022scalablevit,
|
||||
title = {ScalableViT: Rethinking the Context-oriented Generalization of Vision Transformer},
|
||||
title = {ScalableViT: Rethinking the Context-oriented Generalization of Vision Transformer},
|
||||
author = {Rui Yang and Hailong Ma and Jie Wu and Yansong Tang and Xuefeng Xiao and Min Zheng and Xiu Li},
|
||||
year = {2022},
|
||||
eprint = {2203.10790},
|
||||
@@ -2203,13 +2203,85 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
|
||||
```bibtex
|
||||
@misc{carrigg2025decorrelationspeedsvisiontransformers,
|
||||
title = {Decorrelation Speeds Up Vision Transformers},
|
||||
title = {Decorrelation Speeds Up Vision Transformers},
|
||||
author = {Kieran Carrigg and Rob van Gastel and Melda Yeghaian and Sander Dalm and Faysal Boughorbel and Marcel van Gerven},
|
||||
year = {2025},
|
||||
eprint = {2510.14657},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV},
|
||||
url = {https://arxiv.org/abs/2510.14657},
|
||||
url = {https://arxiv.org/abs/2510.14657},
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{gopalakrishnan2025decouplingwhatwherepolar,
|
||||
title = {Decoupling the "What" and "Where" With Polar Coordinate Positional Embeddings},
|
||||
author = {Anand Gopalakrishnan and Robert Csordás and Jürgen Schmidhuber and Michael C. Mozer},
|
||||
year = {2025},
|
||||
eprint = {2509.10534},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.LG},
|
||||
url = {https://arxiv.org/abs/2509.10534},
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{qiu2025gatedattentionlargelanguage,
|
||||
title = {Gated Attention for Large Language Models: Non-linearity, Sparsity, and Attention-Sink-Free},
|
||||
author = {Zihan Qiu and Zekun Wang and Bo Zheng and Zeyu Huang and Kaiyue Wen and Songlin Yang and Rui Men and Le Yu and Fei Huang and Suozhi Huang and Dayiheng Liu and Jingren Zhou and Junyang Lin},
|
||||
year = {2025},
|
||||
eprint = {2505.06708},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CL},
|
||||
url = {https://arxiv.org/abs/2505.06708}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{chen2026postlayernormbackstableexpressive,
|
||||
title = {Post-LayerNorm Is Back: Stable, ExpressivE, and Deep},
|
||||
author = {Chen Chen and Lai Wei},
|
||||
year = {2026},
|
||||
eprint = {2601.19895},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.LG},
|
||||
url = {https://arxiv.org/abs/2601.19895},
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{intelligence2025pi06vlalearnsexperience,
|
||||
title = {$\pi^{*}_{0.6}$: a VLA That Learns From Experience},
|
||||
author = {Physical Intelligence and Ali Amin and Raichelle Aniceto and Ashwin Balakrishna and Kevin Black and Ken Conley and Grace Connors and James Darpinian and Karan Dhabalia and Jared DiCarlo and Danny Driess and Michael Equi and Adnan Esmail and Yunhao Fang and Chelsea Finn and Catherine Glossop and Thomas Godden and Ivan Goryachev and Lachy Groom and Hunter Hancock and Karol Hausman and Gashon Hussein and Brian Ichter and Szymon Jakubczak and Rowan Jen and Tim Jones and Ben Katz and Liyiming Ke and Chandra Kuchi and Marinda Lamb and Devin LeBlanc and Sergey Levine and Adrian Li-Bell and Yao Lu and Vishnu Mano and Mohith Mothukuri and Suraj Nair and Karl Pertsch and Allen Z. Ren and Charvi Sharma and Lucy Xiaoyang Shi and Laura Smith and Jost Tobias Springenberg and Kyle Stachowicz and Will Stoeckle and Alex Swerdlow and James Tanner and Marcel Torne and Quan Vuong and Anna Walling and Haohuan Wang and Blake Williams and Sukwon Yoo and Lili Yu and Ury Zhilinsky and Zhiyuan Zhou},
|
||||
year = {2025},
|
||||
eprint = {2511.14759},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.LG},
|
||||
url = {https://arxiv.org/abs/2511.14759},
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{kimiteam2026attentionresiduals,
|
||||
title = {Attention Residuals},
|
||||
author = {Kimi Team and Guangyu Chen and Yu Zhang and Jianlin Su and Weixin Xu and Siyuan Pan and Yaoyu Wang and Yucheng Wang and Guanduo Chen and Bohong Yin and Yutian Chen and Junjie Yan and Ming Wei and Y. Zhang and Fanqing Meng and Chao Hong and Xiaotong Xie and Shaowei Liu and Enzhe Lu and Yunpeng Tai and Yanru Chen and Xin Men and Haiqing Guo and Y. Charles and Haoyu Lu and Lin Sui and Jinguo Zhu and Zaida Zhou and Weiran He and Weixiao Huang and Xinran Xu and Yuzhi Wang and Guokun Lai and Yulun Du and Yuxin Wu and Zhilin Yang and Xinyu Zhou},
|
||||
year = {2026},
|
||||
eprint = {2603.15031},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CL},
|
||||
url = {https://arxiv.org/abs/2603.15031},
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{balestriero2025lejepa,
|
||||
title = {LeJEPA: Provable and Scalable Self-Supervised Learning Without the Heuristics},
|
||||
author = {Randall Balestriero and Yann LeCun},
|
||||
year = {2025},
|
||||
eprint = {2511.08544},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.LG},
|
||||
url = {https://arxiv.org/abs/2511.08544},
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "vit-pytorch"
|
||||
version = "1.16.2"
|
||||
version = "1.19.1"
|
||||
description = "Vision Transformer (ViT) - Pytorch"
|
||||
readme = { file = "README.md", content-type = "text/markdown" }
|
||||
license = { file = "LICENSE" }
|
||||
@@ -31,8 +31,8 @@ classifiers = [
|
||||
"Programming Language :: Python :: 3.12",
|
||||
]
|
||||
dependencies = [
|
||||
"einops>=0.7.0",
|
||||
"torch>=1.10",
|
||||
"einops>=0.8.2",
|
||||
"torch>=2.4",
|
||||
"torchvision",
|
||||
]
|
||||
|
||||
@@ -44,8 +44,8 @@ test = [
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/lucidrains/vit-pytorch"
|
||||
Repository = "https://github.com/lucidrains/vit-pytorch"
|
||||
Homepage = "https://codeberg.org/lucidrains/vit-pytorch"
|
||||
Repository = "https://codeberg.org/lucidrains/vit-pytorch"
|
||||
|
||||
[tool.setuptools]
|
||||
include-package-data = true
|
||||
|
||||
@@ -26,7 +26,7 @@ class AcceptVideoWrapper(Module):
|
||||
dim_emb = None,
|
||||
time_seq_len = None,
|
||||
embed_is_channel_first = False,
|
||||
output_pos_add_pos_emb = 0, # defaults to first output position to add embedding
|
||||
output_pos_add_pos_emb = 0, # defaults to first output position to add embedding
|
||||
proj_embed_to_dim = None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -25,12 +25,12 @@ class DistillMixin:
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = b)
|
||||
cls_tokens = repeat(self.cls_token, 'n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim = 1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x += self.pos_embedding[:(n + 1)]
|
||||
|
||||
if distilling:
|
||||
distill_tokens = repeat(distill_token, '1 n d -> b n d', b = b)
|
||||
distill_tokens = repeat(distill_token, 'n d -> b n d', b = b)
|
||||
x = torch.cat((x, distill_tokens), dim = 1)
|
||||
|
||||
x = self._attend(x)
|
||||
@@ -125,7 +125,7 @@ class DistillWrapper(Module):
|
||||
self.alpha = alpha
|
||||
self.hard = hard
|
||||
|
||||
self.distillation_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.distillation_token = nn.Parameter(torch.randn(1, dim))
|
||||
|
||||
self.distill_mlp = nn.Sequential(
|
||||
nn.LayerNorm(dim) if mlp_layernorm else nn.Identity(),
|
||||
|
||||
@@ -103,7 +103,7 @@ class JumboViT(Module):
|
||||
h = image_height // patch_height,
|
||||
w = image_width // patch_width,
|
||||
dim = dim,
|
||||
)
|
||||
)
|
||||
|
||||
jumbo_cls_dim = dim * jumbo_cls_k
|
||||
|
||||
|
||||
@@ -145,7 +145,7 @@ class ViT(nn.Module):
|
||||
return x
|
||||
|
||||
def forward(self, img):
|
||||
x = self.img_to_tokens(img)
|
||||
x = self.img_to_tokens(img)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
@@ -160,7 +160,7 @@ class Adapter(nn.Module):
|
||||
*,
|
||||
vit,
|
||||
num_memories_per_layer = 10,
|
||||
num_classes = 2,
|
||||
num_classes = 2,
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(vit, ViT)
|
||||
@@ -188,7 +188,7 @@ class Adapter(nn.Module):
|
||||
)
|
||||
|
||||
# specialized attention mask to preserve the output of the original ViT
|
||||
# it allows the memory CLS token to attend to all other tokens (and the learnable memory layer tokens), but not vice versa
|
||||
# it allows the memory CLS token to attend to all other tokens (and the learnable memory layer tokens), but not vice versa
|
||||
|
||||
attn_mask = torch.ones((num_patches, num_patches), dtype = torch.bool)
|
||||
attn_mask = F.pad(attn_mask, (1, num_memories_per_layer), value = False) # main tokens cannot attend to learnable memories per layer
|
||||
@@ -203,7 +203,7 @@ class Adapter(nn.Module):
|
||||
# add task specific memory tokens
|
||||
|
||||
memory_cls_tokens = repeat(self.memory_cls_token, 'd -> b 1 d', b = b)
|
||||
tokens = torch.cat((memory_cls_tokens, tokens), dim = 1)
|
||||
tokens = torch.cat((memory_cls_tokens, tokens), dim = 1)
|
||||
|
||||
# pass memories along with image tokens through transformer for attending
|
||||
|
||||
|
||||
319
vit_pytorch/lejepa.py
Normal file
319
vit_pytorch/lejepa.py
Normal file
@@ -0,0 +1,319 @@
|
||||
import random
|
||||
from functools import wraps
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Module
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torchvision import transforms as T
|
||||
from einops import rearrange
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def singleton(cache_key):
|
||||
def inner_fn(fn):
|
||||
@wraps(fn)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
instance = getattr(self, cache_key)
|
||||
if instance is not None:
|
||||
return instance
|
||||
|
||||
instance = fn(self, *args, **kwargs)
|
||||
setattr(self, cache_key, instance)
|
||||
return instance
|
||||
return wrapper
|
||||
return inner_fn
|
||||
|
||||
def get_module_device(module):
|
||||
return next(module.parameters()).device
|
||||
|
||||
def l2norm(t, eps = 1e-6):
|
||||
return F.normalize(t, dim = -1, eps = eps)
|
||||
|
||||
# loss function
|
||||
|
||||
def sigreg_loss(
|
||||
x,
|
||||
num_slices = 1024,
|
||||
domain = (-5, 5),
|
||||
num_knots = 17
|
||||
):
|
||||
# Randall Balestriero - https://arxiv.org/abs/2511.08544
|
||||
|
||||
dim, device = x.shape[-1], x.device
|
||||
|
||||
# slice sampling
|
||||
|
||||
rand_projs = torch.randn((num_slices, dim), device = device)
|
||||
rand_projs = l2norm(rand_projs)
|
||||
|
||||
# integration points
|
||||
|
||||
t = torch.linspace(*domain, num_knots, device = device)
|
||||
|
||||
# theoretical CF for N(0, 1) and Gauss. window
|
||||
|
||||
exp_f = (-0.5 * t.square()).exp()
|
||||
|
||||
# empirical CF
|
||||
|
||||
x_t = torch.einsum('... d, m d -> ... m', x, rand_projs)
|
||||
x_t = rearrange(x_t, '... m -> (...) m')
|
||||
|
||||
x_t = rearrange(x_t, 'n m -> n m 1') * t
|
||||
ecf = (1j * x_t).exp().mean(dim = 0)
|
||||
|
||||
# weighted L2 distance
|
||||
|
||||
err = ecf.sub(exp_f).abs().square().mul(exp_f)
|
||||
|
||||
return torch.trapezoid(err, t, dim = -1).mean()
|
||||
|
||||
# augmentation utils
|
||||
|
||||
class RandomApply(Module):
|
||||
def __init__(self, fn, p):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.p = p
|
||||
|
||||
def forward(self, x):
|
||||
if random.random() > self.p:
|
||||
return x
|
||||
return self.fn(x)
|
||||
|
||||
# MLP class for projector
|
||||
|
||||
class L2Norm(Module):
|
||||
def forward(self, x, eps = 1e-6):
|
||||
return l2norm(x, eps)
|
||||
|
||||
class MLP(Module):
|
||||
def __init__(self, dim, dim_out, num_layers, hidden_size = 256):
|
||||
super().__init__()
|
||||
|
||||
layers = []
|
||||
dims = (dim, *((hidden_size,) * (num_layers - 1)))
|
||||
|
||||
for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
is_last = ind == (len(dims) - 1)
|
||||
|
||||
layers.extend([
|
||||
nn.Linear(layer_dim_in, layer_dim_out),
|
||||
nn.GELU() if not is_last else nn.Identity()
|
||||
])
|
||||
|
||||
self.net = nn.Sequential(
|
||||
*layers,
|
||||
L2Norm(),
|
||||
nn.Linear(hidden_size, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
# wrapper
|
||||
|
||||
class NetWrapper(Module):
|
||||
def __init__(self, net, output_dim, projection_hidden_size, projection_num_layers, layer = -2):
|
||||
super().__init__()
|
||||
self.net = net
|
||||
self.layer = layer
|
||||
|
||||
self.projector = None
|
||||
self.projection_hidden_size = projection_hidden_size
|
||||
self.projection_num_layers = projection_num_layers
|
||||
self.output_dim = output_dim
|
||||
|
||||
self.hidden = {}
|
||||
self.hook_registered = False
|
||||
|
||||
def _find_layer(self):
|
||||
if type(self.layer) == str:
|
||||
modules = dict([*self.net.named_modules()])
|
||||
return modules.get(self.layer, None)
|
||||
elif type(self.layer) == int:
|
||||
children = [*self.net.children()]
|
||||
return children[self.layer]
|
||||
return None
|
||||
|
||||
def _hook(self, _, input, output):
|
||||
device = input[0].device
|
||||
self.hidden[device] = output.flatten(1)
|
||||
|
||||
def _register_hook(self):
|
||||
layer = self._find_layer()
|
||||
assert layer is not None, f'hidden layer ({self.layer}) not found'
|
||||
handle = layer.register_forward_hook(self._hook)
|
||||
self.hook_registered = True
|
||||
|
||||
@singleton('projector')
|
||||
def _get_projector(self, hidden):
|
||||
_, dim = hidden.shape
|
||||
projector = MLP(dim, self.output_dim, self.projection_num_layers, self.projection_hidden_size)
|
||||
return projector.to(hidden)
|
||||
|
||||
def get_embedding(self, x):
|
||||
if self.layer == -1:
|
||||
return self.net(x)
|
||||
|
||||
if not self.hook_registered:
|
||||
self._register_hook()
|
||||
|
||||
self.hidden.clear()
|
||||
_ = self.net(x)
|
||||
hidden = self.hidden[x.device]
|
||||
self.hidden.clear()
|
||||
|
||||
assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
|
||||
return hidden
|
||||
|
||||
def forward(self, x, return_projection = True):
|
||||
embed = self.get_embedding(x)
|
||||
if not return_projection:
|
||||
return embed
|
||||
|
||||
projector = self._get_projector(embed)
|
||||
return projector(embed), embed
|
||||
|
||||
# main class
|
||||
|
||||
class LeJEPA(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
net,
|
||||
image_size,
|
||||
hidden_layer = -2,
|
||||
projection_hidden_size = 256,
|
||||
num_classes_K = 65336,
|
||||
projection_layers = 4,
|
||||
local_upper_crop_scale = 0.4,
|
||||
global_lower_crop_scale = 0.5,
|
||||
target_loss_weight = 1.,
|
||||
sigreg_loss_weight = 1.,
|
||||
sigreg_loss_kwargs = dict(
|
||||
num_slices = 1024,
|
||||
domain = (-5, 5),
|
||||
num_knots = 17
|
||||
),
|
||||
augment_fn = None,
|
||||
augment_fn2 = None
|
||||
):
|
||||
super().__init__()
|
||||
self.net = net
|
||||
|
||||
# default BYOL augmentation
|
||||
|
||||
DEFAULT_AUG = torch.nn.Sequential(
|
||||
RandomApply(
|
||||
T.ColorJitter(0.8, 0.8, 0.8, 0.2),
|
||||
p = 0.3
|
||||
),
|
||||
T.RandomGrayscale(p=0.2),
|
||||
T.RandomHorizontalFlip(),
|
||||
RandomApply(
|
||||
T.GaussianBlur((3, 3), (1.0, 2.0)),
|
||||
p = 0.2
|
||||
),
|
||||
T.Normalize(
|
||||
mean=torch.tensor([0.485, 0.456, 0.406]),
|
||||
std=torch.tensor([0.229, 0.224, 0.225])),
|
||||
)
|
||||
|
||||
self.augment1 = default(augment_fn, DEFAULT_AUG)
|
||||
self.augment2 = default(augment_fn2, DEFAULT_AUG)
|
||||
|
||||
# local and global crops
|
||||
|
||||
self.local_crop = T.RandomResizedCrop((image_size, image_size), scale = (0.05, local_upper_crop_scale))
|
||||
self.global_crop = T.RandomResizedCrop((image_size, image_size), scale = (global_lower_crop_scale, 1.))
|
||||
|
||||
self.encoder = NetWrapper(net, num_classes_K, projection_hidden_size, projection_layers, layer = hidden_layer)
|
||||
|
||||
self.target_loss_weight = target_loss_weight
|
||||
self.sigreg_loss_weight = sigreg_loss_weight
|
||||
self.sigreg_loss_kwargs = sigreg_loss_kwargs
|
||||
|
||||
# get device of network and make wrapper same device
|
||||
device = get_module_device(net)
|
||||
self.to(device)
|
||||
|
||||
# send a mock image tensor to instantiate singleton parameters
|
||||
self.forward(torch.randn(2, 3, image_size, image_size, device=device))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
return_embedding = False,
|
||||
return_projection = True
|
||||
):
|
||||
if return_embedding:
|
||||
return self.encoder(x, return_projection = return_projection)
|
||||
|
||||
image_one, image_two = self.augment1(x), self.augment2(x)
|
||||
|
||||
local_image_one, local_image_two = self.local_crop(image_one), self.local_crop(image_two)
|
||||
global_image_one, global_image_two = self.global_crop(image_one), self.global_crop(image_two)
|
||||
|
||||
local_images = torch.cat((local_image_one, local_image_two), dim = 0)
|
||||
proj_locals, _ = self.encoder(local_images)
|
||||
proj_local_one, proj_local_two = proj_locals.chunk(2, dim = 0)
|
||||
|
||||
with torch.no_grad():
|
||||
global_images = torch.cat((global_image_one, global_image_two), dim = 0)
|
||||
proj_globals, _ = self.encoder(global_images)
|
||||
proj_global_one, proj_global_two = proj_globals.chunk(2, dim = 0)
|
||||
|
||||
# invariance loss
|
||||
|
||||
mse_loss = F.mse_loss(proj_local_one, proj_global_two) + F.mse_loss(proj_local_two, proj_global_one)
|
||||
|
||||
# sigreg loss
|
||||
|
||||
sreg_loss = sigreg_loss(proj_locals, **self.sigreg_loss_kwargs)
|
||||
|
||||
return mse_loss * self.target_loss_weight + sreg_loss * self.sigreg_loss_weight
|
||||
|
||||
# quick run
|
||||
|
||||
if __name__ == '__main__':
|
||||
from vit_pytorch import ViT
|
||||
|
||||
model = ViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048
|
||||
)
|
||||
|
||||
learner = LeJEPA(
|
||||
model,
|
||||
image_size = 256,
|
||||
hidden_layer = 'to_latent', # layer name where output is hidden dimension
|
||||
projection_hidden_size = 256, # projector network hidden dimension
|
||||
projection_layers = 4, # number of layers in projection network
|
||||
num_classes_K = 65336, # output dimension
|
||||
target_loss_weight = 1.0,
|
||||
sigreg_loss_weight = 1.0
|
||||
)
|
||||
|
||||
opt = torch.optim.Adam(learner.parameters(), lr = 3e-4)
|
||||
|
||||
images = torch.randn(8, 3, 256, 256)
|
||||
|
||||
loss = learner(images)
|
||||
opt.zero_grad()
|
||||
loss.backward()
|
||||
opt.step()
|
||||
|
||||
print('loss:', loss.item())
|
||||
@@ -182,7 +182,7 @@ class LeViT(nn.Module):
|
||||
def forward(self, img):
|
||||
x = self.conv_embedding(img)
|
||||
|
||||
x = self.backbone(x)
|
||||
x = self.backbone(x)
|
||||
|
||||
x = self.pool(x)
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ class MAE(nn.Module):
|
||||
if self.encoder.pool == "cls":
|
||||
tokens += self.encoder.pos_embedding[:, 1:(num_patches + 1)]
|
||||
elif self.encoder.pool == "mean":
|
||||
tokens += self.encoder.pos_embedding.to(device, dtype=tokens.dtype)
|
||||
tokens += self.encoder.pos_embedding.to(device, dtype=tokens.dtype)
|
||||
|
||||
# calculate of patches needed to be masked, and get random indices, dividing it up for mask vs unmasked
|
||||
|
||||
@@ -87,7 +87,7 @@ class MAE(nn.Module):
|
||||
mask_tokens = mask_tokens + self.decoder_pos_emb(masked_indices)
|
||||
|
||||
# concat the masked tokens to the decoder tokens and attend with decoder
|
||||
|
||||
|
||||
decoder_tokens = torch.zeros(batch, num_patches, self.decoder_dim, device=device)
|
||||
decoder_tokens[batch_range, unmasked_indices] = unmasked_decoder_tokens
|
||||
decoder_tokens[batch_range, masked_indices] = mask_tokens
|
||||
|
||||
@@ -77,7 +77,7 @@ class Dropsample(nn.Module):
|
||||
def __init__(self, prob = 0):
|
||||
super().__init__()
|
||||
self.prob = prob
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
device = x.device
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ class Dropsample(Module):
|
||||
def __init__(self, prob = 0):
|
||||
super().__init__()
|
||||
self.prob = prob
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
device = x.device
|
||||
|
||||
|
||||
@@ -1,243 +1,243 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Reduce
|
||||
|
||||
# helpers
|
||||
|
||||
def conv_1x1_bn(inp, oup):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
nn.SiLU()
|
||||
)
|
||||
|
||||
def conv_nxn_bn(inp, oup, kernel_size=3, stride=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, kernel_size, stride, 1, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
nn.SiLU()
|
||||
)
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout=0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.attend = nn.Softmax(dim=-1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b p h n d -> b p n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
"""Transformer block described in ViT.
|
||||
Paper: https://arxiv.org/abs/2010.11929
|
||||
Based on: https://github.com/lucidrains/vit-pytorch
|
||||
"""
|
||||
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads, dim_head, dropout),
|
||||
FeedForward(dim, mlp_dim, dropout)
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
class MV2Block(nn.Module):
|
||||
"""MV2 block described in MobileNetV2.
|
||||
Paper: https://arxiv.org/pdf/1801.04381
|
||||
Based on: https://github.com/tonylins/pytorch-mobilenet-v2
|
||||
"""
|
||||
|
||||
def __init__(self, inp, oup, stride=1, expansion=4):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
assert stride in [1, 2]
|
||||
|
||||
hidden_dim = int(inp * expansion)
|
||||
self.use_res_connect = self.stride == 1 and inp == oup
|
||||
|
||||
if expansion == 1:
|
||||
self.conv = nn.Sequential(
|
||||
# dw
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
|
||||
1, groups=hidden_dim, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.SiLU(),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
)
|
||||
else:
|
||||
self.conv = nn.Sequential(
|
||||
# pw
|
||||
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.SiLU(),
|
||||
# dw
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
|
||||
1, groups=hidden_dim, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.SiLU(),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv(x)
|
||||
if self.use_res_connect:
|
||||
out = out + x
|
||||
return out
|
||||
|
||||
class MobileViTBlock(nn.Module):
|
||||
def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):
|
||||
super().__init__()
|
||||
self.ph, self.pw = patch_size
|
||||
|
||||
self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
|
||||
self.conv2 = conv_1x1_bn(channel, dim)
|
||||
|
||||
self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout)
|
||||
|
||||
self.conv3 = conv_1x1_bn(dim, channel)
|
||||
self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)
|
||||
|
||||
def forward(self, x):
|
||||
y = x.clone()
|
||||
|
||||
# Local representations
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
|
||||
# Global representations
|
||||
_, _, h, w = x.shape
|
||||
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
|
||||
x = self.transformer(x)
|
||||
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)
|
||||
|
||||
# Fusion
|
||||
x = self.conv3(x)
|
||||
x = torch.cat((x, y), 1)
|
||||
x = self.conv4(x)
|
||||
return x
|
||||
|
||||
class MobileViT(nn.Module):
|
||||
"""MobileViT.
|
||||
Paper: https://arxiv.org/abs/2110.02178
|
||||
Based on: https://github.com/chinhsuanwu/mobilevit-pytorch
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size,
|
||||
dims,
|
||||
channels,
|
||||
num_classes,
|
||||
expansion=4,
|
||||
kernel_size=3,
|
||||
patch_size=(2, 2),
|
||||
depths=(2, 4, 3)
|
||||
):
|
||||
super().__init__()
|
||||
assert len(dims) == 3, 'dims must be a tuple of 3'
|
||||
assert len(depths) == 3, 'depths must be a tuple of 3'
|
||||
|
||||
ih, iw = image_size
|
||||
ph, pw = patch_size
|
||||
assert ih % ph == 0 and iw % pw == 0
|
||||
|
||||
init_dim, *_, last_dim = channels
|
||||
|
||||
self.conv1 = conv_nxn_bn(3, init_dim, stride=2)
|
||||
|
||||
self.stem = nn.ModuleList([])
|
||||
self.stem.append(MV2Block(channels[0], channels[1], 1, expansion))
|
||||
self.stem.append(MV2Block(channels[1], channels[2], 2, expansion))
|
||||
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
|
||||
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
|
||||
|
||||
self.trunk = nn.ModuleList([])
|
||||
self.trunk.append(nn.ModuleList([
|
||||
MV2Block(channels[3], channels[4], 2, expansion),
|
||||
MobileViTBlock(dims[0], depths[0], channels[5],
|
||||
kernel_size, patch_size, int(dims[0] * 2))
|
||||
]))
|
||||
|
||||
self.trunk.append(nn.ModuleList([
|
||||
MV2Block(channels[5], channels[6], 2, expansion),
|
||||
MobileViTBlock(dims[1], depths[1], channels[7],
|
||||
kernel_size, patch_size, int(dims[1] * 4))
|
||||
]))
|
||||
|
||||
self.trunk.append(nn.ModuleList([
|
||||
MV2Block(channels[7], channels[8], 2, expansion),
|
||||
MobileViTBlock(dims[2], depths[2], channels[9],
|
||||
kernel_size, patch_size, int(dims[2] * 4))
|
||||
]))
|
||||
|
||||
self.to_logits = nn.Sequential(
|
||||
conv_1x1_bn(channels[-2], last_dim),
|
||||
Reduce('b c h w -> b c', 'mean'),
|
||||
nn.Linear(channels[-1], num_classes, bias=False)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
|
||||
for conv in self.stem:
|
||||
x = conv(x)
|
||||
|
||||
for conv, attn in self.trunk:
|
||||
x = conv(x)
|
||||
x = attn(x)
|
||||
|
||||
return self.to_logits(x)
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Reduce
|
||||
|
||||
# helpers
|
||||
|
||||
def conv_1x1_bn(inp, oup):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
nn.SiLU()
|
||||
)
|
||||
|
||||
def conv_nxn_bn(inp, oup, kernel_size=3, stride=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, kernel_size, stride, 1, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
nn.SiLU()
|
||||
)
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout=0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.attend = nn.Softmax(dim=-1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b p h n d -> b p n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
"""Transformer block described in ViT.
|
||||
Paper: https://arxiv.org/abs/2010.11929
|
||||
Based on: https://github.com/lucidrains/vit-pytorch
|
||||
"""
|
||||
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads, dim_head, dropout),
|
||||
FeedForward(dim, mlp_dim, dropout)
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
class MV2Block(nn.Module):
|
||||
"""MV2 block described in MobileNetV2.
|
||||
Paper: https://arxiv.org/pdf/1801.04381
|
||||
Based on: https://github.com/tonylins/pytorch-mobilenet-v2
|
||||
"""
|
||||
|
||||
def __init__(self, inp, oup, stride=1, expansion=4):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
assert stride in [1, 2]
|
||||
|
||||
hidden_dim = int(inp * expansion)
|
||||
self.use_res_connect = self.stride == 1 and inp == oup
|
||||
|
||||
if expansion == 1:
|
||||
self.conv = nn.Sequential(
|
||||
# dw
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
|
||||
1, groups=hidden_dim, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.SiLU(),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
)
|
||||
else:
|
||||
self.conv = nn.Sequential(
|
||||
# pw
|
||||
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.SiLU(),
|
||||
# dw
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
|
||||
1, groups=hidden_dim, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.SiLU(),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv(x)
|
||||
if self.use_res_connect:
|
||||
out = out + x
|
||||
return out
|
||||
|
||||
class MobileViTBlock(nn.Module):
|
||||
def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):
|
||||
super().__init__()
|
||||
self.ph, self.pw = patch_size
|
||||
|
||||
self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
|
||||
self.conv2 = conv_1x1_bn(channel, dim)
|
||||
|
||||
self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout)
|
||||
|
||||
self.conv3 = conv_1x1_bn(dim, channel)
|
||||
self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)
|
||||
|
||||
def forward(self, x):
|
||||
y = x.clone()
|
||||
|
||||
# Local representations
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
|
||||
# Global representations
|
||||
_, _, h, w = x.shape
|
||||
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
|
||||
x = self.transformer(x)
|
||||
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)
|
||||
|
||||
# Fusion
|
||||
x = self.conv3(x)
|
||||
x = torch.cat((x, y), 1)
|
||||
x = self.conv4(x)
|
||||
return x
|
||||
|
||||
class MobileViT(nn.Module):
|
||||
"""MobileViT.
|
||||
Paper: https://arxiv.org/abs/2110.02178
|
||||
Based on: https://github.com/chinhsuanwu/mobilevit-pytorch
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size,
|
||||
dims,
|
||||
channels,
|
||||
num_classes,
|
||||
expansion=4,
|
||||
kernel_size=3,
|
||||
patch_size=(2, 2),
|
||||
depths=(2, 4, 3)
|
||||
):
|
||||
super().__init__()
|
||||
assert len(dims) == 3, 'dims must be a tuple of 3'
|
||||
assert len(depths) == 3, 'depths must be a tuple of 3'
|
||||
|
||||
ih, iw = image_size
|
||||
ph, pw = patch_size
|
||||
assert ih % ph == 0 and iw % pw == 0
|
||||
|
||||
init_dim, *_, last_dim = channels
|
||||
|
||||
self.conv1 = conv_nxn_bn(3, init_dim, stride=2)
|
||||
|
||||
self.stem = nn.ModuleList([])
|
||||
self.stem.append(MV2Block(channels[0], channels[1], 1, expansion))
|
||||
self.stem.append(MV2Block(channels[1], channels[2], 2, expansion))
|
||||
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
|
||||
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
|
||||
|
||||
self.trunk = nn.ModuleList([])
|
||||
self.trunk.append(nn.ModuleList([
|
||||
MV2Block(channels[3], channels[4], 2, expansion),
|
||||
MobileViTBlock(dims[0], depths[0], channels[5],
|
||||
kernel_size, patch_size, int(dims[0] * 2))
|
||||
]))
|
||||
|
||||
self.trunk.append(nn.ModuleList([
|
||||
MV2Block(channels[5], channels[6], 2, expansion),
|
||||
MobileViTBlock(dims[1], depths[1], channels[7],
|
||||
kernel_size, patch_size, int(dims[1] * 4))
|
||||
]))
|
||||
|
||||
self.trunk.append(nn.ModuleList([
|
||||
MV2Block(channels[7], channels[8], 2, expansion),
|
||||
MobileViTBlock(dims[2], depths[2], channels[9],
|
||||
kernel_size, patch_size, int(dims[2] * 4))
|
||||
]))
|
||||
|
||||
self.to_logits = nn.Sequential(
|
||||
conv_1x1_bn(channels[-2], last_dim),
|
||||
Reduce('b c h w -> b c', 'mean'),
|
||||
nn.Linear(channels[-1], num_classes, bias=False)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
|
||||
for conv in self.stem:
|
||||
x = conv(x)
|
||||
|
||||
for conv, attn in self.trunk:
|
||||
x = conv(x)
|
||||
x = attn(x)
|
||||
|
||||
return self.to_logits(x)
|
||||
|
||||
@@ -110,7 +110,7 @@ class ViT(nn.Module):
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
|
||||
@@ -178,7 +178,7 @@ class MP3(nn.Module):
|
||||
|
||||
attended_tokens = self.vit.transformer(tokens, tokens_unmasked)
|
||||
logits = rearrange(self.mlp_head(attended_tokens), 'b n d -> (b n) d')
|
||||
|
||||
|
||||
# Define labels
|
||||
labels = repeat(torch.arange(num_patches, device = device), 'n -> (b n)', b = batch)
|
||||
loss = F.cross_entropy(logits, labels)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
from functools import partial, lru_cache
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
@@ -9,7 +9,6 @@ from torch import nn, Tensor
|
||||
from torch.nn.utils.rnn import pad_sequence as orig_pad_sequence
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
@@ -28,6 +27,12 @@ def pair(t):
|
||||
def divisible_by(numer, denom):
|
||||
return (numer % denom) == 0
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def posemb_grid(ph, pw, device):
|
||||
h_idx = torch.arange(ph, device=device).repeat_interleave(pw)
|
||||
w_idx = torch.arange(pw, device=device).repeat(ph)
|
||||
return torch.stack([h_idx, w_idx], dim=-1)
|
||||
|
||||
# auto grouping images
|
||||
|
||||
def group_images_by_max_seq_len(
|
||||
@@ -117,8 +122,7 @@ class Attention(nn.Module):
|
||||
self.q_norm = RMSNorm(heads, dim_head)
|
||||
self.k_norm = RMSNorm(heads, dim_head)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.dropout_p = dropout
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
|
||||
@@ -145,19 +149,22 @@ class Attention(nn.Module):
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2))
|
||||
# combine masks if both exist
|
||||
if exists(mask) or exists(attn_mask):
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b j -> b 1 1 j')
|
||||
if exists(mask) and exists(attn_mask):
|
||||
attn_mask = mask & attn_mask
|
||||
elif exists(mask):
|
||||
attn_mask = mask
|
||||
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b j -> b 1 1 j')
|
||||
dots = dots.masked_fill(~mask, -torch.finfo(dots.dtype).max)
|
||||
out = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask = attn_mask,
|
||||
dropout_p = self.dropout_p if self.training else 0.,
|
||||
scale = 1. # RMSNorm already includes sqrt(dim) scaling
|
||||
)
|
||||
|
||||
if exists(attn_mask):
|
||||
dots = dots.masked_fill(~attn_mask, -torch.finfo(dots.dtype).max)
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
@@ -281,42 +288,41 @@ class NaViT(nn.Module):
|
||||
for images in batched_images:
|
||||
num_images.append(len(images))
|
||||
|
||||
sequences = []
|
||||
positions = []
|
||||
image_ids = torch.empty((0,), device = device, dtype = torch.long)
|
||||
|
||||
for image_id, image in enumerate(images):
|
||||
assert image.ndim ==3 and image.shape[0] == c
|
||||
# compute patch dimensions for all images
|
||||
patch_dims = []
|
||||
for image in images:
|
||||
assert image.ndim == 3 and image.shape[0] == c
|
||||
image_dims = image.shape[-2:]
|
||||
assert all([divisible_by(dim, p) for dim in image_dims]), f'height and width {image_dims} of images must be divisible by patch size {p}'
|
||||
patch_dims.append((image_dims[0] // p, image_dims[1] // p))
|
||||
|
||||
ph, pw = map(lambda dim: dim // p, image_dims)
|
||||
# extract patches for all images
|
||||
sequences = [rearrange(img, 'c (h p1) (w p2) -> (h w) (c p1 p2)', p1=p, p2=p) for img in images]
|
||||
|
||||
pos = torch.stack(torch.meshgrid((
|
||||
arange(ph),
|
||||
arange(pw)
|
||||
), indexing = 'ij'), dim = -1)
|
||||
# compute positions - uses lru_cache to avoid redundant computation across forward passes
|
||||
positions = [posemb_grid(ph, pw, device) for ph, pw in patch_dims]
|
||||
|
||||
pos = rearrange(pos, 'h w c -> (h w) c')
|
||||
seq = rearrange(image, 'c (h p1) (w p2) -> (h w) (c p1 p2)', p1 = p, p2 = p)
|
||||
|
||||
seq_len = seq.shape[-2]
|
||||
|
||||
if has_token_dropout:
|
||||
# handle token dropout
|
||||
if has_token_dropout:
|
||||
for i, (seq, pos) in enumerate(zip(sequences, positions)):
|
||||
image_dims = images[i].shape[-2:]
|
||||
token_dropout = self.calc_token_dropout(*image_dims)
|
||||
seq_len = seq.shape[0]
|
||||
num_keep = max(1, int(seq_len * (1 - token_dropout)))
|
||||
keep_indices = torch.randn((seq_len,), device = device).topk(num_keep, dim = -1).indices
|
||||
keep_indices = torch.randn((seq_len,), device=device).topk(num_keep, dim=-1).indices
|
||||
sequences[i] = seq[keep_indices]
|
||||
positions[i] = pos[keep_indices]
|
||||
|
||||
seq = seq[keep_indices]
|
||||
pos = pos[keep_indices]
|
||||
|
||||
image_ids = F.pad(image_ids, (0, seq.shape[-2]), value = image_id)
|
||||
sequences.append(seq)
|
||||
positions.append(pos)
|
||||
# build image_ids efficiently using repeat_interleave
|
||||
patch_counts = [seq.shape[0] for seq in sequences]
|
||||
image_ids = torch.repeat_interleave(
|
||||
arange(len(images)),
|
||||
torch.tensor(patch_counts, device=device)
|
||||
)
|
||||
|
||||
batched_image_ids.append(image_ids)
|
||||
batched_sequences.append(torch.cat(sequences, dim = 0))
|
||||
batched_positions.append(torch.cat(positions, dim = 0))
|
||||
batched_sequences.append(torch.cat(sequences, dim=0))
|
||||
batched_positions.append(torch.cat(positions, dim=0))
|
||||
|
||||
# derive key padding mask
|
||||
|
||||
@@ -337,11 +343,11 @@ class NaViT(nn.Module):
|
||||
|
||||
# need to know how many images for final attention pooling
|
||||
|
||||
num_images = torch.tensor(num_images, device = device, dtype = torch.long)
|
||||
num_images = torch.tensor(num_images, device = device, dtype = torch.long)
|
||||
|
||||
# to patches
|
||||
|
||||
x = self.to_patch_embedding(patches)
|
||||
x = self.to_patch_embedding(patches)
|
||||
|
||||
# factorized 2d absolute positional embedding
|
||||
|
||||
|
||||
@@ -64,7 +64,7 @@ class Attention(Module):
|
||||
self.to_out = nn.Linear(dim_inner, dim, bias = False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
self,
|
||||
x,
|
||||
context: Tensor | None = None
|
||||
):
|
||||
|
||||
@@ -64,7 +64,7 @@ class Attention(Module):
|
||||
self.to_out = nn.Linear(dim_inner, dim, bias = False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
self,
|
||||
x,
|
||||
context: Tensor | None = None
|
||||
):
|
||||
|
||||
@@ -154,7 +154,7 @@ class PiT(nn.Module):
|
||||
|
||||
for ind, (layer_depth, layer_heads) in enumerate(zip(depth, heads)):
|
||||
not_last = ind < (len(depth) - 1)
|
||||
|
||||
|
||||
layers.append(Transformer(dim, layer_depth, layer_heads, dim_head, mlp_dim, dropout))
|
||||
|
||||
if not_last:
|
||||
|
||||
@@ -146,7 +146,7 @@ class R2LTransformer(nn.Module):
|
||||
region_tokens = rearrange(region_tokens, 'b c h w -> b (h w) c')
|
||||
|
||||
# calculate local relative positional bias
|
||||
|
||||
|
||||
h_range = torch.arange(window_size_h, device = device)
|
||||
w_range = torch.arange(window_size_w, device = device)
|
||||
|
||||
|
||||
@@ -187,7 +187,7 @@ class InteractiveWindowedSelfAttention(nn.Module):
|
||||
|
||||
out = rearrange(out, '(b x y) h (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz_h, y = width // wsz_w, w1 = wsz_h, w2 = wsz_w)
|
||||
|
||||
# add LIM output
|
||||
# add LIM output
|
||||
|
||||
out = out + local_out
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
|
||||
omega = 1. / (temperature ** omega)
|
||||
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
|
||||
return pe.type(dtype)
|
||||
|
||||
@@ -57,7 +57,7 @@ class Attend(nn.Module):
|
||||
config = self.cuda_config if q.is_cuda else self.cpu_config
|
||||
|
||||
# flash attention - https://arxiv.org/abs/2205.14135
|
||||
|
||||
|
||||
with torch.backends.cuda.sdp_kernel(**config._asdict()):
|
||||
out = F.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ def posemb_sincos_3d(patches, temperature = 10000, dtype = torch.float32):
|
||||
|
||||
z = z.flatten()[:, None] * omega[None, :]
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos(), z.sin(), z.cos()), dim = 1)
|
||||
|
||||
@@ -52,7 +52,7 @@ class Attend(Module):
|
||||
|
||||
def flash_attn(self, q, k, v):
|
||||
# flash attention - https://arxiv.org/abs/2205.14135
|
||||
|
||||
|
||||
with torch.backends.cuda.sdp_kernel(**self.config._asdict()):
|
||||
out = F.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ def FeedForward(dim, hidden_dim):
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
)
|
||||
)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64):
|
||||
|
||||
@@ -98,7 +98,7 @@ class SimpleViT(nn.Module):
|
||||
h = image_height // patch_height,
|
||||
w = image_width // patch_width,
|
||||
dim = dim,
|
||||
)
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ def posemb_sincos_3d(patches, temperature = 10000, dtype = torch.float32):
|
||||
|
||||
z = z.flatten()[:, None] * omega[None, :]
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos(), z.sin(), z.cos()), dim = 1)
|
||||
|
||||
|
||||
230
vit_pytorch/simple_vit_attn_residual.py
Normal file
230
vit_pytorch/simple_vit_attn_residual.py
Normal file
@@ -0,0 +1,230 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Module, ModuleList
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
def default(v, d):
|
||||
return v if exists(v) else d
|
||||
|
||||
def last(arr):
|
||||
return arr[-1]
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def divisible_by(num, den):
|
||||
return (num % den) == 0
|
||||
|
||||
def posemb_sincos_2d(h, w, dim, temperature = 10000, dtype = torch.float32):
|
||||
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing = 'ij')
|
||||
assert divisible_by(dim, 4), 'feature dimension must be multiple of 4 for sincos emb'
|
||||
omega = torch.arange(dim // 4) / (dim // 4 - 1)
|
||||
omega = 1.0 / (temperature ** omega)
|
||||
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
|
||||
return pe.type(dtype)
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
def forward(self, x, context = None):
|
||||
x = self.norm(x)
|
||||
context = default(context, x)
|
||||
|
||||
q = self.to_q(x)
|
||||
k, v = self.to_kv(context).chunk(2, dim = -1)
|
||||
|
||||
q, k, v = tuple(rearrange(t, 'b n (h d) -> b h n d', h = self.heads) for t in (q, k, v))
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class AttentionPool(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, use_learned_query = True):
|
||||
super().__init__()
|
||||
self.use_learned_query = use_learned_query
|
||||
self.norm_context = nn.LayerNorm(dim)
|
||||
self.attn = Attention(dim, heads = heads, dim_head = dim_head)
|
||||
|
||||
if use_learned_query:
|
||||
self.query = nn.Parameter(torch.randn(dim))
|
||||
|
||||
def forward(self, context, query = None):
|
||||
batch = context.shape[0]
|
||||
|
||||
context = self.norm_context(context)
|
||||
|
||||
if self.use_learned_query:
|
||||
q = repeat(self.query, 'd -> b 1 d', b = batch)
|
||||
else:
|
||||
q = query
|
||||
|
||||
return self.attn(q, context = context)
|
||||
|
||||
class AttentionResidual(Module):
|
||||
"""
|
||||
replaces the standard residual connection.
|
||||
pools from a growing history of all previous outputs via attention,
|
||||
then passes the result through the wrapped module (attn or ff).
|
||||
the output is appended back to history (in-place list mutation).
|
||||
"""
|
||||
|
||||
def __init__(self, fn, dim, heads = 8, dim_head = 64, use_learned_query = True):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.attn_pool = AttentionPool(dim, heads = heads, dim_head = dim_head, use_learned_query = use_learned_query)
|
||||
|
||||
def forward(self, history):
|
||||
context = torch.stack(history, dim = 2)
|
||||
b, n, l, d = context.shape
|
||||
|
||||
context = rearrange(context, 'b n l d -> (b n) l d')
|
||||
|
||||
last_out = last(history)
|
||||
query = rearrange(last_out, 'b n d -> (b n) 1 d')
|
||||
|
||||
pooled = self.attn_pool(context, query = query)
|
||||
|
||||
pooled = rearrange(pooled, '(b n) 1 d -> b n d', b = b, n = n)
|
||||
|
||||
out = self.fn(pooled)
|
||||
|
||||
# mutate history for subsequent layers
|
||||
history.append(out)
|
||||
|
||||
return history
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, use_learned_query = True):
|
||||
super().__init__()
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(ModuleList([
|
||||
AttentionResidual(Attention(dim, heads = heads, dim_head = dim_head), dim, heads = heads, dim_head = dim_head, use_learned_query = use_learned_query),
|
||||
AttentionResidual(FeedForward(dim, mlp_dim), dim, heads = heads, dim_head = dim_head, use_learned_query = use_learned_query)
|
||||
]))
|
||||
|
||||
self.final_attn_pool = AttentionResidual(nn.LayerNorm(dim), dim, heads = heads, dim_head = dim_head, use_learned_query = use_learned_query)
|
||||
|
||||
def forward(self, tokens):
|
||||
history = [tokens]
|
||||
|
||||
for attn_res, ff_res in self.layers:
|
||||
history = attn_res(history)
|
||||
history = ff_res(history)
|
||||
|
||||
history = self.final_attn_pool(history)
|
||||
|
||||
return last(history)
|
||||
|
||||
class SimpleViTAttnResidual(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_size,
|
||||
patch_size,
|
||||
num_classes,
|
||||
dim,
|
||||
depth,
|
||||
heads,
|
||||
mlp_dim,
|
||||
channels = 3,
|
||||
dim_head = 64,
|
||||
use_learned_query = True
|
||||
):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert divisible_by(image_height, patch_height) and divisible_by(image_width, patch_width), 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = posemb_sincos_2d(
|
||||
h = image_height // patch_height,
|
||||
w = image_width // patch_width,
|
||||
dim = dim,
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, use_learned_query = use_learned_query)
|
||||
|
||||
self.to_latent = nn.Identity()
|
||||
self.linear_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, img):
|
||||
device, dtype = img.device, img.dtype
|
||||
|
||||
x = self.to_patch_embedding(img)
|
||||
x += self.pos_embedding.to(device, dtype = dtype)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
x = x.mean(dim = 1)
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.linear_head(x)
|
||||
|
||||
if __name__ == '__main__':
|
||||
for use_learned_query in (True, False):
|
||||
v = SimpleViTAttnResidual(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
use_learned_query = use_learned_query
|
||||
)
|
||||
|
||||
img = torch.randn(2, 3, 256, 256)
|
||||
preds = v(img)
|
||||
|
||||
assert preds.shape == (2, 1000)
|
||||
@@ -186,7 +186,7 @@ class SimpleViT(nn.Module):
|
||||
h = image_height // patch_height,
|
||||
w = image_width // patch_width,
|
||||
dim = dim,
|
||||
)
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, num_residual_streams)
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
|
||||
omega = 1. / (temperature ** omega)
|
||||
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
|
||||
return pe.type(dtype)
|
||||
|
||||
|
||||
@@ -119,7 +119,7 @@ class SimpleViT(nn.Module):
|
||||
h = image_height // patch_height,
|
||||
w = image_width // patch_width,
|
||||
dim = dim,
|
||||
)
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
|
||||
|
||||
|
||||
@@ -105,7 +105,7 @@ class SimpleViT(nn.Module):
|
||||
h = image_height // patch_height,
|
||||
w = image_width // patch_width,
|
||||
dim = dim,
|
||||
)
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
|
||||
|
||||
|
||||
@@ -120,7 +120,7 @@ class SimpleViT(Module):
|
||||
h = image_height // patch_height,
|
||||
w = image_width // patch_width,
|
||||
dim = dim,
|
||||
)
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ def posemb_sincos_2d(
|
||||
omega = temperature ** -omega
|
||||
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
|
||||
pe = cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
|
||||
pe = pe.type(dtype)
|
||||
@@ -120,6 +120,12 @@ class Attention(Module):
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(dim_context, inner_dim * 2, bias = False)
|
||||
|
||||
self.to_out_gates = nn.Sequential(
|
||||
nn.Linear(dim, heads),
|
||||
Rearrange('b ... h -> b h ... 1'),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
@@ -150,6 +156,9 @@ class Attention(Module):
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
|
||||
out = out * self.to_out_gates(x)
|
||||
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
@@ -433,7 +442,8 @@ class VAAT(Module):
|
||||
self_attn_heads = 4,
|
||||
self_attn_dim_head = 32,
|
||||
ast_layer_indices: tuple[int, ...] | None = None,
|
||||
vit_layer_indices: tuple[int, ...] | None = None
|
||||
vit_layer_indices: tuple[int, ...] | None = None,
|
||||
num_advantage_bins = 0
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -471,7 +481,7 @@ class VAAT(Module):
|
||||
|
||||
assert len(ast_layer_indices) == depth, f'number of ast layer indices {len(ast_layer_indices)} does not much the VAAT depth {depth}'
|
||||
|
||||
self.register_buffer('ast_layer_indices', tensor(vit_layer_indices), persistent = False)
|
||||
self.register_buffer('ast_layer_indices', tensor(ast_layer_indices), persistent = False)
|
||||
|
||||
# handle maybe multiple frames
|
||||
|
||||
@@ -502,6 +512,14 @@ class VAAT(Module):
|
||||
|
||||
self.action_pos_emb = nn.Parameter(torch.randn(action_chunk_len, dim) * 1e-2)
|
||||
|
||||
# handle maybe advantage conditioning
|
||||
|
||||
self.has_advantages = num_advantage_bins > 0
|
||||
self.num_advantage_bins = num_advantage_bins
|
||||
|
||||
if self.has_advantages:
|
||||
self.advantage_emb = nn.Embedding(num_advantage_bins + 1, dim)
|
||||
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
@@ -531,14 +549,15 @@ class VAAT(Module):
|
||||
video_or_image, # (b v? c t? h w) - batch, views [wrist + third person or more], channels, maybe time, height, width
|
||||
audio_or_spec, # (b v? t) | (b v?f t) - batch, audio len | batch, spec freq, time
|
||||
*,
|
||||
extra = None, # (b d) - batch, dim extra
|
||||
extra = None, # (b d) - batch, dim extra
|
||||
tasks = None, # (b)
|
||||
advantages = None,# (b)
|
||||
actions = None, # (b k d) - batch, action chunk length, action dimension
|
||||
return_hiddens = False,
|
||||
freeze_vit = False,
|
||||
freeze_ast = False
|
||||
):
|
||||
batch = video_or_image.shape[0]
|
||||
batch, device = video_or_image.shape[0], video_or_image.device
|
||||
return_loss = exists(actions)
|
||||
|
||||
# handle some various input dimensions
|
||||
@@ -646,53 +665,66 @@ class VAAT(Module):
|
||||
|
||||
audio_context = rearrange(audio_hiddens, 'l b v n d -> l b (v n) d')
|
||||
|
||||
# get main action tokens and maybe append extra
|
||||
# main action tokens
|
||||
|
||||
action_tokens = repeat(self.action_pos_emb, 'k d -> b k d', b = batch)
|
||||
action_tokens = repeat(self.action_pos_emb, 'n d -> b n d', b = batch)
|
||||
|
||||
has_extra = exists(extra)
|
||||
# maybe advantage tokens
|
||||
|
||||
if has_extra:
|
||||
assert self.accept_extra_token
|
||||
empty_token = action_tokens[:, 0:0]
|
||||
|
||||
extra_token = self.to_extra_token(extra)
|
||||
maybe_advantage_embed = empty_token
|
||||
|
||||
action_tokens, packed_extra = pack([action_tokens, extra_token], 'b * d')
|
||||
if self.has_advantages and exists(advantages):
|
||||
if isinstance(advantages, int):
|
||||
advantages = torch.full((batch,), advantages, device = device, dtype = torch.long)
|
||||
|
||||
maybe_advantage_embed = self.advantage_emb(advantages + 1)
|
||||
|
||||
# register tokens
|
||||
|
||||
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
|
||||
register_tokens = empty_token
|
||||
|
||||
action_tokens, registers_packed_shape = pack((register_tokens, action_tokens), 'b * d')
|
||||
if exists(self.register_tokens):
|
||||
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
|
||||
|
||||
# cross attention
|
||||
# extra
|
||||
|
||||
hiddens = [action_tokens]
|
||||
maybe_extra_embed = empty_token
|
||||
|
||||
has_extra = exists(extra)
|
||||
if has_extra:
|
||||
assert self.accept_extra_token
|
||||
|
||||
maybe_extra_embed = self.to_extra_token(extra)
|
||||
|
||||
# pack all tokens for attention
|
||||
|
||||
tokens, ps = pack((register_tokens, maybe_advantage_embed, action_tokens, maybe_extra_embed), 'b * d')
|
||||
|
||||
# transformer
|
||||
|
||||
hiddens = [tokens]
|
||||
|
||||
for (maybe_film, maybe_self_attn, image_cross_attn, audio_cross_attn, ff), image_layer_context, audio_layer_context in zip(self.layers, image_context, audio_context):
|
||||
|
||||
if exists(tasks):
|
||||
action_tokens = maybe_film(action_tokens, task_emb)
|
||||
if exists(maybe_film) and exists(tasks):
|
||||
tokens = maybe_film(tokens, task_emb)
|
||||
|
||||
action_tokens = image_cross_attn(action_tokens, image_layer_context) + action_tokens
|
||||
tokens = image_cross_attn(tokens, image_layer_context) + tokens
|
||||
|
||||
action_tokens = audio_cross_attn(action_tokens, audio_layer_context) + action_tokens
|
||||
tokens = audio_cross_attn(tokens, audio_layer_context) + tokens
|
||||
|
||||
if exists(maybe_self_attn):
|
||||
action_tokens = maybe_self_attn(action_tokens) + action_tokens
|
||||
tokens = maybe_self_attn(tokens) + tokens
|
||||
|
||||
action_tokens = ff(action_tokens) + action_tokens
|
||||
tokens = ff(tokens) + tokens
|
||||
|
||||
hiddens.append(action_tokens)
|
||||
hiddens.append(tokens)
|
||||
|
||||
# unpack registers
|
||||
# unpack register, advantage, action, and extra tokens
|
||||
|
||||
_, action_tokens = unpack(action_tokens, registers_packed_shape, 'b * d')
|
||||
|
||||
# maybe unpack extra
|
||||
|
||||
if has_extra:
|
||||
action_tokens, _ = unpack(action_tokens, packed_extra, 'b * d')
|
||||
maybe_register_embed, maybe_advantage_embed, action_tokens, maybe_extra_embed = unpack(tokens, ps, 'b * d')
|
||||
|
||||
# norm and prediction
|
||||
|
||||
@@ -735,43 +767,51 @@ if __name__ == '__main__':
|
||||
mlp_dim = 384 * 4
|
||||
)
|
||||
|
||||
vaat = VAAT(
|
||||
vit,
|
||||
ast,
|
||||
dim = 512,
|
||||
depth = 9,
|
||||
heads = 8,
|
||||
dim_head = 64,
|
||||
mlp_dim = 2048,
|
||||
dim_action = 20,
|
||||
action_chunk_len = 7,
|
||||
time_seq_len = 4,
|
||||
num_image_views = 2,
|
||||
num_audio_views = 2,
|
||||
num_tasks = 4,
|
||||
add_self_attn = True,
|
||||
dim_extra_token = 33, # extra token with some variable dimension
|
||||
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
|
||||
0, 0, 1, 1, 2, 2, 3, 3, 4
|
||||
),
|
||||
ast_layer_indices = (
|
||||
1, 1, 1, 2, 2, 2, 3, 3, 3
|
||||
for num_adv_bins in (0, 2, 10):
|
||||
vaat = VAAT(
|
||||
vit,
|
||||
ast,
|
||||
dim = 512,
|
||||
depth = 9,
|
||||
heads = 8,
|
||||
dim_head = 64,
|
||||
mlp_dim = 2048,
|
||||
dim_action = 20,
|
||||
action_chunk_len = 7,
|
||||
time_seq_len = 4,
|
||||
num_image_views = 2,
|
||||
num_audio_views = 2,
|
||||
num_tasks = 4,
|
||||
num_advantage_bins = num_adv_bins,
|
||||
add_self_attn = True,
|
||||
dim_extra_token = 33, # extra token with some variable dimension
|
||||
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
|
||||
0, 0, 1, 1, 2, 2, 3, 3, 4
|
||||
),
|
||||
ast_layer_indices = (
|
||||
1, 1, 1, 2, 2, 2, 3, 3, 3
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
images = torch.randn(2, 2, 3, 4, 256, 256) # (2 views with 4 frames)
|
||||
audio = torch.randn(2, 2, 14_100 * 5)
|
||||
images = torch.randn(2, 2, 3, 4, 256, 256) # (2 views with 4 frames)
|
||||
audio = torch.randn(2, 2, 14_100 * 5)
|
||||
|
||||
tasks = torch.randint(0, 4, (2,))
|
||||
extra = torch.randn(2, 33) # extra internal state
|
||||
tasks = torch.randint(0, 4, (2,))
|
||||
extra = torch.randn(2, 33) # extra internal state
|
||||
|
||||
actions = torch.randn(2, 7, 20) # actions for learning
|
||||
# advantage conditioning
|
||||
|
||||
loss = vaat(images, audio, actions = actions, tasks = tasks, extra = extra, freeze_vit = True)
|
||||
loss.backward()
|
||||
advantages = None
|
||||
if num_adv_bins > 0:
|
||||
advantages = torch.randint(-1, num_adv_bins, (2,))
|
||||
|
||||
# after much training
|
||||
actions = torch.randn(2, 7, 20) # actions for learning
|
||||
|
||||
pred_actions, hiddens = vaat(images, audio, tasks = tasks, extra = extra, return_hiddens = True)
|
||||
loss = vaat(images, audio, actions = actions, advantages = advantages, tasks = tasks, extra = extra, freeze_vit = True)
|
||||
loss.backward()
|
||||
|
||||
assert pred_actions.shape == (2, 7, 20)
|
||||
# after much training
|
||||
|
||||
pred_actions, hiddens = vaat(images, audio, advantages = advantages, tasks = tasks, extra = extra, return_hiddens = True)
|
||||
|
||||
assert pred_actions.shape == (2, 7, 20)
|
||||
|
||||
@@ -92,6 +92,12 @@ class Attention(Module):
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(dim_context, inner_dim * 2, bias = False)
|
||||
|
||||
self.to_out_gates = nn.Sequential(
|
||||
nn.Linear(dim, heads),
|
||||
Rearrange('b ... h -> b h ... 1'),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
@@ -122,6 +128,8 @@ class Attention(Module):
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = out * self.to_out_gates(x) # https://arxiv.org/abs/2505.06708
|
||||
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
@@ -270,7 +278,8 @@ class VAT(Module):
|
||||
add_self_attn = True, # in the paper, they didn't have any ways for the action token to exchange information with the extra token, so we'll just add it as an option
|
||||
self_attn_heads = 4,
|
||||
self_attn_dim_head = 32,
|
||||
vit_layer_indices: tuple[int, ...] | None = None
|
||||
vit_layer_indices: tuple[int, ...] | None = None,
|
||||
num_advantage_bins = 0
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -316,6 +325,14 @@ class VAT(Module):
|
||||
|
||||
self.action_pos_emb = nn.Parameter(torch.randn(action_chunk_len, dim) * 1e-2)
|
||||
|
||||
# handle maybe advantage conditioning
|
||||
|
||||
self.has_advantages = num_advantage_bins > 0
|
||||
self.num_advantage_bins = num_advantage_bins
|
||||
|
||||
if self.has_advantages:
|
||||
self.advantage_emb = nn.Embedding(num_advantage_bins + 1, dim)
|
||||
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
@@ -343,13 +360,14 @@ class VAT(Module):
|
||||
self,
|
||||
video_or_image, # (b v? c t? h w) - batch, views [wrist + third person or more], channels, maybe time, height, width
|
||||
*,
|
||||
extra = None, # (b d) - batch, dim extra
|
||||
extra = None, # (b d) - batch, dim extra
|
||||
tasks = None, # (b)
|
||||
advantages = None,# (b)
|
||||
actions = None, # (b k d) - batch, action chunk length, action dimension
|
||||
return_hiddens = False,
|
||||
freeze_vit = False
|
||||
):
|
||||
batch = video_or_image.shape[0]
|
||||
batch, device = video_or_image.shape[0], video_or_image.device
|
||||
return_loss = exists(actions)
|
||||
|
||||
# handle some various input dimensions
|
||||
@@ -415,51 +433,64 @@ class VAT(Module):
|
||||
|
||||
context = rearrange(hiddens, 'l b v t n d -> l b (v t n) d')
|
||||
|
||||
# get main action tokens and maybe append extra
|
||||
# main action tokens
|
||||
|
||||
action_tokens = repeat(self.action_pos_emb, 'k d -> b k d', b = batch)
|
||||
action_tokens = repeat(self.action_pos_emb, 'n d -> b n d', b = batch)
|
||||
|
||||
has_extra = exists(extra)
|
||||
# maybe advantage tokens
|
||||
|
||||
if has_extra:
|
||||
assert self.accept_extra_token
|
||||
empty_token = action_tokens[:, 0:0]
|
||||
|
||||
extra_token = self.to_extra_token(extra)
|
||||
maybe_advantage_embed = empty_token
|
||||
|
||||
action_tokens, packed_extra = pack([action_tokens, extra_token], 'b * d')
|
||||
if self.has_advantages and exists(advantages):
|
||||
if isinstance(advantages, int):
|
||||
advantages = torch.full((batch,), advantages, device = device, dtype = torch.long)
|
||||
|
||||
maybe_advantage_embed = self.advantage_emb(advantages + 1)
|
||||
|
||||
# register tokens
|
||||
|
||||
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
|
||||
register_tokens = empty_token
|
||||
|
||||
action_tokens, registers_packed_shape = pack((register_tokens, action_tokens), 'b * d')
|
||||
if exists(self.register_tokens):
|
||||
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
|
||||
|
||||
# cross attention
|
||||
# extra
|
||||
|
||||
hiddens = [action_tokens]
|
||||
maybe_extra_embed = empty_token
|
||||
|
||||
has_extra = exists(extra)
|
||||
if has_extra:
|
||||
assert self.accept_extra_token
|
||||
|
||||
maybe_extra_embed = self.to_extra_token(extra)
|
||||
|
||||
# pack all tokens for attention
|
||||
|
||||
tokens, ps = pack((register_tokens, maybe_advantage_embed, action_tokens, maybe_extra_embed), 'b * d')
|
||||
|
||||
# transformer
|
||||
|
||||
hiddens = [tokens]
|
||||
|
||||
for (maybe_film, maybe_self_attn, cross_attn, ff), layer_context in zip(self.layers, context):
|
||||
|
||||
if exists(tasks):
|
||||
action_tokens = maybe_film(action_tokens, task_emb)
|
||||
if exists(maybe_film) and exists(tasks):
|
||||
tokens = maybe_film(tokens, task_emb)
|
||||
|
||||
action_tokens = cross_attn(action_tokens, layer_context) + action_tokens
|
||||
tokens = cross_attn(tokens, layer_context) + tokens
|
||||
|
||||
if exists(maybe_self_attn):
|
||||
action_tokens = maybe_self_attn(action_tokens) + action_tokens
|
||||
tokens = maybe_self_attn(tokens) + tokens
|
||||
|
||||
action_tokens = ff(action_tokens) + action_tokens
|
||||
tokens = ff(tokens) + tokens
|
||||
|
||||
hiddens.append(action_tokens)
|
||||
hiddens.append(tokens)
|
||||
|
||||
# unpack registers
|
||||
# unpack register, advantage, action, and extra tokens
|
||||
|
||||
_, action_tokens = unpack(action_tokens, registers_packed_shape, 'b * d')
|
||||
|
||||
# maybe unpack extra
|
||||
|
||||
if has_extra:
|
||||
action_tokens, _ = unpack(action_tokens, packed_extra, 'b * d')
|
||||
maybe_register_embed, maybe_advantage_embed, action_tokens, maybe_extra_embed = unpack(tokens, ps, 'b * d')
|
||||
|
||||
# norm and prediction
|
||||
|
||||
@@ -493,36 +524,44 @@ if __name__ == '__main__':
|
||||
mlp_dim = 1024
|
||||
)
|
||||
|
||||
vat = VAT(
|
||||
vit,
|
||||
dim = 512,
|
||||
depth = 9,
|
||||
heads = 8,
|
||||
dim_head = 64,
|
||||
mlp_dim = 2048,
|
||||
dim_action = 20,
|
||||
action_chunk_len = 7,
|
||||
time_seq_len = 4,
|
||||
num_views = 2,
|
||||
num_tasks = 4,
|
||||
add_self_attn = True,
|
||||
dim_extra_token = 33, # extra token with some variable dimension
|
||||
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
|
||||
0, 0, 1, 1, 2, 2, 3, 3, 4
|
||||
for num_adv_bins in (0, 2, 10):
|
||||
vat = VAT(
|
||||
vit,
|
||||
dim = 512,
|
||||
depth = 9,
|
||||
heads = 8,
|
||||
dim_head = 64,
|
||||
mlp_dim = 2048,
|
||||
dim_action = 20,
|
||||
action_chunk_len = 7,
|
||||
time_seq_len = 4,
|
||||
num_views = 2,
|
||||
num_tasks = 4,
|
||||
num_advantage_bins = num_adv_bins,
|
||||
add_self_attn = True,
|
||||
dim_extra_token = 33, # extra token with some variable dimension
|
||||
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
|
||||
0, 0, 1, 1, 2, 2, 3, 3, 4
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
images = torch.randn(2, 2, 3, 4, 256, 256) # (2 views with 4 frames)
|
||||
tasks = torch.randint(0, 4, (2,))
|
||||
extra = torch.randn(2, 33) # extra internal state
|
||||
images = torch.randn(2, 2, 3, 4, 256, 256) # (2 views with 4 frames)
|
||||
tasks = torch.randint(0, 4, (2,))
|
||||
extra = torch.randn(2, 33) # extra internal state
|
||||
|
||||
actions = torch.randn(2, 7, 20) # actions for learning
|
||||
# advantage conditioning
|
||||
|
||||
loss = vat(images, actions = actions, tasks = tasks, extra = extra, freeze_vit = True)
|
||||
loss.backward()
|
||||
advantages = None
|
||||
if num_adv_bins > 0:
|
||||
advantages = torch.randint(-1, num_adv_bins, (2,))
|
||||
|
||||
# after much training
|
||||
actions = torch.randn(2, 7, 20) # actions for learning
|
||||
|
||||
pred_actions, hiddens = vat(images, tasks = tasks, extra = extra, return_hiddens = True)
|
||||
loss = vat(images, actions = actions, advantages = advantages, tasks = tasks, extra = extra, freeze_vit = True)
|
||||
loss.backward()
|
||||
|
||||
assert pred_actions.shape == (2, 7, 20)
|
||||
# after much training
|
||||
|
||||
pred_actions, hiddens = vat(images, advantages = advantages, tasks = tasks, extra = extra, return_hiddens = True)
|
||||
|
||||
assert pred_actions.shape == (2, 7, 20)
|
||||
|
||||
521
vit_pytorch/vat_siglip.py
Normal file
521
vit_pytorch/vat_siglip.py
Normal file
@@ -0,0 +1,521 @@
|
||||
from __future__ import annotations
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, cat, stack, tensor, einsum
|
||||
from torch.nn import Module, ModuleList
|
||||
|
||||
from einops import rearrange, repeat, pack, unpack
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
def default(v, d):
|
||||
return v if exists(v) else d
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
# attention
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_context = None,
|
||||
heads = 8,
|
||||
dim_head = 64,
|
||||
dropout = 0.,
|
||||
norm_eps = 1e-6,
|
||||
gate_attn = False
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.norm = nn.LayerNorm(dim, eps = norm_eps)
|
||||
|
||||
self.is_cross_attn = exists(dim_context)
|
||||
dim_context = default(dim_context, dim)
|
||||
self.norm_context = nn.LayerNorm(dim_context, eps = norm_eps) if self.is_cross_attn else None
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim)
|
||||
self.to_kv = nn.Linear(dim_context, inner_dim * 2)
|
||||
|
||||
self.to_out_gates = nn.Sequential(
|
||||
nn.Linear(dim, heads),
|
||||
Rearrange('b ... h -> b h ... 1'),
|
||||
nn.Sigmoid()
|
||||
) if gate_attn else None
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, context = None):
|
||||
x = self.norm(x)
|
||||
|
||||
if self.is_cross_attn:
|
||||
assert exists(context)
|
||||
context = self.norm_context(context)
|
||||
else:
|
||||
context = x
|
||||
|
||||
q = self.to_q(x)
|
||||
k, v = self.to_kv(context).chunk(2, dim = -1)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))
|
||||
|
||||
sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
attn = sim.softmax(dim = -1)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
|
||||
if exists(self.to_out_gates):
|
||||
out = out * self.to_out_gates(x) # https://arxiv.org/abs/2505.06708
|
||||
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
def FeedForward(
|
||||
dim,
|
||||
dim_inner,
|
||||
norm_eps = 1e-6
|
||||
):
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim, eps = norm_eps),
|
||||
nn.Linear(dim, dim_inner),
|
||||
nn.GELU(approximate = 'tanh'),
|
||||
nn.Linear(dim_inner, dim)
|
||||
)
|
||||
|
||||
class SigLIP(Module):
|
||||
def __init__(
|
||||
self,
|
||||
image_size = 224,
|
||||
patch_size = 14,
|
||||
dim = 1152,
|
||||
depth = 27,
|
||||
heads = 16,
|
||||
mlp_dim = 4304,
|
||||
norm_eps = 1e-6
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.depth = depth
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
dim_head = dim // heads
|
||||
|
||||
self.to_patch_embed = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
||||
nn.Linear(patch_size * patch_size * 3, dim)
|
||||
)
|
||||
|
||||
self.pos_embed = nn.Parameter(torch.randn(num_patches, dim))
|
||||
|
||||
self.layers = ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, norm_eps = norm_eps),
|
||||
FeedForward(dim = dim, dim_inner = mlp_dim, norm_eps = norm_eps)
|
||||
]))
|
||||
|
||||
self.norm = nn.LayerNorm(dim, eps = norm_eps)
|
||||
|
||||
def forward(self, x, return_hiddens = False):
|
||||
x = self.to_patch_embed(x)
|
||||
num_patches = x.shape[1]
|
||||
|
||||
x = x + self.pos_embed[:num_patches]
|
||||
|
||||
hiddens = []
|
||||
|
||||
for attn, ff in self.layers:
|
||||
hiddens.append(x)
|
||||
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
|
||||
out = self.norm(x)
|
||||
|
||||
if not return_hiddens:
|
||||
return out
|
||||
|
||||
return out, stack(hiddens)
|
||||
|
||||
class FiLM(Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
proj = nn.Linear(dim, dim * 2)
|
||||
|
||||
self.to_gamma_beta = nn.Sequential(
|
||||
proj,
|
||||
Rearrange('b (two d) -> two b 1 d', two = 2)
|
||||
)
|
||||
|
||||
nn.init.zeros_(proj.weight)
|
||||
nn.init.zeros_(proj.bias)
|
||||
|
||||
def forward(self, tokens, cond):
|
||||
gamma, beta = self.to_gamma_beta(cond)
|
||||
return tokens * gamma + beta
|
||||
|
||||
class SigLIPVAT(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim = 512,
|
||||
depth = 27,
|
||||
heads = 8,
|
||||
dim_head = 64,
|
||||
dim_action = 32,
|
||||
mlp_dim = 2048,
|
||||
num_views = 1,
|
||||
num_tasks = None,
|
||||
dim_extra_token = None,
|
||||
num_register_tokens = 4,
|
||||
action_chunk_len = 50,
|
||||
time_seq_len = 1,
|
||||
dropout = 0.,
|
||||
add_self_attn = True,
|
||||
self_attn_heads = 4,
|
||||
self_attn_dim_head = 32,
|
||||
vit_layer_indices: tuple[int, ...] | None = None,
|
||||
num_advantage_bins = 0,
|
||||
siglip_image_size = 224,
|
||||
siglip_patch_size = 14,
|
||||
siglip_dim = 1152,
|
||||
siglip_depth = 27,
|
||||
siglip_heads = 16,
|
||||
siglip_mlp_dim = 4304,
|
||||
siglip_norm_eps = 1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.vit = SigLIP(
|
||||
image_size = siglip_image_size,
|
||||
patch_size = siglip_patch_size,
|
||||
dim = siglip_dim,
|
||||
depth = siglip_depth,
|
||||
heads = siglip_heads,
|
||||
mlp_dim = siglip_mlp_dim,
|
||||
norm_eps = siglip_norm_eps
|
||||
)
|
||||
|
||||
vit_dim = siglip_dim
|
||||
self.vit_dim = vit_dim
|
||||
|
||||
vit_layer_indices = default(vit_layer_indices, tuple(range(depth)))
|
||||
assert len(vit_layer_indices) == depth, f'number of vit layer indices {len(vit_layer_indices)} does not much the VAT depth {depth}'
|
||||
self.register_buffer('layer_indices', tensor(vit_layer_indices), persistent = False)
|
||||
|
||||
# handle maybe multiple frames
|
||||
|
||||
is_video = time_seq_len > 1
|
||||
self.is_video = is_video
|
||||
self.time_seq_len = time_seq_len
|
||||
self.time_pos_emb = nn.Parameter(torch.randn(time_seq_len, vit_dim) * 1e-2) if is_video else None
|
||||
|
||||
# maybe view embeddings
|
||||
|
||||
self.view_emb = nn.Parameter(torch.randn(num_views, vit_dim) * 1e-2) if exists(num_views) and num_views > 1 else None
|
||||
|
||||
# handle maybe task conditioning
|
||||
|
||||
self.has_tasks = exists(num_tasks)
|
||||
if self.has_tasks:
|
||||
self.task_emb = nn.Parameter(torch.randn(num_tasks, dim) * 1e-2)
|
||||
|
||||
# register tokens
|
||||
|
||||
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
|
||||
|
||||
# to action tokens
|
||||
|
||||
self.action_pos_emb = nn.Parameter(torch.randn(action_chunk_len, dim) * 1e-2)
|
||||
|
||||
# handle maybe advantage conditioning
|
||||
|
||||
self.has_advantages = num_advantage_bins > 0
|
||||
self.num_advantage_bins = num_advantage_bins
|
||||
|
||||
if self.has_advantages:
|
||||
self.advantage_emb = nn.Embedding(num_advantage_bins + 1, dim)
|
||||
|
||||
self.layers = ModuleList([])
|
||||
for _ in range(depth):
|
||||
maybe_film = FiLM(dim = dim) if self.has_tasks else None
|
||||
maybe_self_attn = Attention(dim = dim, heads = self_attn_heads, dim_head = self_attn_dim_head, dropout = dropout) if add_self_attn else None
|
||||
|
||||
self.layers.append(ModuleList([
|
||||
maybe_film,
|
||||
maybe_self_attn,
|
||||
Attention(dim = dim, dim_context = vit_dim, heads = heads, dim_head = dim_head, dropout = dropout, gate_attn = True),
|
||||
FeedForward(dim = dim, dim_inner = mlp_dim)
|
||||
]))
|
||||
|
||||
self.final_norm = nn.LayerNorm(dim)
|
||||
self.to_pred_action = nn.Linear(dim, dim_action, bias = False)
|
||||
|
||||
# handle the extra token
|
||||
|
||||
self.accept_extra_token = exists(dim_extra_token)
|
||||
if exists(dim_extra_token):
|
||||
self.to_extra_token = nn.Linear(dim_extra_token, dim)
|
||||
|
||||
def load_siglip(
|
||||
self,
|
||||
repo_id = 'google/siglip-so400m-patch14-224',
|
||||
folder = 'checkpoints/siglip'
|
||||
):
|
||||
folder = Path(folder)
|
||||
if not folder.exists():
|
||||
from huggingface_hub import snapshot_download
|
||||
snapshot_download(
|
||||
repo_id = repo_id,
|
||||
local_dir = folder,
|
||||
allow_patterns = ['config.json', 'model.safetensors']
|
||||
)
|
||||
|
||||
from safetensors import safe_open
|
||||
weights_path = folder / 'model.safetensors'
|
||||
|
||||
# Auto-detect prefix based on keys
|
||||
with safe_open(weights_path, framework = 'pt') as f:
|
||||
keys = f.keys()
|
||||
|
||||
vi_p = ''
|
||||
if any(k.startswith('paligemma_with_expert.paligemma.model.vision_tower.vision_model') for k in keys):
|
||||
vi_p = 'paligemma_with_expert.paligemma.model.vision_tower.vision_model.'
|
||||
elif any(k.startswith('vision_model') for k in keys):
|
||||
vi_p = 'vision_model.'
|
||||
|
||||
pz_state = self.vit.state_dict()
|
||||
|
||||
def copy_weight_bias(pz_prefix, vi_prefix):
|
||||
pz_state[f'{pz_prefix}.weight'].copy_(f.get_tensor(f'{vi_prefix}.weight'))
|
||||
pz_state[f'{pz_prefix}.bias'].copy_(f.get_tensor(f'{vi_prefix}.bias'))
|
||||
|
||||
# patch embedding
|
||||
patch_weight = rearrange(f.get_tensor(f'{vi_p}embeddings.patch_embedding.weight'), 'd c h w -> d (h w c)')
|
||||
pz_state['to_patch_embed.1.weight'].copy_(patch_weight)
|
||||
pz_state['to_patch_embed.1.bias'].copy_(f.get_tensor(f'{vi_p}embeddings.patch_embedding.bias'))
|
||||
|
||||
# position embedding
|
||||
pz_state['pos_embed'].copy_(f.get_tensor(f'{vi_p}embeddings.position_embedding.weight'))
|
||||
|
||||
# transformer layers
|
||||
for i in range(self.vit.depth):
|
||||
v_pi = f'{vi_p}encoder.layers.{i}'
|
||||
v_pz = f'layers.{i}'
|
||||
|
||||
# attention
|
||||
copy_weight_bias(f'{v_pz}.0.norm', f'{v_pi}.layer_norm1')
|
||||
copy_weight_bias(f'{v_pz}.0.to_q', f'{v_pi}.self_attn.q_proj')
|
||||
|
||||
vk, vv = [f.get_tensor(f'{v_pi}.self_attn.{x}_proj.weight') for x in ('k', 'v')]
|
||||
bk, bv = [f.get_tensor(f'{v_pi}.self_attn.{x}_proj.bias') for x in ('k', 'v')]
|
||||
|
||||
pz_state[f'{v_pz}.0.to_kv.weight'].copy_(cat((vk, vv), dim = 0))
|
||||
pz_state[f'{v_pz}.0.to_kv.bias'].copy_(cat((bk, bv), dim = 0))
|
||||
|
||||
copy_weight_bias(f'{v_pz}.0.to_out.0', f'{v_pi}.self_attn.out_proj')
|
||||
|
||||
# feedforward
|
||||
copy_weight_bias(f'{v_pz}.1.0', f'{v_pi}.layer_norm2')
|
||||
copy_weight_bias(f'{v_pz}.1.1', f'{v_pi}.mlp.fc1')
|
||||
copy_weight_bias(f'{v_pz}.1.3', f'{v_pi}.mlp.fc2')
|
||||
|
||||
# post-layernorm
|
||||
copy_weight_bias('norm', f'{vi_p}post_layernorm')
|
||||
|
||||
self.vit.load_state_dict(pz_state)
|
||||
|
||||
print(f'Successfully loaded SigLIP weights from {repo_id}')
|
||||
|
||||
def forward(
|
||||
self,
|
||||
video_or_image, # (b v? c t? h w) - batch, views [wrist + third person or more], channels, maybe time, height, width
|
||||
*,
|
||||
extra = None, # (b d) - batch, dim extra
|
||||
tasks = None, # (b)
|
||||
advantages = None,# (b)
|
||||
actions = None, # (b k d) - batch, action chunk length, action dimension
|
||||
return_hiddens = False,
|
||||
freeze_vit = False
|
||||
):
|
||||
batch, device = video_or_image.shape[0], video_or_image.device
|
||||
return_loss = exists(actions)
|
||||
|
||||
# handle some various input dimensions
|
||||
|
||||
if video_or_image.ndim == 4:
|
||||
video_or_image = rearrange(video_or_image, 'b 1 c h w')
|
||||
|
||||
if video_or_image.ndim == 5:
|
||||
video_or_image = rearrange(video_or_image, 'b v c h w -> b v c 1 h w')
|
||||
|
||||
assert video_or_image.shape[3] == self.time_seq_len
|
||||
|
||||
# to images
|
||||
|
||||
images = rearrange(video_or_image, 'b v c t h w -> b v t c h w')
|
||||
images, packed_shape = pack([images], '* c h w')
|
||||
|
||||
# get representation trajectory from vit
|
||||
|
||||
vit_forward_context = torch.no_grad if freeze_vit else nullcontext
|
||||
|
||||
with vit_forward_context():
|
||||
embed, hiddens = self.vit(images, return_hiddens = True)
|
||||
|
||||
hiddens = cat((hiddens, embed[None, ...]))
|
||||
|
||||
# extract the hiddens needed for the action cross attention
|
||||
|
||||
hiddens = hiddens[self.layer_indices]
|
||||
|
||||
# pack temporarily for embedding
|
||||
|
||||
hiddens, = unpack(hiddens, packed_shape, 'l * n d') # l for layers
|
||||
|
||||
# maybe add time embeddings
|
||||
|
||||
if self.is_video:
|
||||
time_pos_emb = rearrange(self.time_pos_emb, 't d -> t 1 d')
|
||||
hiddens = hiddens + time_pos_emb
|
||||
|
||||
# maybe view embeddings
|
||||
|
||||
if exists(self.view_emb):
|
||||
view_emb = rearrange(self.view_emb, 'v d -> v 1 1 d')
|
||||
hiddens = hiddens + view_emb
|
||||
|
||||
# maybe tasks
|
||||
|
||||
if exists(tasks):
|
||||
task_emb = self.task_emb[tasks]
|
||||
|
||||
# cross from actions to representation trajectory
|
||||
|
||||
context = rearrange(hiddens, 'l b v t n d -> l b (v t n) d')
|
||||
|
||||
# main action tokens
|
||||
|
||||
action_tokens = repeat(self.action_pos_emb, 'n d -> b n d', b = batch)
|
||||
|
||||
# maybe advantage tokens
|
||||
|
||||
empty_token = action_tokens[:, 0:0]
|
||||
|
||||
maybe_advantage_embed = empty_token
|
||||
|
||||
if self.has_advantages and exists(advantages):
|
||||
if isinstance(advantages, int):
|
||||
advantages = torch.full((batch,), advantages, device = device, dtype = torch.long)
|
||||
|
||||
maybe_advantage_embed = self.advantage_emb(advantages + 1)
|
||||
|
||||
# register tokens
|
||||
|
||||
register_tokens = empty_token
|
||||
|
||||
if exists(self.register_tokens):
|
||||
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
|
||||
|
||||
# extra
|
||||
|
||||
maybe_extra_embed = empty_token
|
||||
|
||||
has_extra = exists(extra)
|
||||
if has_extra:
|
||||
maybe_extra_embed = self.to_extra_token(extra)
|
||||
|
||||
# pack all tokens for attention
|
||||
|
||||
tokens, ps = pack((register_tokens, maybe_advantage_embed, action_tokens, maybe_extra_embed), 'b * d')
|
||||
|
||||
# transformer
|
||||
|
||||
vat_hiddens = [tokens]
|
||||
|
||||
for (maybe_film, maybe_self_attn, cross_attn, ff), layer_context in zip(self.layers, context):
|
||||
|
||||
if exists(maybe_film) and exists(tasks):
|
||||
tokens = maybe_film(tokens, task_emb)
|
||||
|
||||
tokens = cross_attn(tokens, layer_context) + tokens
|
||||
|
||||
if exists(maybe_self_attn):
|
||||
tokens = maybe_self_attn(tokens) + tokens
|
||||
|
||||
tokens = ff(tokens) + tokens
|
||||
|
||||
vat_hiddens.append(tokens)
|
||||
|
||||
# unpack register, advantage, action, and extra tokens
|
||||
|
||||
maybe_register_embed, maybe_advantage_embed, action_tokens, maybe_extra_embed = unpack(tokens, ps, 'b * d')
|
||||
|
||||
# norm and prediction
|
||||
|
||||
action_tokens = self.final_norm(action_tokens)
|
||||
pred_action = self.to_pred_action(action_tokens)
|
||||
|
||||
if not return_loss:
|
||||
if not return_hiddens:
|
||||
return pred_action
|
||||
|
||||
return pred_action, stack(vat_hiddens)
|
||||
|
||||
assert pred_action.shape[1] == actions.shape[1]
|
||||
|
||||
return F.l1_loss(pred_action, actions)
|
||||
|
||||
# quick test
|
||||
|
||||
if __name__ == '__main__':
|
||||
for num_adv_bins in (0, 2, 10):
|
||||
vat = SigLIPVAT(
|
||||
num_tasks = 4,
|
||||
dim_extra_token = 32,
|
||||
time_seq_len = 2,
|
||||
num_views = 2,
|
||||
depth = 4,
|
||||
num_advantage_bins = num_adv_bins,
|
||||
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
|
||||
0, 1, 26, 27
|
||||
)
|
||||
)
|
||||
|
||||
vat.load_siglip() # load siglip weights from hf
|
||||
|
||||
# inputs
|
||||
|
||||
images = torch.randn(1, 2, 3, 2, 224, 224) # (b, v, c, t, h, w)
|
||||
tasks = torch.randint(0, 4, (1,))
|
||||
extra = torch.randn(1, 32)
|
||||
|
||||
# advantage conditioning
|
||||
|
||||
advantages = None
|
||||
if num_adv_bins > 0:
|
||||
advantages = torch.randint(-1, num_adv_bins, (1,))
|
||||
|
||||
actions = torch.randn(1, 50, 32) # actions for learning
|
||||
|
||||
loss = vat(images, actions = actions, advantages = advantages, tasks = tasks, extra = extra, freeze_vit = True)
|
||||
loss.backward()
|
||||
|
||||
# after much training
|
||||
|
||||
pred_actions = vat(images, advantages = advantages, tasks = tasks, extra = extra)
|
||||
|
||||
assert pred_actions.shape == (1, 50, 32)
|
||||
@@ -113,7 +113,7 @@ class ViT(Module):
|
||||
self.pool = pool
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Linear(dim, num_classes)
|
||||
self.mlp_head = nn.Linear(dim, num_classes) if num_classes > 0 else None
|
||||
|
||||
def forward(self, img):
|
||||
batch = img.shape[0]
|
||||
@@ -129,6 +129,9 @@ class ViT(Module):
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
if self.mlp_head is None:
|
||||
return x
|
||||
|
||||
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
x = self.to_latent(x)
|
||||
|
||||
@@ -31,7 +31,7 @@ class FeedForward(Module):
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
@@ -40,31 +40,31 @@ class Attention(Module):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
@@ -79,7 +79,7 @@ class Transformer(Module):
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
@@ -105,73 +105,73 @@ class ViTND(Module):
|
||||
emb_dropout: float = 0.
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
assert 1 <= ndim <= 7, 'ndim must be between 1 and 7'
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
|
||||
|
||||
self.ndim = ndim
|
||||
self.pool = pool
|
||||
|
||||
|
||||
input_shape = ensure_tuple(input_shape, ndim)
|
||||
patch_size = ensure_tuple(patch_size, ndim)
|
||||
|
||||
|
||||
for i, (inp_dim, patch_dim) in enumerate(zip(input_shape, patch_size)):
|
||||
assert inp_dim % patch_dim == 0, f'Input dimension {i} ({inp_dim}) must be divisible by patch size ({patch_dim})'
|
||||
|
||||
|
||||
num_patches_per_dim = [inp_dim // patch_dim for inp_dim, patch_dim in zip(input_shape, patch_size)]
|
||||
num_patches = 1
|
||||
for n in num_patches_per_dim:
|
||||
num_patches *= n
|
||||
|
||||
|
||||
patch_dim = channels
|
||||
for p in patch_size:
|
||||
patch_dim *= p
|
||||
|
||||
|
||||
dim_names = 'fghijkl'[:ndim]
|
||||
|
||||
|
||||
input_dims = [f'({d} p{i})' for i, d in enumerate(dim_names)]
|
||||
patch_dims = [f'p{i}' for i in range(ndim)]
|
||||
|
||||
|
||||
input_pattern = f'b c {join(input_dims)}'
|
||||
output_pattern = f'b ({join(dim_names)}) ({join(patch_dims)} c)'
|
||||
rearrange_str = f'{input_pattern} -> {output_pattern}'
|
||||
|
||||
|
||||
rearrange_kwargs = {f'p{i}': p for i, p in enumerate(patch_size)}
|
||||
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange(rearrange_str, **rearrange_kwargs),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
|
||||
self.to_latent = nn.Identity()
|
||||
self.mlp_head = nn.Linear(dim, num_classes)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
x = self.to_patch_embedding(x)
|
||||
b, n, _ = x.shape
|
||||
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim = 1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.dropout(x)
|
||||
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
|
||||
x = x[:, 1:].mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.mlp_head(x)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
|
||||
model = ViTND(
|
||||
ndim = 4,
|
||||
input_shape = (8, 16, 32, 64),
|
||||
@@ -185,7 +185,7 @@ if __name__ == '__main__':
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
|
||||
occupancy_time = torch.randn(2, 3, 8, 16, 32, 64)
|
||||
|
||||
|
||||
logits = model(occupancy_time)
|
||||
|
||||
353
vit_pytorch/vit_nd_pope.py
Normal file
353
vit_pytorch/vit_nd_pope.py
Normal file
@@ -0,0 +1,353 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import pi, nn, arange, cat, stack, Tensor
|
||||
from torch.nn import Module, ModuleList
|
||||
from torch.amp import autocast
|
||||
|
||||
from einops import rearrange, repeat, reduce, pack, unpack
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def l2norm(t):
|
||||
return F.normalize(t, dim = -1, p = 2)
|
||||
|
||||
def join(arr, delimiter = ' '):
|
||||
return delimiter.join(arr)
|
||||
|
||||
def ensure_tuple(t, length):
|
||||
if isinstance(t, (tuple, list)):
|
||||
assert len(t) == length, f'Expected tuple of length {length}, got {len(t)}'
|
||||
return tuple(t)
|
||||
|
||||
return (t,) * length
|
||||
|
||||
# golden gate rotary - Jerry Xiong, PhD student at UIUC
|
||||
# https://jerryxio.ng/posts/nd-rope/
|
||||
|
||||
# but using polar version instead
|
||||
# Gopalakrishnan et al. https://arxiv.org/abs/2509.10534
|
||||
|
||||
def _phi(m: int) -> float:
|
||||
x = 2.0
|
||||
for _ in range(10):
|
||||
x = (1 + x) ** (1.0 / (m + 1.0))
|
||||
return x
|
||||
|
||||
def make_directions(n: int, d: int) -> Tensor:
|
||||
g = _phi(d)
|
||||
alpha = (1.0 / g) ** arange(1, d + 1, dtype = torch.float64)
|
||||
i = arange(1, n + 1, dtype = torch.float64).unsqueeze(1)
|
||||
z = torch.fmod(i * alpha, 1.0)
|
||||
directions = torch.erfinv(2.0 * z - 1.0)
|
||||
directions = l2norm(directions)
|
||||
return directions.float()
|
||||
|
||||
class GoldenGatePoPENd(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim_pos: int,
|
||||
heads: int,
|
||||
dim_head: int,
|
||||
min_freq: float = 1.0,
|
||||
max_freq: float = 10000.0,
|
||||
p_zero_freqs: float = 0.0, # proportion of frequencies set to 0
|
||||
init_learned_bias_uniform = False
|
||||
):
|
||||
super().__init__()
|
||||
n_freqs = dim_head
|
||||
n_zero_freqs = round(p_zero_freqs * n_freqs)
|
||||
|
||||
omega = cat((
|
||||
torch.zeros(n_zero_freqs),
|
||||
min_freq * (max_freq / min_freq) ** torch.linspace(0, 1, n_freqs - n_zero_freqs),
|
||||
))
|
||||
|
||||
directions = rearrange(
|
||||
make_directions(heads * n_freqs, dim_pos),
|
||||
'(h f) p -> h f p',
|
||||
h = heads
|
||||
)
|
||||
|
||||
omega_expanded = rearrange(omega, 'f -> f 1')
|
||||
self.register_buffer('freqs', directions * omega_expanded) # shape: (h, f, p)
|
||||
|
||||
self.learned_bias = nn.Parameter(torch.zeros(heads, dim_head))
|
||||
|
||||
if init_learned_bias_uniform:
|
||||
self.learned_bias.uniform_(-2. * pi, 0.)
|
||||
|
||||
@autocast('cuda', enabled = False)
|
||||
def forward(self, pos):
|
||||
|
||||
freqs = rearrange(self.freqs, 'h f p -> 1 h 1 f p')
|
||||
positions = rearrange(pos.float(), 'b n p -> b 1 n 1 p')
|
||||
|
||||
# compute theta for each (batch, head, seq, freq)
|
||||
|
||||
theta = reduce(freqs * positions, 'b h n f p -> b h n f', 'sum')
|
||||
|
||||
bias = self.learned_bias.clamp(-2. * pi, 0.)
|
||||
bias = rearrange(bias, 'h d -> h 1 d')
|
||||
|
||||
return theta, bias
|
||||
|
||||
@autocast('cuda', enabled = False)
|
||||
def apply_polar_pos_emb(t, freqs):
|
||||
orig_dtype = t.dtype
|
||||
|
||||
t = t.float()
|
||||
t = F.softplus(t)
|
||||
|
||||
out = cat((t * freqs.cos(), t * freqs.sin()), dim = -1)
|
||||
|
||||
return out.type(orig_dtype)
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qk = nn.Linear(dim, inner_dim * 2, bias = False)
|
||||
self.to_v = nn.Linear(dim, inner_dim, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x, polar_pos_emb = None):
|
||||
x = self.norm(x)
|
||||
qkv = (*self.to_qk(x).chunk(2, dim = -1), self.to_v(x))
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
if exists(polar_pos_emb):
|
||||
freqs, bias = polar_pos_emb
|
||||
q = apply_polar_pos_emb(q, freqs)
|
||||
k = apply_polar_pos_emb(k, freqs + bias)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., polar_emb = None):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.polar_emb = polar_emb
|
||||
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
def forward(self, x, pos = None):
|
||||
|
||||
# pope embedding
|
||||
|
||||
polar_pos_emb = None
|
||||
if exists(pos) and exists(self.polar_emb):
|
||||
polar_pos_emb = self.polar_emb(pos)
|
||||
|
||||
# transformer layers
|
||||
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x, polar_pos_emb) + x
|
||||
x = ff(x) + x
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class ViTND(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ndim: int,
|
||||
input_shape: int | tuple[int, ...],
|
||||
patch_size: int | tuple[int, ...],
|
||||
num_classes: int,
|
||||
dim: int,
|
||||
depth: int,
|
||||
heads: int,
|
||||
mlp_dim: int,
|
||||
channels: int = 3,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.,
|
||||
emb_dropout: float = 0.,
|
||||
pope_min_freq: float = 1.0,
|
||||
pope_max_freq: float = 10000.0,
|
||||
pope_p_zero_freqs: float = 0.0,
|
||||
pope_init_learned_bias_uniform = False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert 1 <= ndim <= 7, 'ndim must be between 1 and 7'
|
||||
|
||||
self.ndim = ndim
|
||||
|
||||
input_shape = ensure_tuple(input_shape, ndim)
|
||||
patch_size = ensure_tuple(patch_size, ndim)
|
||||
|
||||
for i, (inp_dim, patch_dim) in enumerate(zip(input_shape, patch_size)):
|
||||
assert inp_dim % patch_dim == 0, f'Input dimension {i} ({inp_dim}) must be divisible by patch size ({patch_dim})'
|
||||
|
||||
num_patches_per_dim = [inp_dim // patch_dim for inp_dim, patch_dim in zip(input_shape, patch_size)]
|
||||
num_patches = 1
|
||||
for n in num_patches_per_dim:
|
||||
num_patches *= n
|
||||
|
||||
patch_dim = channels
|
||||
for p in patch_size:
|
||||
patch_dim *= p
|
||||
|
||||
dim_names = 'fghijkl'[:ndim]
|
||||
|
||||
input_dims = [f'({d} p{i})' for i, d in enumerate(dim_names)]
|
||||
patch_dims = [f'p{i}' for i in range(ndim)]
|
||||
|
||||
input_pattern = f'b c {join(input_dims)}'
|
||||
output_pattern = f'b {join(dim_names)} ({join(patch_dims)} c)'
|
||||
rearrange_str = f'{input_pattern} -> {output_pattern}'
|
||||
|
||||
rearrange_kwargs = {f'p{i}': p for i, p in enumerate(patch_size)}
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange(rearrange_str, **rearrange_kwargs),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
# golden gate pope
|
||||
|
||||
self.polar_emb = GoldenGatePoPENd(
|
||||
dim_pos = ndim,
|
||||
heads = heads,
|
||||
dim_head = dim_head,
|
||||
min_freq = pope_min_freq,
|
||||
max_freq = pope_max_freq,
|
||||
p_zero_freqs = pope_p_zero_freqs,
|
||||
init_learned_bias_uniform = pope_init_learned_bias_uniform
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, polar_emb = self.polar_emb)
|
||||
|
||||
self.to_latent = nn.Identity()
|
||||
self.mlp_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def muon_parameters(self):
|
||||
params = []
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, Attention):
|
||||
params.extend([
|
||||
m.to_v.weight,
|
||||
m.to_out[0].weight
|
||||
])
|
||||
elif isinstance(m, FeedForward):
|
||||
params.extend([
|
||||
m.net[1].weight,
|
||||
m.net[-2].weight
|
||||
])
|
||||
|
||||
return params
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
return_embed = False
|
||||
):
|
||||
x = self.to_patch_embedding(x) # (b, *spatial_dims, patch_dim)
|
||||
|
||||
batch, *spatial_dims, _, device = *x.shape, x.device
|
||||
|
||||
# Generate position coordinates
|
||||
|
||||
grids = [arange(d, device = device, dtype = torch.float32) for d in spatial_dims]
|
||||
grid = torch.meshgrid(*grids, indexing = 'ij')
|
||||
pos = stack(grid, dim = -1) # (*spatial_dims, ndim)
|
||||
|
||||
# flatten spatial dimensions for attention with nd rotary
|
||||
|
||||
pos = repeat(pos, '... p -> b (...) p', b = batch)
|
||||
x, packed_shape = pack([x], 'b * d')
|
||||
|
||||
x = self.dropout(x)
|
||||
|
||||
embed = self.transformer(x, pos)
|
||||
|
||||
# return the embed with reconstituted patch shape
|
||||
|
||||
if return_embed:
|
||||
embed, = unpack(embed, packed_shape, 'b * d')
|
||||
return embed
|
||||
|
||||
# pooling to logits
|
||||
|
||||
pooled = reduce(embed, 'b n d -> b d', 'mean')
|
||||
|
||||
pooled = self.to_latent(pooled)
|
||||
return self.mlp_head(pooled)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
model = ViTND(
|
||||
ndim = 5,
|
||||
input_shape = (4, 8, 16, 32, 64),
|
||||
patch_size = (2, 2, 4, 4, 8),
|
||||
num_classes = 1000,
|
||||
dim = 512,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048,
|
||||
channels = 3,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
data = torch.randn(3, 3, 4, 8, 16, 32, 64)
|
||||
|
||||
logits = model(data)
|
||||
|
||||
embed = model(data, return_embed = True)
|
||||
assert embed.shape == (3, 2, 4, 4, 8, 8, 512)
|
||||
@@ -75,23 +75,23 @@ class GoldenGateRoPENd(Module):
|
||||
# input shape: (b, h, n, d) where d = head_dim
|
||||
# pos shape: (b, n, p) where p = pos_dim
|
||||
# self.freqs shape: (h, f, p) where f = d // 2
|
||||
|
||||
|
||||
x, y = input.float().chunk(2, dim = -1) # both (b, h, n, f)
|
||||
|
||||
|
||||
# Expand dimensions for broadcasting
|
||||
freqs = rearrange(self.freqs, 'h f p -> 1 h 1 f p')
|
||||
positions = rearrange(pos.float(), 'b n p -> b 1 n 1 p')
|
||||
|
||||
|
||||
# Compute theta for each (batch, head, seq, freq)
|
||||
theta = reduce(freqs * positions, 'b h n f p -> b h n f', 'sum')
|
||||
|
||||
|
||||
cos_theta = torch.cos(theta)
|
||||
sin_theta = torch.sin(theta)
|
||||
|
||||
|
||||
# Apply rotation
|
||||
x_out = x * cos_theta - y * sin_theta
|
||||
y_out = x * sin_theta + y * cos_theta
|
||||
|
||||
|
||||
output = cat((x_out, y_out), dim=-1)
|
||||
return output.type_as(input)
|
||||
|
||||
@@ -108,7 +108,7 @@ class FeedForward(Module):
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
@@ -117,15 +117,15 @@ class Attention(Module):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.rotary_emb = rotary_emb
|
||||
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
|
||||
self.to_qk = nn.Linear(dim, inner_dim * 2, bias = False)
|
||||
self.to_v = nn.Linear(dim, inner_dim, bias = False)
|
||||
|
||||
@@ -133,24 +133,24 @@ class Attention(Module):
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
|
||||
def forward(self, x, pos = None):
|
||||
x = self.norm(x)
|
||||
qkv = (*self.to_qk(x).chunk(2, dim = -1), self.to_v(x))
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
|
||||
# Apply rotary embeddings if available
|
||||
if exists(self.rotary_emb):
|
||||
assert exists(pos)
|
||||
q = self.rotary_emb(q, pos)
|
||||
k = self.rotary_emb(k, pos)
|
||||
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
@@ -165,7 +165,7 @@ class Transformer(Module):
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, rotary_emb = rotary_emb),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
|
||||
def forward(self, x, pos = None):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x, pos) + x
|
||||
@@ -193,45 +193,45 @@ class ViTND(Module):
|
||||
rope_p_zero_freqs: float = 0.0
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
assert 1 <= ndim <= 7, 'ndim must be between 1 and 7'
|
||||
|
||||
|
||||
self.ndim = ndim
|
||||
|
||||
|
||||
input_shape = ensure_tuple(input_shape, ndim)
|
||||
patch_size = ensure_tuple(patch_size, ndim)
|
||||
|
||||
|
||||
for i, (inp_dim, patch_dim) in enumerate(zip(input_shape, patch_size)):
|
||||
assert inp_dim % patch_dim == 0, f'Input dimension {i} ({inp_dim}) must be divisible by patch size ({patch_dim})'
|
||||
|
||||
|
||||
num_patches_per_dim = [inp_dim // patch_dim for inp_dim, patch_dim in zip(input_shape, patch_size)]
|
||||
num_patches = 1
|
||||
for n in num_patches_per_dim:
|
||||
num_patches *= n
|
||||
|
||||
|
||||
patch_dim = channels
|
||||
for p in patch_size:
|
||||
patch_dim *= p
|
||||
|
||||
|
||||
dim_names = 'fghijkl'[:ndim]
|
||||
|
||||
|
||||
input_dims = [f'({d} p{i})' for i, d in enumerate(dim_names)]
|
||||
patch_dims = [f'p{i}' for i in range(ndim)]
|
||||
|
||||
|
||||
input_pattern = f'b c {join(input_dims)}'
|
||||
output_pattern = f'b {join(dim_names)} ({join(patch_dims)} c)'
|
||||
rearrange_str = f'{input_pattern} -> {output_pattern}'
|
||||
|
||||
|
||||
rearrange_kwargs = {f'p{i}': p for i, p in enumerate(patch_size)}
|
||||
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange(rearrange_str, **rearrange_kwargs),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
|
||||
# Create rotary embeddings
|
||||
self.rotary_emb = GoldenGateRoPENd(
|
||||
dim_pos = ndim,
|
||||
@@ -241,12 +241,12 @@ class ViTND(Module):
|
||||
rope_max_freq = rope_max_freq,
|
||||
rope_p_zero_freqs = rope_p_zero_freqs
|
||||
)
|
||||
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, rotary_emb = self.rotary_emb)
|
||||
|
||||
|
||||
self.to_latent = nn.Identity()
|
||||
self.mlp_head = nn.Linear(dim, num_classes)
|
||||
|
||||
|
||||
def muon_parameters(self):
|
||||
params = []
|
||||
|
||||
@@ -270,9 +270,9 @@ class ViTND(Module):
|
||||
return_embed = False
|
||||
):
|
||||
x = self.to_patch_embedding(x) # (b, *spatial_dims, patch_dim)
|
||||
|
||||
|
||||
batch, *spatial_dims, _, device = *x.shape, x.device
|
||||
|
||||
|
||||
# Generate position coordinates
|
||||
|
||||
grids = [arange(d, device = device, dtype = torch.float32) for d in spatial_dims]
|
||||
@@ -280,12 +280,12 @@ class ViTND(Module):
|
||||
pos = stack(grid, dim = -1) # (*spatial_dims, ndim)
|
||||
|
||||
# flatten spatial dimensions for attention with nd rotary
|
||||
|
||||
|
||||
pos = repeat(pos, '... p -> b (...) p', b = batch)
|
||||
x, packed_shape = pack([x], 'b * d')
|
||||
|
||||
x = self.dropout(x)
|
||||
|
||||
|
||||
embed = self.transformer(x, pos)
|
||||
|
||||
# return the embed with reconstituted patch shape
|
||||
@@ -303,7 +303,7 @@ class ViTND(Module):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
|
||||
model = ViTND(
|
||||
ndim = 5,
|
||||
input_shape = (4, 8, 16, 32, 64),
|
||||
|
||||
@@ -231,4 +231,4 @@ if __name__ == '__main__':
|
||||
|
||||
decorr_loss = DecorrelationLoss(0.0001, soft_validate_num_sampled = True)
|
||||
out = decorr_loss(hiddens)
|
||||
assert out.item() == 0
|
||||
assert out.item() == 0
|
||||
|
||||
217
vit_pytorch/vit_with_keel_post_ln.py
Normal file
217
vit_pytorch/vit_with_keel_post_ln.py
Normal file
@@ -0,0 +1,217 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Module, ModuleList
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# functions
|
||||
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
def default(v, d):
|
||||
return v if exists(v) else d
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim, bias = False),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.norm = nn.LayerNorm(dim, bias = False)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
depth,
|
||||
heads,
|
||||
dim_head,
|
||||
mlp_dim,
|
||||
dropout = 0.,
|
||||
keel_residual_scale = None
|
||||
):
|
||||
super().__init__()
|
||||
assert depth > 1
|
||||
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.extend([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
])
|
||||
|
||||
num_layers = depth * 2
|
||||
self.keel_residual_scale = default(keel_residual_scale, num_layers)
|
||||
|
||||
self.post_norms = ModuleList([nn.LayerNorm(dim, bias = False) for _ in range(num_layers - 1)])
|
||||
|
||||
def forward(self, x):
|
||||
residual_scale = self.keel_residual_scale
|
||||
|
||||
for layer_ind, layer in enumerate(self.layers):
|
||||
first_layer = layer_ind == 0
|
||||
|
||||
residual = x
|
||||
|
||||
out = layer(x)
|
||||
|
||||
if first_layer:
|
||||
x = out + residual
|
||||
continue
|
||||
|
||||
post_norm = self.post_norms[layer_ind - 1]
|
||||
|
||||
x = post_norm(out + residual * residual_scale)
|
||||
|
||||
return x
|
||||
|
||||
class ViT(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_size,
|
||||
patch_size,
|
||||
num_classes,
|
||||
dim,
|
||||
depth,
|
||||
heads,
|
||||
mlp_dim,
|
||||
pool = 'cls',
|
||||
channels = 3,
|
||||
dim_head = 64,
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.,
|
||||
keel_residual_scale = None
|
||||
):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
num_cls_tokens = 1 if pool == 'cls' else 0
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.cls_token = nn.Parameter(torch.randn(num_cls_tokens, dim))
|
||||
self.pos_embedding = nn.Parameter(torch.randn(num_patches + num_cls_tokens, dim))
|
||||
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(
|
||||
dim,
|
||||
depth,
|
||||
heads,
|
||||
dim_head,
|
||||
mlp_dim,
|
||||
dropout,
|
||||
keel_residual_scale = keel_residual_scale
|
||||
)
|
||||
|
||||
self.pool = pool
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Linear(dim, num_classes) if num_classes > 0 else None
|
||||
|
||||
def forward(self, img):
|
||||
batch = img.shape[0]
|
||||
x = self.to_patch_embedding(img)
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '... d -> b ... d', b = batch)
|
||||
x = torch.cat((cls_tokens, x), dim = 1)
|
||||
|
||||
seq = x.shape[1]
|
||||
|
||||
x = x + self.pos_embedding[:seq]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
if not exists(self.mlp_head):
|
||||
return x
|
||||
|
||||
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.mlp_head(x)
|
||||
|
||||
if __name__ == '__main__':
|
||||
v = ViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 256, 256)
|
||||
|
||||
preds = v(img)
|
||||
|
||||
assert preds.shape == (1, 1000)
|
||||
@@ -1,5 +1,10 @@
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch import nn, cat
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Module, ModuleList
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
|
||||
from einops import rearrange, repeat, reduce
|
||||
from einops.layers.torch import Rearrange
|
||||
@@ -9,12 +14,15 @@ from einops.layers.torch import Rearrange
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def divisible_by(num, den):
|
||||
return (num % den) == 0
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
class FeedForward(Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
@@ -28,9 +36,11 @@ class FeedForward(nn.Module):
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., use_flash_attn = True):
|
||||
super().__init__()
|
||||
self.use_flash_attn = use_flash_attn
|
||||
self.dropout_p = dropout
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
@@ -48,61 +58,100 @@ class Attention(nn.Module):
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
def flash_attn(self, q, k, v, mask = None):
|
||||
|
||||
with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION]):
|
||||
|
||||
out = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask = mask,
|
||||
dropout_p = self.dropout_p,
|
||||
is_causal = False,
|
||||
scale = self.scale
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
def forward(self, x, mask = None):
|
||||
batch, seq, _ = x.shape
|
||||
|
||||
x = self.norm(x)
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b j -> b 1 1 j')
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
if self.use_flash_attn:
|
||||
out = self.flash_attn(q, k, v, mask = mask)
|
||||
|
||||
else:
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
if exists(mask):
|
||||
dots = dots.masked_fill(~mask, -torch.finfo(dots.dtype).max)
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., use_flash_attn = True):
|
||||
super().__init__()
|
||||
self.use_flash_attn = use_flash_attn
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = nn.ModuleList([])
|
||||
self.layers = ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_flash_attn = use_flash_attn),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
def forward(self, x):
|
||||
|
||||
def forward(self, x, mask = None):
|
||||
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = attn(x, mask = mask) + x
|
||||
x = ff(x) + x
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class FactorizedTransformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
class FactorizedTransformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., use_flash_attn = True):
|
||||
super().__init__()
|
||||
self.use_flash_attn = use_flash_attn
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_flash_attn = use_flash_attn),
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_flash_attn = use_flash_attn),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
b, f, n, _ = x.shape
|
||||
def forward(self, x, mask = None):
|
||||
batch, frames, seq, _ = x.shape
|
||||
|
||||
if exists(mask):
|
||||
mask = repeat(mask, 'b ... -> (b space) ...', space = x.shape[2])
|
||||
|
||||
for spatial_attn, temporal_attn, ff in self.layers:
|
||||
x = rearrange(x, 'b f n d -> (b f) n d')
|
||||
x = spatial_attn(x) + x
|
||||
x = rearrange(x, '(b f) n d -> (b n) f d', b=b, f=f)
|
||||
x = temporal_attn(x) + x
|
||||
x = rearrange(x, '(b f) n d -> (b n) f d', b = batch, f = frames)
|
||||
x = temporal_attn(x, mask = mask) + x
|
||||
x = ff(x) + x
|
||||
x = rearrange(x, '(b n) f d -> b f n d', b=b, n=n)
|
||||
x = rearrange(x, '(b n) f d -> b f n d', b = batch, n = seq)
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class ViT(nn.Module):
|
||||
class ViViT(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -122,13 +171,14 @@ class ViT(nn.Module):
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.,
|
||||
variant = 'factorized_encoder',
|
||||
use_flash_attn: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(image_patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size'
|
||||
assert divisible_by(image_height, patch_height) and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
assert divisible_by(frames, frame_patch_size), 'Frames must be divisible by frame patch size'
|
||||
assert variant in ('factorized_encoder', 'factorized_self_attention'), f'variant = {variant} is not implemented'
|
||||
|
||||
num_image_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
@@ -138,6 +188,8 @@ class ViT(nn.Module):
|
||||
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
|
||||
self.frame_patch_size = frame_patch_size
|
||||
|
||||
self.global_average_pool = pool == 'mean'
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
@@ -154,11 +206,11 @@ class ViT(nn.Module):
|
||||
|
||||
if variant == 'factorized_encoder':
|
||||
self.temporal_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None
|
||||
self.spatial_transformer = Transformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)
|
||||
self.temporal_transformer = Transformer(dim, temporal_depth, heads, dim_head, mlp_dim, dropout)
|
||||
self.spatial_transformer = Transformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout, use_flash_attn)
|
||||
self.temporal_transformer = Transformer(dim, temporal_depth, heads, dim_head, mlp_dim, dropout, use_flash_attn)
|
||||
elif variant == 'factorized_self_attention':
|
||||
assert spatial_depth == temporal_depth, 'Spatial and temporal depth must be the same for factorized self-attention'
|
||||
self.factorized_transformer = FactorizedTransformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)
|
||||
self.factorized_transformer = FactorizedTransformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout, use_flash_attn)
|
||||
|
||||
self.pool = pool
|
||||
self.to_latent = nn.Identity()
|
||||
@@ -166,25 +218,36 @@ class ViT(nn.Module):
|
||||
self.mlp_head = nn.Linear(dim, num_classes)
|
||||
self.variant = variant
|
||||
|
||||
def forward(self, video):
|
||||
x = self.to_patch_embedding(video)
|
||||
b, f, n, _ = x.shape
|
||||
def forward(self, video, mask = None):
|
||||
device = video.device
|
||||
|
||||
x = x + self.pos_embedding[:, :f, :n]
|
||||
x = self.to_patch_embedding(video)
|
||||
batch, frames, seq, _ = x.shape
|
||||
|
||||
x = x + self.pos_embedding[:, :frames, :seq]
|
||||
|
||||
if exists(self.spatial_cls_token):
|
||||
spatial_cls_tokens = repeat(self.spatial_cls_token, '1 1 d -> b f 1 d', b = b, f = f)
|
||||
x = torch.cat((spatial_cls_tokens, x), dim = 2)
|
||||
spatial_cls_tokens = repeat(self.spatial_cls_token, '1 1 d -> b f 1 d', b = batch, f = frames)
|
||||
x = cat((spatial_cls_tokens, x), dim = 2)
|
||||
|
||||
x = self.dropout(x)
|
||||
|
||||
# maybe temporal mask
|
||||
|
||||
temporal_mask = None
|
||||
|
||||
if exists(mask):
|
||||
temporal_mask = reduce(mask, 'b (f patch) -> b f', 'all', patch = self.frame_patch_size)
|
||||
|
||||
# the two variants
|
||||
|
||||
if self.variant == 'factorized_encoder':
|
||||
x = rearrange(x, 'b f n d -> (b f) n d')
|
||||
|
||||
# attend across space
|
||||
|
||||
x = self.spatial_transformer(x)
|
||||
x = rearrange(x, '(b f) n d -> b f n d', b = b)
|
||||
x = rearrange(x, '(b f) n d -> b f n d', b = batch)
|
||||
|
||||
# excise out the spatial cls tokens or average pool for temporal attention
|
||||
|
||||
@@ -193,22 +256,50 @@ class ViT(nn.Module):
|
||||
# append temporal CLS tokens
|
||||
|
||||
if exists(self.temporal_cls_token):
|
||||
temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = b)
|
||||
temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = batch)
|
||||
|
||||
x = torch.cat((temporal_cls_tokens, x), dim = 1)
|
||||
|
||||
x = cat((temporal_cls_tokens, x), dim = 1)
|
||||
|
||||
if exists(temporal_mask):
|
||||
temporal_mask = F.pad(temporal_mask, (1, 0), value = True)
|
||||
|
||||
# attend across time
|
||||
|
||||
x = self.temporal_transformer(x)
|
||||
x = self.temporal_transformer(x, mask = temporal_mask)
|
||||
|
||||
# excise out temporal cls token or average pool
|
||||
|
||||
x = x[:, 0] if not self.global_average_pool else reduce(x, 'b f d -> b d', 'mean')
|
||||
|
||||
elif self.variant == 'factorized_self_attention':
|
||||
x = self.factorized_transformer(x)
|
||||
|
||||
x = self.factorized_transformer(x, mask = temporal_mask)
|
||||
|
||||
x = x[:, 0, 0] if not self.global_average_pool else reduce(x, 'b f n d -> b d', 'mean')
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.mlp_head(x)
|
||||
|
||||
# main
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
vivit = ViViT(
|
||||
dim = 512,
|
||||
spatial_depth = 2,
|
||||
temporal_depth = 2,
|
||||
heads = 4,
|
||||
mlp_dim = 2048,
|
||||
image_size = 256,
|
||||
image_patch_size = 16,
|
||||
frames = 8,
|
||||
frame_patch_size = 2,
|
||||
num_classes = 1000,
|
||||
variant = 'factorized_encoder',
|
||||
)
|
||||
|
||||
video = torch.randn(3, 3, 8, 256, 256)
|
||||
mask = torch.randint(0, 2, (3, 8)).bool()
|
||||
|
||||
logits = vivit(video, mask = None)
|
||||
assert logits.shape == (3, 1000)
|
||||
|
||||
Reference in New Issue
Block a user