mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2026-05-15 12:54:12 +00:00
Compare commits
14 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
93df0e6046 | ||
|
|
8e104e9afc | ||
|
|
3f03aa3994 | ||
|
|
2da1b45b9b | ||
|
|
7ab07c2499 | ||
|
|
dea6b0da56 | ||
|
|
13284b7af1 | ||
|
|
7e18d0302e | ||
|
|
b80676e09c | ||
|
|
fc1e727428 | ||
|
|
6032a54b48 | ||
|
|
06a1f42924 | ||
|
|
6ae6a3ab64 | ||
|
|
827300beed |
3
.github/FUNDING.yml
vendored
3
.github/FUNDING.yml
vendored
@@ -1,3 +0,0 @@
|
||||
# These are supported funding model platforms
|
||||
|
||||
github: [lucidrains]
|
||||
36
.github/workflows/python-publish.yml
vendored
36
.github/workflows/python-publish.yml
vendored
@@ -1,36 +0,0 @@
|
||||
# This workflow will upload a Python Package using Twine when a release is created
|
||||
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
|
||||
|
||||
# This workflow uses actions that are not certified by GitHub.
|
||||
# They are provided by a third-party and are governed by
|
||||
# separate terms of service, privacy policy, and support
|
||||
# documentation.
|
||||
|
||||
name: Upload Python Package
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.x'
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install build
|
||||
- name: Build package
|
||||
run: python -m build
|
||||
- name: Publish package
|
||||
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
|
||||
with:
|
||||
user: __token__
|
||||
password: ${{ secrets.PYPI_API_TOKEN }}
|
||||
34
.github/workflows/python-test.yml
vendored
34
.github/workflows/python-test.yml
vendored
@@ -1,34 +0,0 @@
|
||||
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
|
||||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
|
||||
|
||||
name: Test
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.8, 3.9]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cpu
|
||||
python -m pip install -e .
|
||||
python -m pip install pytest
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
pytest -q
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -127,3 +127,6 @@ dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# scripts
|
||||
*.sh
|
||||
|
||||
148
README.md
148
README.md
@@ -90,26 +90,26 @@ preds = v(img) # (1, 1000)
|
||||
|
||||
## Parameters
|
||||
|
||||
- `image_size`: int.
|
||||
- `image_size`: int.
|
||||
Image size. If you have rectangular images, make sure your image size is the maximum of the width and height
|
||||
- `patch_size`: int.
|
||||
Size of patches. `image_size` must be divisible by `patch_size`.
|
||||
- `patch_size`: int.
|
||||
Size of patches. `image_size` must be divisible by `patch_size`.
|
||||
The number of patches is: ` n = (image_size // patch_size) ** 2` and `n` **must be greater than 16**.
|
||||
- `num_classes`: int.
|
||||
- `num_classes`: int.
|
||||
Number of classes to classify.
|
||||
- `dim`: int.
|
||||
- `dim`: int.
|
||||
Last dimension of output tensor after linear transformation `nn.Linear(..., dim)`.
|
||||
- `depth`: int.
|
||||
- `depth`: int.
|
||||
Number of Transformer blocks.
|
||||
- `heads`: int.
|
||||
Number of heads in Multi-head Attention layer.
|
||||
- `mlp_dim`: int.
|
||||
Dimension of the MLP (FeedForward) layer.
|
||||
- `channels`: int, default `3`.
|
||||
Number of image's channels.
|
||||
- `dropout`: float between `[0, 1]`, default `0.`.
|
||||
Dropout rate.
|
||||
- `emb_dropout`: float between `[0, 1]`, default `0`.
|
||||
- `heads`: int.
|
||||
Number of heads in Multi-head Attention layer.
|
||||
- `mlp_dim`: int.
|
||||
Dimension of the MLP (FeedForward) layer.
|
||||
- `channels`: int, default `3`.
|
||||
Number of image's channels.
|
||||
- `dropout`: float between `[0, 1]`, default `0.`.
|
||||
Dropout rate.
|
||||
- `emb_dropout`: float between `[0, 1]`, default `0`.
|
||||
Embedding dropout rate.
|
||||
- `pool`: string, either `cls` token pooling or `mean` pooling
|
||||
|
||||
@@ -972,7 +972,7 @@ torch.save(model.state_dict(), './pretrained-net.pt')
|
||||
|
||||
<img src="./images/mp3.png" width="400px"></img>
|
||||
|
||||
New <a href="https://arxiv.org/abs/2207.07611">paper</a> that introduces masked position prediction pre-training criteria. This strategy is more efficient than the Masked Autoencoder strategy and has comparable performance.
|
||||
New <a href="https://arxiv.org/abs/2207.07611">paper</a> that introduces masked position prediction pre-training criteria. This strategy is more efficient than the Masked Autoencoder strategy and has comparable performance.
|
||||
|
||||
```python
|
||||
import torch
|
||||
@@ -1361,7 +1361,7 @@ learner = Dino(
|
||||
num_classes_K = 65536, # output logits dimensions (referenced as K in paper)
|
||||
student_temp = 0.9, # student temperature
|
||||
teacher_temp = 0.04, # teacher temperature, needs to be annealed from 0.04 to 0.07 over 30 epochs
|
||||
local_upper_crop_scale = 0.4, # upper bound for local crop - 0.4 was recommended in the paper
|
||||
local_upper_crop_scale = 0.4, # upper bound for local crop - 0.4 was recommended in the paper
|
||||
global_lower_crop_scale = 0.5, # lower bound for global crop - 0.5 was recommended in the paper
|
||||
moving_average_decay = 0.9, # moving average of encoder - paper showed anywhere from 0.9 to 0.999 was ok
|
||||
center_moving_average_decay = 0.9, # moving average of teacher centers - paper showed anywhere from 0.9 to 0.999 was ok
|
||||
@@ -1735,7 +1735,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
|
||||
```bibtex
|
||||
@misc{touvron2020training,
|
||||
title = {Training data-efficient image transformers & distillation through attention},
|
||||
title = {Training data-efficient image transformers & distillation through attention},
|
||||
author = {Hugo Touvron and Matthieu Cord and Matthijs Douze and Francisco Massa and Alexandre Sablayrolles and Hervé Jégou},
|
||||
year = {2020},
|
||||
eprint = {2012.12877},
|
||||
@@ -1768,7 +1768,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
|
||||
```bibtex
|
||||
@misc{touvron2021going,
|
||||
title = {Going deeper with Image Transformers},
|
||||
title = {Going deeper with Image Transformers},
|
||||
author = {Hugo Touvron and Matthieu Cord and Alexandre Sablayrolles and Gabriel Synnaeve and Hervé Jégou},
|
||||
year = {2021},
|
||||
eprint = {2103.17239},
|
||||
@@ -1801,7 +1801,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
|
||||
```bibtex
|
||||
@misc{heo2021rethinking,
|
||||
title = {Rethinking Spatial Dimensions of Vision Transformers},
|
||||
title = {Rethinking Spatial Dimensions of Vision Transformers},
|
||||
author = {Byeongho Heo and Sangdoo Yun and Dongyoon Han and Sanghyuk Chun and Junsuk Choe and Seong Joon Oh},
|
||||
year = {2021},
|
||||
eprint = {2103.16302},
|
||||
@@ -1845,7 +1845,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
|
||||
```bibtex
|
||||
@misc{su2021roformer,
|
||||
title = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
|
||||
title = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
|
||||
author = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
|
||||
year = {2021},
|
||||
eprint = {2104.09864},
|
||||
@@ -1867,7 +1867,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
|
||||
```bibtex
|
||||
@misc{chen2021regionvit,
|
||||
title = {RegionViT: Regional-to-Local Attention for Vision Transformers},
|
||||
title = {RegionViT: Regional-to-Local Attention for Vision Transformers},
|
||||
author = {Chun-Fu Chen and Rameswar Panda and Quanfu Fan},
|
||||
year = {2021},
|
||||
eprint = {2106.02689},
|
||||
@@ -1878,7 +1878,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
|
||||
```bibtex
|
||||
@misc{wang2021crossformer,
|
||||
title = {CrossFormer: A Versatile Vision Transformer Hinging on Cross-scale Attention},
|
||||
title = {CrossFormer: A Versatile Vision Transformer Hinging on Cross-scale Attention},
|
||||
author = {Wenxiao Wang and Lu Yao and Long Chen and Binbin Lin and Deng Cai and Xiaofei He and Wei Liu},
|
||||
year = {2021},
|
||||
eprint = {2108.00154},
|
||||
@@ -1900,7 +1900,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
|
||||
```bibtex
|
||||
@misc{he2021masked,
|
||||
title = {Masked Autoencoders Are Scalable Vision Learners},
|
||||
title = {Masked Autoencoders Are Scalable Vision Learners},
|
||||
author = {Kaiming He and Xinlei Chen and Saining Xie and Yanghao Li and Piotr Dollár and Ross Girshick},
|
||||
year = {2021},
|
||||
eprint = {2111.06377},
|
||||
@@ -1911,7 +1911,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
|
||||
```bibtex
|
||||
@misc{xie2021simmim,
|
||||
title = {SimMIM: A Simple Framework for Masked Image Modeling},
|
||||
title = {SimMIM: A Simple Framework for Masked Image Modeling},
|
||||
author = {Zhenda Xie and Zheng Zhang and Yue Cao and Yutong Lin and Jianmin Bao and Zhuliang Yao and Qi Dai and Han Hu},
|
||||
year = {2021},
|
||||
eprint = {2111.09886},
|
||||
@@ -1944,7 +1944,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
|
||||
```bibtex
|
||||
@misc{lee2021vision,
|
||||
title = {Vision Transformer for Small-Size Datasets},
|
||||
title = {Vision Transformer for Small-Size Datasets},
|
||||
author = {Seung Hoon Lee and Seunghyun Lee and Byung Cheol Song},
|
||||
year = {2021},
|
||||
eprint = {2112.13492},
|
||||
@@ -1966,7 +1966,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
|
||||
```bibtex
|
||||
@misc{yang2022scalablevit,
|
||||
title = {ScalableViT: Rethinking the Context-oriented Generalization of Vision Transformer},
|
||||
title = {ScalableViT: Rethinking the Context-oriented Generalization of Vision Transformer},
|
||||
author = {Rui Yang and Hailong Ma and Jie Wu and Yansong Tang and Xuefeng Xiao and Min Zheng and Xiu Li},
|
||||
year = {2022},
|
||||
eprint = {2203.10790},
|
||||
@@ -2203,37 +2203,119 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
|
||||
```bibtex
|
||||
@misc{carrigg2025decorrelationspeedsvisiontransformers,
|
||||
title = {Decorrelation Speeds Up Vision Transformers},
|
||||
title = {Decorrelation Speeds Up Vision Transformers},
|
||||
author = {Kieran Carrigg and Rob van Gastel and Melda Yeghaian and Sander Dalm and Faysal Boughorbel and Marcel van Gerven},
|
||||
year = {2025},
|
||||
eprint = {2510.14657},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV},
|
||||
url = {https://arxiv.org/abs/2510.14657},
|
||||
url = {https://arxiv.org/abs/2510.14657},
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{gopalakrishnan2025decouplingwhatwherepolar,
|
||||
title = {Decoupling the "What" and "Where" With Polar Coordinate Positional Embeddings},
|
||||
title = {Decoupling the "What" and "Where" With Polar Coordinate Positional Embeddings},
|
||||
author = {Anand Gopalakrishnan and Robert Csordás and Jürgen Schmidhuber and Michael C. Mozer},
|
||||
year = {2025},
|
||||
eprint = {2509.10534},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.LG},
|
||||
url = {https://arxiv.org/abs/2509.10534},
|
||||
url = {https://arxiv.org/abs/2509.10534},
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{qiu2025gatedattentionlargelanguage,
|
||||
title = {Gated Attention for Large Language Models: Non-linearity, Sparsity, and Attention-Sink-Free},
|
||||
title = {Gated Attention for Large Language Models: Non-linearity, Sparsity, and Attention-Sink-Free},
|
||||
author = {Zihan Qiu and Zekun Wang and Bo Zheng and Zeyu Huang and Kaiyue Wen and Songlin Yang and Rui Men and Le Yu and Fei Huang and Suozhi Huang and Dayiheng Liu and Jingren Zhou and Junyang Lin},
|
||||
year = {2025},
|
||||
eprint = {2505.06708},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CL},
|
||||
url = {https://arxiv.org/abs/2505.06708},
|
||||
url = {https://arxiv.org/abs/2505.06708}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{chen2026postlayernormbackstableexpressive,
|
||||
title = {Post-LayerNorm Is Back: Stable, ExpressivE, and Deep},
|
||||
author = {Chen Chen and Lai Wei},
|
||||
year = {2026},
|
||||
eprint = {2601.19895},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.LG},
|
||||
url = {https://arxiv.org/abs/2601.19895},
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{intelligence2025pi06vlalearnsexperience,
|
||||
title = {$\pi^{*}_{0.6}$: a VLA That Learns From Experience},
|
||||
author = {Physical Intelligence and Ali Amin and Raichelle Aniceto and Ashwin Balakrishna and Kevin Black and Ken Conley and Grace Connors and James Darpinian and Karan Dhabalia and Jared DiCarlo and Danny Driess and Michael Equi and Adnan Esmail and Yunhao Fang and Chelsea Finn and Catherine Glossop and Thomas Godden and Ivan Goryachev and Lachy Groom and Hunter Hancock and Karol Hausman and Gashon Hussein and Brian Ichter and Szymon Jakubczak and Rowan Jen and Tim Jones and Ben Katz and Liyiming Ke and Chandra Kuchi and Marinda Lamb and Devin LeBlanc and Sergey Levine and Adrian Li-Bell and Yao Lu and Vishnu Mano and Mohith Mothukuri and Suraj Nair and Karl Pertsch and Allen Z. Ren and Charvi Sharma and Lucy Xiaoyang Shi and Laura Smith and Jost Tobias Springenberg and Kyle Stachowicz and Will Stoeckle and Alex Swerdlow and James Tanner and Marcel Torne and Quan Vuong and Anna Walling and Haohuan Wang and Blake Williams and Sukwon Yoo and Lili Yu and Ury Zhilinsky and Zhiyuan Zhou},
|
||||
year = {2025},
|
||||
eprint = {2511.14759},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.LG},
|
||||
url = {https://arxiv.org/abs/2511.14759},
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{kimiteam2026attentionresiduals,
|
||||
title = {Attention Residuals},
|
||||
author = {Kimi Team and Guangyu Chen and Yu Zhang and Jianlin Su and Weixin Xu and Siyuan Pan and Yaoyu Wang and Yucheng Wang and Guanduo Chen and Bohong Yin and Yutian Chen and Junjie Yan and Ming Wei and Y. Zhang and Fanqing Meng and Chao Hong and Xiaotong Xie and Shaowei Liu and Enzhe Lu and Yunpeng Tai and Yanru Chen and Xin Men and Haiqing Guo and Y. Charles and Haoyu Lu and Lin Sui and Jinguo Zhu and Zaida Zhou and Weiran He and Weixiao Huang and Xinran Xu and Yuzhi Wang and Guokun Lai and Yulun Du and Yuxin Wu and Zhilin Yang and Xinyu Zhou},
|
||||
year = {2026},
|
||||
eprint = {2603.15031},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CL},
|
||||
url = {https://arxiv.org/abs/2603.15031},
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{balestriero2025lejepa,
|
||||
title = {LeJEPA: Provable and Scalable Self-Supervised Learning Without the Heuristics},
|
||||
author = {Randall Balestriero and Yann LeCun},
|
||||
year = {2025},
|
||||
eprint = {2511.08544},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.LG},
|
||||
url = {https://arxiv.org/abs/2511.08544},
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{oh2026revisitingresidualconnectionsorthogonal,
|
||||
title = {Revisiting Residual Connections: Orthogonal Updates for Stable and Efficient Deep Networks},
|
||||
author = {Giyeong Oh and Woohyun Cho and Siyeol Kim and Suhwan Choi and Youngjae Yu},
|
||||
year = {2026},
|
||||
eprint = {2505.11881},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV},
|
||||
url = {https://arxiv.org/abs/2505.11881},
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{niu2026learning,
|
||||
title = {Learning to Grasp Anything By Playing with Random Toys},
|
||||
author = {Dantong Niu and Yuvan Sharma and Baifeng Shi and Rachel Ding and Matteo Gioia and Haoru Xue and Henry Tsai and Konstantinos Kallidromitis and Anirudh Pai and S. Shankar Sastry and Trevor Darrell and Jitendra Malik and Roei Herzig},
|
||||
booktitle = {The Fourteenth International Conference on Learning Representations},
|
||||
year = {2026},
|
||||
url = {https://openreview.net/forum?id=NZDaMcpXZm}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{marouani2026revisitingclspatchtoken,
|
||||
title = {Revisiting [CLS] and Patch Token Interaction in Vision Transformers},
|
||||
author = {Alexis Marouani and Oriane Siméoni and Hervé Jégou and Piotr Bojanowski and Huy V. Vo},
|
||||
year = {2026},
|
||||
eprint = {2602.08626},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV},
|
||||
url = {https://arxiv.org/abs/2602.08626},
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "vit-pytorch"
|
||||
version = "1.17.6"
|
||||
version = "1.20.4"
|
||||
description = "Vision Transformer (ViT) - Pytorch"
|
||||
readme = { file = "README.md", content-type = "text/markdown" }
|
||||
license = { file = "LICENSE" }
|
||||
@@ -31,8 +31,8 @@ classifiers = [
|
||||
"Programming Language :: Python :: 3.12",
|
||||
]
|
||||
dependencies = [
|
||||
"einops>=0.7.0",
|
||||
"torch>=1.10",
|
||||
"einops>=0.8.2",
|
||||
"torch>=2.4",
|
||||
"torchvision",
|
||||
]
|
||||
|
||||
@@ -44,8 +44,8 @@ test = [
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/lucidrains/vit-pytorch"
|
||||
Repository = "https://github.com/lucidrains/vit-pytorch"
|
||||
Homepage = "https://codeberg.org/lucidrains/vit-pytorch"
|
||||
Repository = "https://codeberg.org/lucidrains/vit-pytorch"
|
||||
|
||||
[tool.setuptools]
|
||||
include-package-data = true
|
||||
|
||||
@@ -26,7 +26,7 @@ class AcceptVideoWrapper(Module):
|
||||
dim_emb = None,
|
||||
time_seq_len = None,
|
||||
embed_is_channel_first = False,
|
||||
output_pos_add_pos_emb = 0, # defaults to first output position to add embedding
|
||||
output_pos_add_pos_emb = 0, # defaults to first output position to add embedding
|
||||
proj_embed_to_dim = None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -103,7 +103,7 @@ class JumboViT(Module):
|
||||
h = image_height // patch_height,
|
||||
w = image_width // patch_width,
|
||||
dim = dim,
|
||||
)
|
||||
)
|
||||
|
||||
jumbo_cls_dim = dim * jumbo_cls_k
|
||||
|
||||
|
||||
@@ -145,7 +145,7 @@ class ViT(nn.Module):
|
||||
return x
|
||||
|
||||
def forward(self, img):
|
||||
x = self.img_to_tokens(img)
|
||||
x = self.img_to_tokens(img)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
@@ -160,7 +160,7 @@ class Adapter(nn.Module):
|
||||
*,
|
||||
vit,
|
||||
num_memories_per_layer = 10,
|
||||
num_classes = 2,
|
||||
num_classes = 2,
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(vit, ViT)
|
||||
@@ -188,7 +188,7 @@ class Adapter(nn.Module):
|
||||
)
|
||||
|
||||
# specialized attention mask to preserve the output of the original ViT
|
||||
# it allows the memory CLS token to attend to all other tokens (and the learnable memory layer tokens), but not vice versa
|
||||
# it allows the memory CLS token to attend to all other tokens (and the learnable memory layer tokens), but not vice versa
|
||||
|
||||
attn_mask = torch.ones((num_patches, num_patches), dtype = torch.bool)
|
||||
attn_mask = F.pad(attn_mask, (1, num_memories_per_layer), value = False) # main tokens cannot attend to learnable memories per layer
|
||||
@@ -203,7 +203,7 @@ class Adapter(nn.Module):
|
||||
# add task specific memory tokens
|
||||
|
||||
memory_cls_tokens = repeat(self.memory_cls_token, 'd -> b 1 d', b = b)
|
||||
tokens = torch.cat((memory_cls_tokens, tokens), dim = 1)
|
||||
tokens = torch.cat((memory_cls_tokens, tokens), dim = 1)
|
||||
|
||||
# pass memories along with image tokens through transformer for attending
|
||||
|
||||
|
||||
319
vit_pytorch/lejepa.py
Normal file
319
vit_pytorch/lejepa.py
Normal file
@@ -0,0 +1,319 @@
|
||||
import random
|
||||
from functools import wraps
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Module
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torchvision import transforms as T
|
||||
from einops import rearrange
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def singleton(cache_key):
|
||||
def inner_fn(fn):
|
||||
@wraps(fn)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
instance = getattr(self, cache_key)
|
||||
if instance is not None:
|
||||
return instance
|
||||
|
||||
instance = fn(self, *args, **kwargs)
|
||||
setattr(self, cache_key, instance)
|
||||
return instance
|
||||
return wrapper
|
||||
return inner_fn
|
||||
|
||||
def get_module_device(module):
|
||||
return next(module.parameters()).device
|
||||
|
||||
def l2norm(t, eps = 1e-6):
|
||||
return F.normalize(t, dim = -1, eps = eps)
|
||||
|
||||
# loss function
|
||||
|
||||
def sigreg_loss(
|
||||
x,
|
||||
num_slices = 1024,
|
||||
domain = (-5, 5),
|
||||
num_knots = 17
|
||||
):
|
||||
# Randall Balestriero - https://arxiv.org/abs/2511.08544
|
||||
|
||||
dim, device = x.shape[-1], x.device
|
||||
|
||||
# slice sampling
|
||||
|
||||
rand_projs = torch.randn((num_slices, dim), device = device)
|
||||
rand_projs = l2norm(rand_projs)
|
||||
|
||||
# integration points
|
||||
|
||||
t = torch.linspace(*domain, num_knots, device = device)
|
||||
|
||||
# theoretical CF for N(0, 1) and Gauss. window
|
||||
|
||||
exp_f = (-0.5 * t.square()).exp()
|
||||
|
||||
# empirical CF
|
||||
|
||||
x_t = torch.einsum('... d, m d -> ... m', x, rand_projs)
|
||||
x_t = rearrange(x_t, '... m -> (...) m')
|
||||
|
||||
x_t = rearrange(x_t, 'n m -> n m 1') * t
|
||||
ecf = (1j * x_t).exp().mean(dim = 0)
|
||||
|
||||
# weighted L2 distance
|
||||
|
||||
err = ecf.sub(exp_f).abs().square().mul(exp_f)
|
||||
|
||||
return torch.trapezoid(err, t, dim = -1).mean()
|
||||
|
||||
# augmentation utils
|
||||
|
||||
class RandomApply(Module):
|
||||
def __init__(self, fn, p):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.p = p
|
||||
|
||||
def forward(self, x):
|
||||
if random.random() > self.p:
|
||||
return x
|
||||
return self.fn(x)
|
||||
|
||||
# MLP class for projector
|
||||
|
||||
class L2Norm(Module):
|
||||
def forward(self, x, eps = 1e-6):
|
||||
return l2norm(x, eps)
|
||||
|
||||
class MLP(Module):
|
||||
def __init__(self, dim, dim_out, num_layers, hidden_size = 256):
|
||||
super().__init__()
|
||||
|
||||
layers = []
|
||||
dims = (dim, *((hidden_size,) * (num_layers - 1)))
|
||||
|
||||
for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
is_last = ind == (len(dims) - 1)
|
||||
|
||||
layers.extend([
|
||||
nn.Linear(layer_dim_in, layer_dim_out),
|
||||
nn.GELU() if not is_last else nn.Identity()
|
||||
])
|
||||
|
||||
self.net = nn.Sequential(
|
||||
*layers,
|
||||
L2Norm(),
|
||||
nn.Linear(hidden_size, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
# wrapper
|
||||
|
||||
class NetWrapper(Module):
|
||||
def __init__(self, net, output_dim, projection_hidden_size, projection_num_layers, layer = -2):
|
||||
super().__init__()
|
||||
self.net = net
|
||||
self.layer = layer
|
||||
|
||||
self.projector = None
|
||||
self.projection_hidden_size = projection_hidden_size
|
||||
self.projection_num_layers = projection_num_layers
|
||||
self.output_dim = output_dim
|
||||
|
||||
self.hidden = {}
|
||||
self.hook_registered = False
|
||||
|
||||
def _find_layer(self):
|
||||
if type(self.layer) == str:
|
||||
modules = dict([*self.net.named_modules()])
|
||||
return modules.get(self.layer, None)
|
||||
elif type(self.layer) == int:
|
||||
children = [*self.net.children()]
|
||||
return children[self.layer]
|
||||
return None
|
||||
|
||||
def _hook(self, _, input, output):
|
||||
device = input[0].device
|
||||
self.hidden[device] = output.flatten(1)
|
||||
|
||||
def _register_hook(self):
|
||||
layer = self._find_layer()
|
||||
assert layer is not None, f'hidden layer ({self.layer}) not found'
|
||||
handle = layer.register_forward_hook(self._hook)
|
||||
self.hook_registered = True
|
||||
|
||||
@singleton('projector')
|
||||
def _get_projector(self, hidden):
|
||||
_, dim = hidden.shape
|
||||
projector = MLP(dim, self.output_dim, self.projection_num_layers, self.projection_hidden_size)
|
||||
return projector.to(hidden)
|
||||
|
||||
def get_embedding(self, x):
|
||||
if self.layer == -1:
|
||||
return self.net(x)
|
||||
|
||||
if not self.hook_registered:
|
||||
self._register_hook()
|
||||
|
||||
self.hidden.clear()
|
||||
_ = self.net(x)
|
||||
hidden = self.hidden[x.device]
|
||||
self.hidden.clear()
|
||||
|
||||
assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
|
||||
return hidden
|
||||
|
||||
def forward(self, x, return_projection = True):
|
||||
embed = self.get_embedding(x)
|
||||
if not return_projection:
|
||||
return embed
|
||||
|
||||
projector = self._get_projector(embed)
|
||||
return projector(embed), embed
|
||||
|
||||
# main class
|
||||
|
||||
class LeJEPA(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
net,
|
||||
image_size,
|
||||
hidden_layer = -2,
|
||||
projection_hidden_size = 256,
|
||||
num_classes_K = 65336,
|
||||
projection_layers = 4,
|
||||
local_upper_crop_scale = 0.4,
|
||||
global_lower_crop_scale = 0.5,
|
||||
target_loss_weight = 1.,
|
||||
sigreg_loss_weight = 1.,
|
||||
sigreg_loss_kwargs = dict(
|
||||
num_slices = 1024,
|
||||
domain = (-5, 5),
|
||||
num_knots = 17
|
||||
),
|
||||
augment_fn = None,
|
||||
augment_fn2 = None
|
||||
):
|
||||
super().__init__()
|
||||
self.net = net
|
||||
|
||||
# default BYOL augmentation
|
||||
|
||||
DEFAULT_AUG = torch.nn.Sequential(
|
||||
RandomApply(
|
||||
T.ColorJitter(0.8, 0.8, 0.8, 0.2),
|
||||
p = 0.3
|
||||
),
|
||||
T.RandomGrayscale(p=0.2),
|
||||
T.RandomHorizontalFlip(),
|
||||
RandomApply(
|
||||
T.GaussianBlur((3, 3), (1.0, 2.0)),
|
||||
p = 0.2
|
||||
),
|
||||
T.Normalize(
|
||||
mean=torch.tensor([0.485, 0.456, 0.406]),
|
||||
std=torch.tensor([0.229, 0.224, 0.225])),
|
||||
)
|
||||
|
||||
self.augment1 = default(augment_fn, DEFAULT_AUG)
|
||||
self.augment2 = default(augment_fn2, DEFAULT_AUG)
|
||||
|
||||
# local and global crops
|
||||
|
||||
self.local_crop = T.RandomResizedCrop((image_size, image_size), scale = (0.05, local_upper_crop_scale))
|
||||
self.global_crop = T.RandomResizedCrop((image_size, image_size), scale = (global_lower_crop_scale, 1.))
|
||||
|
||||
self.encoder = NetWrapper(net, num_classes_K, projection_hidden_size, projection_layers, layer = hidden_layer)
|
||||
|
||||
self.target_loss_weight = target_loss_weight
|
||||
self.sigreg_loss_weight = sigreg_loss_weight
|
||||
self.sigreg_loss_kwargs = sigreg_loss_kwargs
|
||||
|
||||
# get device of network and make wrapper same device
|
||||
device = get_module_device(net)
|
||||
self.to(device)
|
||||
|
||||
# send a mock image tensor to instantiate singleton parameters
|
||||
self.forward(torch.randn(2, 3, image_size, image_size, device=device))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
return_embedding = False,
|
||||
return_projection = True
|
||||
):
|
||||
if return_embedding:
|
||||
return self.encoder(x, return_projection = return_projection)
|
||||
|
||||
image_one, image_two = self.augment1(x), self.augment2(x)
|
||||
|
||||
local_image_one, local_image_two = self.local_crop(image_one), self.local_crop(image_two)
|
||||
global_image_one, global_image_two = self.global_crop(image_one), self.global_crop(image_two)
|
||||
|
||||
local_images = torch.cat((local_image_one, local_image_two), dim = 0)
|
||||
proj_locals, _ = self.encoder(local_images)
|
||||
proj_local_one, proj_local_two = proj_locals.chunk(2, dim = 0)
|
||||
|
||||
with torch.no_grad():
|
||||
global_images = torch.cat((global_image_one, global_image_two), dim = 0)
|
||||
proj_globals, _ = self.encoder(global_images)
|
||||
proj_global_one, proj_global_two = proj_globals.chunk(2, dim = 0)
|
||||
|
||||
# invariance loss
|
||||
|
||||
mse_loss = F.mse_loss(proj_local_one, proj_global_two) + F.mse_loss(proj_local_two, proj_global_one)
|
||||
|
||||
# sigreg loss
|
||||
|
||||
sreg_loss = sigreg_loss(proj_locals, **self.sigreg_loss_kwargs)
|
||||
|
||||
return mse_loss * self.target_loss_weight + sreg_loss * self.sigreg_loss_weight
|
||||
|
||||
# quick run
|
||||
|
||||
if __name__ == '__main__':
|
||||
from vit_pytorch import ViT
|
||||
|
||||
model = ViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048
|
||||
)
|
||||
|
||||
learner = LeJEPA(
|
||||
model,
|
||||
image_size = 256,
|
||||
hidden_layer = 'to_latent', # layer name where output is hidden dimension
|
||||
projection_hidden_size = 256, # projector network hidden dimension
|
||||
projection_layers = 4, # number of layers in projection network
|
||||
num_classes_K = 65336, # output dimension
|
||||
target_loss_weight = 1.0,
|
||||
sigreg_loss_weight = 1.0
|
||||
)
|
||||
|
||||
opt = torch.optim.Adam(learner.parameters(), lr = 3e-4)
|
||||
|
||||
images = torch.randn(8, 3, 256, 256)
|
||||
|
||||
loss = learner(images)
|
||||
opt.zero_grad()
|
||||
loss.backward()
|
||||
opt.step()
|
||||
|
||||
print('loss:', loss.item())
|
||||
@@ -182,7 +182,7 @@ class LeViT(nn.Module):
|
||||
def forward(self, img):
|
||||
x = self.conv_embedding(img)
|
||||
|
||||
x = self.backbone(x)
|
||||
x = self.backbone(x)
|
||||
|
||||
x = self.pool(x)
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ class MAE(nn.Module):
|
||||
if self.encoder.pool == "cls":
|
||||
tokens += self.encoder.pos_embedding[:, 1:(num_patches + 1)]
|
||||
elif self.encoder.pool == "mean":
|
||||
tokens += self.encoder.pos_embedding.to(device, dtype=tokens.dtype)
|
||||
tokens += self.encoder.pos_embedding.to(device, dtype=tokens.dtype)
|
||||
|
||||
# calculate of patches needed to be masked, and get random indices, dividing it up for mask vs unmasked
|
||||
|
||||
@@ -87,7 +87,7 @@ class MAE(nn.Module):
|
||||
mask_tokens = mask_tokens + self.decoder_pos_emb(masked_indices)
|
||||
|
||||
# concat the masked tokens to the decoder tokens and attend with decoder
|
||||
|
||||
|
||||
decoder_tokens = torch.zeros(batch, num_patches, self.decoder_dim, device=device)
|
||||
decoder_tokens[batch_range, unmasked_indices] = unmasked_decoder_tokens
|
||||
decoder_tokens[batch_range, masked_indices] = mask_tokens
|
||||
|
||||
@@ -77,7 +77,7 @@ class Dropsample(nn.Module):
|
||||
def __init__(self, prob = 0):
|
||||
super().__init__()
|
||||
self.prob = prob
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
device = x.device
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ class Dropsample(Module):
|
||||
def __init__(self, prob = 0):
|
||||
super().__init__()
|
||||
self.prob = prob
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
device = x.device
|
||||
|
||||
|
||||
@@ -1,243 +1,243 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Reduce
|
||||
|
||||
# helpers
|
||||
|
||||
def conv_1x1_bn(inp, oup):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
nn.SiLU()
|
||||
)
|
||||
|
||||
def conv_nxn_bn(inp, oup, kernel_size=3, stride=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, kernel_size, stride, 1, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
nn.SiLU()
|
||||
)
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout=0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.attend = nn.Softmax(dim=-1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b p h n d -> b p n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
"""Transformer block described in ViT.
|
||||
Paper: https://arxiv.org/abs/2010.11929
|
||||
Based on: https://github.com/lucidrains/vit-pytorch
|
||||
"""
|
||||
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads, dim_head, dropout),
|
||||
FeedForward(dim, mlp_dim, dropout)
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
class MV2Block(nn.Module):
|
||||
"""MV2 block described in MobileNetV2.
|
||||
Paper: https://arxiv.org/pdf/1801.04381
|
||||
Based on: https://github.com/tonylins/pytorch-mobilenet-v2
|
||||
"""
|
||||
|
||||
def __init__(self, inp, oup, stride=1, expansion=4):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
assert stride in [1, 2]
|
||||
|
||||
hidden_dim = int(inp * expansion)
|
||||
self.use_res_connect = self.stride == 1 and inp == oup
|
||||
|
||||
if expansion == 1:
|
||||
self.conv = nn.Sequential(
|
||||
# dw
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
|
||||
1, groups=hidden_dim, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.SiLU(),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
)
|
||||
else:
|
||||
self.conv = nn.Sequential(
|
||||
# pw
|
||||
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.SiLU(),
|
||||
# dw
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
|
||||
1, groups=hidden_dim, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.SiLU(),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv(x)
|
||||
if self.use_res_connect:
|
||||
out = out + x
|
||||
return out
|
||||
|
||||
class MobileViTBlock(nn.Module):
|
||||
def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):
|
||||
super().__init__()
|
||||
self.ph, self.pw = patch_size
|
||||
|
||||
self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
|
||||
self.conv2 = conv_1x1_bn(channel, dim)
|
||||
|
||||
self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout)
|
||||
|
||||
self.conv3 = conv_1x1_bn(dim, channel)
|
||||
self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)
|
||||
|
||||
def forward(self, x):
|
||||
y = x.clone()
|
||||
|
||||
# Local representations
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
|
||||
# Global representations
|
||||
_, _, h, w = x.shape
|
||||
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
|
||||
x = self.transformer(x)
|
||||
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)
|
||||
|
||||
# Fusion
|
||||
x = self.conv3(x)
|
||||
x = torch.cat((x, y), 1)
|
||||
x = self.conv4(x)
|
||||
return x
|
||||
|
||||
class MobileViT(nn.Module):
|
||||
"""MobileViT.
|
||||
Paper: https://arxiv.org/abs/2110.02178
|
||||
Based on: https://github.com/chinhsuanwu/mobilevit-pytorch
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size,
|
||||
dims,
|
||||
channels,
|
||||
num_classes,
|
||||
expansion=4,
|
||||
kernel_size=3,
|
||||
patch_size=(2, 2),
|
||||
depths=(2, 4, 3)
|
||||
):
|
||||
super().__init__()
|
||||
assert len(dims) == 3, 'dims must be a tuple of 3'
|
||||
assert len(depths) == 3, 'depths must be a tuple of 3'
|
||||
|
||||
ih, iw = image_size
|
||||
ph, pw = patch_size
|
||||
assert ih % ph == 0 and iw % pw == 0
|
||||
|
||||
init_dim, *_, last_dim = channels
|
||||
|
||||
self.conv1 = conv_nxn_bn(3, init_dim, stride=2)
|
||||
|
||||
self.stem = nn.ModuleList([])
|
||||
self.stem.append(MV2Block(channels[0], channels[1], 1, expansion))
|
||||
self.stem.append(MV2Block(channels[1], channels[2], 2, expansion))
|
||||
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
|
||||
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
|
||||
|
||||
self.trunk = nn.ModuleList([])
|
||||
self.trunk.append(nn.ModuleList([
|
||||
MV2Block(channels[3], channels[4], 2, expansion),
|
||||
MobileViTBlock(dims[0], depths[0], channels[5],
|
||||
kernel_size, patch_size, int(dims[0] * 2))
|
||||
]))
|
||||
|
||||
self.trunk.append(nn.ModuleList([
|
||||
MV2Block(channels[5], channels[6], 2, expansion),
|
||||
MobileViTBlock(dims[1], depths[1], channels[7],
|
||||
kernel_size, patch_size, int(dims[1] * 4))
|
||||
]))
|
||||
|
||||
self.trunk.append(nn.ModuleList([
|
||||
MV2Block(channels[7], channels[8], 2, expansion),
|
||||
MobileViTBlock(dims[2], depths[2], channels[9],
|
||||
kernel_size, patch_size, int(dims[2] * 4))
|
||||
]))
|
||||
|
||||
self.to_logits = nn.Sequential(
|
||||
conv_1x1_bn(channels[-2], last_dim),
|
||||
Reduce('b c h w -> b c', 'mean'),
|
||||
nn.Linear(channels[-1], num_classes, bias=False)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
|
||||
for conv in self.stem:
|
||||
x = conv(x)
|
||||
|
||||
for conv, attn in self.trunk:
|
||||
x = conv(x)
|
||||
x = attn(x)
|
||||
|
||||
return self.to_logits(x)
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Reduce
|
||||
|
||||
# helpers
|
||||
|
||||
def conv_1x1_bn(inp, oup):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
nn.SiLU()
|
||||
)
|
||||
|
||||
def conv_nxn_bn(inp, oup, kernel_size=3, stride=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, kernel_size, stride, 1, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
nn.SiLU()
|
||||
)
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout=0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.attend = nn.Softmax(dim=-1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b p h n d -> b p n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
"""Transformer block described in ViT.
|
||||
Paper: https://arxiv.org/abs/2010.11929
|
||||
Based on: https://github.com/lucidrains/vit-pytorch
|
||||
"""
|
||||
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads, dim_head, dropout),
|
||||
FeedForward(dim, mlp_dim, dropout)
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
class MV2Block(nn.Module):
|
||||
"""MV2 block described in MobileNetV2.
|
||||
Paper: https://arxiv.org/pdf/1801.04381
|
||||
Based on: https://github.com/tonylins/pytorch-mobilenet-v2
|
||||
"""
|
||||
|
||||
def __init__(self, inp, oup, stride=1, expansion=4):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
assert stride in [1, 2]
|
||||
|
||||
hidden_dim = int(inp * expansion)
|
||||
self.use_res_connect = self.stride == 1 and inp == oup
|
||||
|
||||
if expansion == 1:
|
||||
self.conv = nn.Sequential(
|
||||
# dw
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
|
||||
1, groups=hidden_dim, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.SiLU(),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
)
|
||||
else:
|
||||
self.conv = nn.Sequential(
|
||||
# pw
|
||||
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.SiLU(),
|
||||
# dw
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
|
||||
1, groups=hidden_dim, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.SiLU(),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv(x)
|
||||
if self.use_res_connect:
|
||||
out = out + x
|
||||
return out
|
||||
|
||||
class MobileViTBlock(nn.Module):
|
||||
def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):
|
||||
super().__init__()
|
||||
self.ph, self.pw = patch_size
|
||||
|
||||
self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
|
||||
self.conv2 = conv_1x1_bn(channel, dim)
|
||||
|
||||
self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout)
|
||||
|
||||
self.conv3 = conv_1x1_bn(dim, channel)
|
||||
self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)
|
||||
|
||||
def forward(self, x):
|
||||
y = x.clone()
|
||||
|
||||
# Local representations
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
|
||||
# Global representations
|
||||
_, _, h, w = x.shape
|
||||
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
|
||||
x = self.transformer(x)
|
||||
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)
|
||||
|
||||
# Fusion
|
||||
x = self.conv3(x)
|
||||
x = torch.cat((x, y), 1)
|
||||
x = self.conv4(x)
|
||||
return x
|
||||
|
||||
class MobileViT(nn.Module):
|
||||
"""MobileViT.
|
||||
Paper: https://arxiv.org/abs/2110.02178
|
||||
Based on: https://github.com/chinhsuanwu/mobilevit-pytorch
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size,
|
||||
dims,
|
||||
channels,
|
||||
num_classes,
|
||||
expansion=4,
|
||||
kernel_size=3,
|
||||
patch_size=(2, 2),
|
||||
depths=(2, 4, 3)
|
||||
):
|
||||
super().__init__()
|
||||
assert len(dims) == 3, 'dims must be a tuple of 3'
|
||||
assert len(depths) == 3, 'depths must be a tuple of 3'
|
||||
|
||||
ih, iw = image_size
|
||||
ph, pw = patch_size
|
||||
assert ih % ph == 0 and iw % pw == 0
|
||||
|
||||
init_dim, *_, last_dim = channels
|
||||
|
||||
self.conv1 = conv_nxn_bn(3, init_dim, stride=2)
|
||||
|
||||
self.stem = nn.ModuleList([])
|
||||
self.stem.append(MV2Block(channels[0], channels[1], 1, expansion))
|
||||
self.stem.append(MV2Block(channels[1], channels[2], 2, expansion))
|
||||
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
|
||||
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
|
||||
|
||||
self.trunk = nn.ModuleList([])
|
||||
self.trunk.append(nn.ModuleList([
|
||||
MV2Block(channels[3], channels[4], 2, expansion),
|
||||
MobileViTBlock(dims[0], depths[0], channels[5],
|
||||
kernel_size, patch_size, int(dims[0] * 2))
|
||||
]))
|
||||
|
||||
self.trunk.append(nn.ModuleList([
|
||||
MV2Block(channels[5], channels[6], 2, expansion),
|
||||
MobileViTBlock(dims[1], depths[1], channels[7],
|
||||
kernel_size, patch_size, int(dims[1] * 4))
|
||||
]))
|
||||
|
||||
self.trunk.append(nn.ModuleList([
|
||||
MV2Block(channels[7], channels[8], 2, expansion),
|
||||
MobileViTBlock(dims[2], depths[2], channels[9],
|
||||
kernel_size, patch_size, int(dims[2] * 4))
|
||||
]))
|
||||
|
||||
self.to_logits = nn.Sequential(
|
||||
conv_1x1_bn(channels[-2], last_dim),
|
||||
Reduce('b c h w -> b c', 'mean'),
|
||||
nn.Linear(channels[-1], num_classes, bias=False)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
|
||||
for conv in self.stem:
|
||||
x = conv(x)
|
||||
|
||||
for conv, attn in self.trunk:
|
||||
x = conv(x)
|
||||
x = attn(x)
|
||||
|
||||
return self.to_logits(x)
|
||||
|
||||
@@ -110,7 +110,7 @@ class ViT(nn.Module):
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
|
||||
@@ -178,7 +178,7 @@ class MP3(nn.Module):
|
||||
|
||||
attended_tokens = self.vit.transformer(tokens, tokens_unmasked)
|
||||
logits = rearrange(self.mlp_head(attended_tokens), 'b n d -> (b n) d')
|
||||
|
||||
|
||||
# Define labels
|
||||
labels = repeat(torch.arange(num_patches, device = device), 'n -> (b n)', b = batch)
|
||||
loss = F.cross_entropy(logits, labels)
|
||||
|
||||
@@ -343,11 +343,11 @@ class NaViT(nn.Module):
|
||||
|
||||
# need to know how many images for final attention pooling
|
||||
|
||||
num_images = torch.tensor(num_images, device = device, dtype = torch.long)
|
||||
num_images = torch.tensor(num_images, device = device, dtype = torch.long)
|
||||
|
||||
# to patches
|
||||
|
||||
x = self.to_patch_embedding(patches)
|
||||
x = self.to_patch_embedding(patches)
|
||||
|
||||
# factorized 2d absolute positional embedding
|
||||
|
||||
|
||||
@@ -64,7 +64,7 @@ class Attention(Module):
|
||||
self.to_out = nn.Linear(dim_inner, dim, bias = False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
self,
|
||||
x,
|
||||
context: Tensor | None = None
|
||||
):
|
||||
|
||||
@@ -64,7 +64,7 @@ class Attention(Module):
|
||||
self.to_out = nn.Linear(dim_inner, dim, bias = False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
self,
|
||||
x,
|
||||
context: Tensor | None = None
|
||||
):
|
||||
|
||||
@@ -154,7 +154,7 @@ class PiT(nn.Module):
|
||||
|
||||
for ind, (layer_depth, layer_heads) in enumerate(zip(depth, heads)):
|
||||
not_last = ind < (len(depth) - 1)
|
||||
|
||||
|
||||
layers.append(Transformer(dim, layer_depth, layer_heads, dim_head, mlp_dim, dropout))
|
||||
|
||||
if not_last:
|
||||
|
||||
@@ -146,7 +146,7 @@ class R2LTransformer(nn.Module):
|
||||
region_tokens = rearrange(region_tokens, 'b c h w -> b (h w) c')
|
||||
|
||||
# calculate local relative positional bias
|
||||
|
||||
|
||||
h_range = torch.arange(window_size_h, device = device)
|
||||
w_range = torch.arange(window_size_w, device = device)
|
||||
|
||||
|
||||
@@ -187,7 +187,7 @@ class InteractiveWindowedSelfAttention(nn.Module):
|
||||
|
||||
out = rearrange(out, '(b x y) h (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz_h, y = width // wsz_w, w1 = wsz_h, w2 = wsz_w)
|
||||
|
||||
# add LIM output
|
||||
# add LIM output
|
||||
|
||||
out = out + local_out
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
|
||||
omega = 1. / (temperature ** omega)
|
||||
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
|
||||
return pe.type(dtype)
|
||||
|
||||
@@ -57,7 +57,7 @@ class Attend(nn.Module):
|
||||
config = self.cuda_config if q.is_cuda else self.cpu_config
|
||||
|
||||
# flash attention - https://arxiv.org/abs/2205.14135
|
||||
|
||||
|
||||
with torch.backends.cuda.sdp_kernel(**config._asdict()):
|
||||
out = F.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ def posemb_sincos_3d(patches, temperature = 10000, dtype = torch.float32):
|
||||
|
||||
z = z.flatten()[:, None] * omega[None, :]
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos(), z.sin(), z.cos()), dim = 1)
|
||||
|
||||
@@ -52,7 +52,7 @@ class Attend(Module):
|
||||
|
||||
def flash_attn(self, q, k, v):
|
||||
# flash attention - https://arxiv.org/abs/2205.14135
|
||||
|
||||
|
||||
with torch.backends.cuda.sdp_kernel(**self.config._asdict()):
|
||||
out = F.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ def FeedForward(dim, hidden_dim):
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
)
|
||||
)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64):
|
||||
|
||||
@@ -98,7 +98,7 @@ class SimpleViT(nn.Module):
|
||||
h = image_height // patch_height,
|
||||
w = image_width // patch_width,
|
||||
dim = dim,
|
||||
)
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ def posemb_sincos_3d(patches, temperature = 10000, dtype = torch.float32):
|
||||
|
||||
z = z.flatten()[:, None] * omega[None, :]
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos(), z.sin(), z.cos()), dim = 1)
|
||||
|
||||
|
||||
241
vit_pytorch/simple_vit_attn_residual.py
Normal file
241
vit_pytorch/simple_vit_attn_residual.py
Normal file
@@ -0,0 +1,241 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
from torch.nn import Module, ModuleList
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
def default(v, d):
|
||||
return v if exists(v) else d
|
||||
|
||||
def last(arr):
|
||||
return arr[-1]
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def divisible_by(num, den):
|
||||
return (num % den) == 0
|
||||
|
||||
def posemb_sincos_2d(h, w, dim, temperature = 10000, dtype = torch.float32):
|
||||
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing = 'ij')
|
||||
assert divisible_by(dim, 4), 'feature dimension must be multiple of 4 for sincos emb'
|
||||
omega = torch.arange(dim // 4) / (dim // 4 - 1)
|
||||
omega = 1.0 / (temperature ** omega)
|
||||
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
|
||||
return pe.type(dtype)
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, cross_attend = False):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.norm_context = nn.LayerNorm(dim) if cross_attend else None
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
def forward(self, x, context = None):
|
||||
x = self.norm(x)
|
||||
|
||||
if exists(context):
|
||||
context = self.norm_context(context)
|
||||
else:
|
||||
context = x
|
||||
|
||||
q = self.to_q(x)
|
||||
k, v = self.to_kv(context).chunk(2, dim = -1)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
attn = dots.softmax(dim = -1)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class AttentionResidual(Module):
|
||||
def __init__(self, fn, dim, heads = 8, dim_head = 64, learned_query = True, disable = False):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.disable = disable
|
||||
|
||||
if disable:
|
||||
return
|
||||
|
||||
self.attn = Attention(dim, heads = heads, dim_head = dim_head, cross_attend = True)
|
||||
self.learned_query = nn.Parameter(torch.randn(dim)) if learned_query else None
|
||||
|
||||
def forward(self, history: list[Tensor]) -> Tensor:
|
||||
if self.disable:
|
||||
return self.fn(last(history))
|
||||
|
||||
batch, seq_len = history[0].shape[:2]
|
||||
|
||||
context = torch.stack(history, dim = 2)
|
||||
context = rearrange(context, 'b n l d -> (b n) l d')
|
||||
|
||||
if exists(self.learned_query):
|
||||
q = repeat(self.learned_query, 'd -> (b n) 1 d', b = batch, n = seq_len)
|
||||
else:
|
||||
q = rearrange(last(history), 'b n d -> (b n) 1 d')
|
||||
|
||||
pooled = self.attn(q, context = context)
|
||||
pooled = rearrange(pooled, '(b n) 1 d -> b n d', b = batch, n = seq_len)
|
||||
|
||||
return self.fn(pooled)
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, learned_query = True):
|
||||
super().__init__()
|
||||
|
||||
self.layers = ModuleList([])
|
||||
for ind in range(depth):
|
||||
is_first = ind == 0
|
||||
|
||||
self.layers.append(ModuleList([
|
||||
AttentionResidual(Attention(dim, heads = heads, dim_head = dim_head), dim, heads = heads, dim_head = dim_head, learned_query = learned_query, disable = is_first),
|
||||
AttentionResidual(FeedForward(dim, mlp_dim), dim, heads = heads, dim_head = dim_head, learned_query = learned_query),
|
||||
]))
|
||||
|
||||
self.final_pool = AttentionResidual(nn.LayerNorm(dim), dim, heads = heads, dim_head = dim_head, learned_query = learned_query)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
history: list[Tensor] | None = None,
|
||||
return_history = False
|
||||
):
|
||||
history = [*default(history, [])]
|
||||
|
||||
history.append(x)
|
||||
|
||||
for attn_residual, ff_residual in self.layers:
|
||||
history.append(attn_residual(history))
|
||||
history.append(ff_residual(history))
|
||||
|
||||
out = self.final_pool(history)
|
||||
|
||||
if return_history:
|
||||
return out, history
|
||||
|
||||
return out
|
||||
|
||||
class SimpleViTAttnResidual(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_size,
|
||||
patch_size,
|
||||
num_classes,
|
||||
dim,
|
||||
depth,
|
||||
heads,
|
||||
mlp_dim,
|
||||
channels = 3,
|
||||
dim_head = 64,
|
||||
learned_query = True
|
||||
):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert divisible_by(image_height, patch_height) and divisible_by(image_width, patch_width), 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = posemb_sincos_2d(
|
||||
h = image_height // patch_height,
|
||||
w = image_width // patch_width,
|
||||
dim = dim,
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, learned_query = learned_query)
|
||||
|
||||
self.to_latent = nn.Identity()
|
||||
self.linear_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img,
|
||||
history: list[Tensor] | None = None,
|
||||
return_history = False
|
||||
):
|
||||
device, dtype = img.device, img.dtype
|
||||
|
||||
x = self.to_patch_embedding(img)
|
||||
x += self.pos_embedding.to(device, dtype = dtype)
|
||||
|
||||
x = self.transformer(x, history = history, return_history = return_history)
|
||||
|
||||
if return_history:
|
||||
x, history = x
|
||||
|
||||
x = x.mean(dim = 1)
|
||||
|
||||
x = self.to_latent(x)
|
||||
out = self.linear_head(x)
|
||||
|
||||
if return_history:
|
||||
return out, history
|
||||
|
||||
return out
|
||||
|
||||
if __name__ == '__main__':
|
||||
for learned_query in (True, False):
|
||||
v = SimpleViTAttnResidual(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
learned_query = learned_query
|
||||
)
|
||||
|
||||
img = torch.randn(2, 3, 256, 256)
|
||||
preds, history = v(img, return_history = True)
|
||||
|
||||
assert preds.shape == (2, 1000)
|
||||
|
||||
preds, _ = v(img, history = history, return_history = True)
|
||||
|
||||
assert preds.shape == (2, 1000)
|
||||
206
vit_pytorch/simple_vit_orthog_residual_update.py
Normal file
206
vit_pytorch/simple_vit_orthog_residual_update.py
Normal file
@@ -0,0 +1,206 @@
|
||||
# Revisiting Residual Connections: Orthogonal Updates for Stable and Efficient Deep Networks
|
||||
# Giyeong Oh et al. https://arxiv.org/abs/2505.11881
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Module, ModuleList
|
||||
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
|
||||
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
|
||||
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
|
||||
omega = torch.arange(dim // 4) / (dim // 4 - 1)
|
||||
omega = 1.0 / (temperature ** omega)
|
||||
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
|
||||
return pe.type(dtype)
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class OrthogonalResidualUpdate(Module):
|
||||
def __init__(
|
||||
self,
|
||||
block: Module,
|
||||
dim = None,
|
||||
double_precision = True,
|
||||
learned = False
|
||||
):
|
||||
super().__init__()
|
||||
self.block = block
|
||||
self.double_precision = double_precision
|
||||
|
||||
self.learned = learned
|
||||
|
||||
if learned:
|
||||
assert exists(dim)
|
||||
self.to_modulation = nn.Linear(dim, 2)
|
||||
|
||||
def orthog_proj(self, block_out, residual):
|
||||
use_double, dtype = self.double_precision, residual.dtype
|
||||
|
||||
if use_double:
|
||||
residual, block_out = residual.double(), block_out.double()
|
||||
|
||||
# get orthogonal projection of the attention or feedforward output respect to residual
|
||||
|
||||
unit = F.normalize(residual, dim = -1)
|
||||
parallel = (block_out * unit).sum(dim = -1, keepdim = True) * unit
|
||||
orthogonal = block_out - parallel
|
||||
|
||||
# back to original dtype if double precision
|
||||
|
||||
if use_double:
|
||||
parallel, orthogonal = parallel.to(dtype), orthogonal.to(dtype)
|
||||
|
||||
return parallel, orthogonal
|
||||
|
||||
def forward(self, residual):
|
||||
block_out = self.block(residual)
|
||||
|
||||
parallel_update, orthog_update = self.orthog_proj(block_out, residual)
|
||||
|
||||
if self.learned:
|
||||
parallel_mod, orthog_mod = self.to_modulation(block_out).sigmoid().split(1, dim = -1)
|
||||
parallel_update = parallel_update * parallel_mod
|
||||
orthog_update = orthog_update * orthog_mod
|
||||
else:
|
||||
parallel_update = 0
|
||||
|
||||
return residual + parallel_update + orthog_update
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, orthog_residual_update_kwargs: dict = dict()):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
attn = Attention(dim, heads = heads, dim_head = dim_head)
|
||||
ff = FeedForward(dim, mlp_dim)
|
||||
|
||||
self.layers.append(ModuleList([
|
||||
OrthogonalResidualUpdate(attn, dim = dim, **orthog_residual_update_kwargs),
|
||||
OrthogonalResidualUpdate(ff, dim = dim, **orthog_residual_update_kwargs)
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x)
|
||||
x = ff(x)
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class SimpleViT(Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, orthog_residual_update_kwargs: dict = dict()):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = posemb_sincos_2d(
|
||||
h = image_height // patch_height,
|
||||
w = image_width // patch_width,
|
||||
dim = dim,
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, orthog_residual_update_kwargs)
|
||||
|
||||
self.pool = "mean"
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.linear_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, img):
|
||||
device = img.device
|
||||
|
||||
x = self.to_patch_embedding(img)
|
||||
x += self.pos_embedding.to(device, dtype=x.dtype)
|
||||
|
||||
x = self.transformer(x)
|
||||
x = x.mean(dim = 1)
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.linear_head(x)
|
||||
|
||||
# quick test
|
||||
|
||||
if __name__ == '__main__':
|
||||
vit = SimpleViT(
|
||||
image_size = 256,
|
||||
patch_size = 16,
|
||||
num_classes = 10,
|
||||
dim = 512,
|
||||
depth = 2,
|
||||
heads = 4,
|
||||
mlp_dim = 2048,
|
||||
orthog_residual_update_kwargs = dict(
|
||||
learned = True
|
||||
)
|
||||
)
|
||||
|
||||
images = torch.randn(2, 3, 256, 256)
|
||||
|
||||
assert vit(images).shape == (2, 10)
|
||||
@@ -186,7 +186,7 @@ class SimpleViT(nn.Module):
|
||||
h = image_height // patch_height,
|
||||
w = image_width // patch_width,
|
||||
dim = dim,
|
||||
)
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, num_residual_streams)
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
|
||||
omega = 1. / (temperature ** omega)
|
||||
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
|
||||
return pe.type(dtype)
|
||||
|
||||
|
||||
@@ -119,7 +119,7 @@ class SimpleViT(nn.Module):
|
||||
h = image_height // patch_height,
|
||||
w = image_width // patch_width,
|
||||
dim = dim,
|
||||
)
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
|
||||
|
||||
|
||||
@@ -105,7 +105,7 @@ class SimpleViT(nn.Module):
|
||||
h = image_height // patch_height,
|
||||
w = image_width // patch_width,
|
||||
dim = dim,
|
||||
)
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
|
||||
|
||||
|
||||
205
vit_pytorch/simple_vit_with_specialized_cls.py
Normal file
205
vit_pytorch/simple_vit_with_specialized_cls.py
Normal file
@@ -0,0 +1,205 @@
|
||||
from __future__ import annotations
|
||||
|
||||
# Alexis Marouani et al. https://arxiv.org/abs/2602.08626
|
||||
|
||||
import torch
|
||||
from torch import nn, cat, Tensor, is_tensor
|
||||
from torch.nn import Module, ModuleList
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
def default(v, d):
|
||||
return v if exists(v) else d
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
|
||||
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
|
||||
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
|
||||
omega = torch.arange(dim // 4) / (dim // 4 - 1)
|
||||
omega = 1.0 / (temperature ** omega)
|
||||
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
pe = cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
|
||||
return pe.type(dtype)
|
||||
|
||||
# classes
|
||||
|
||||
class Specialized(Module):
|
||||
def __init__(self, modules: list[Module]):
|
||||
super().__init__()
|
||||
self.fns = ModuleList(modules)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor | list[Tensor],
|
||||
token_lens: tuple[int, ...] = None
|
||||
):
|
||||
if is_tensor(x):
|
||||
assert exists(token_lens)
|
||||
x = x.split(token_lens, dim = 1)
|
||||
|
||||
assert len(self.fns) == len(x)
|
||||
|
||||
out = tuple(fn(t) for fn, t in zip(self.fns, x))
|
||||
|
||||
if is_tensor:
|
||||
out = cat(out, dim = 1)
|
||||
|
||||
return out
|
||||
|
||||
class FeedForward(Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.norm = Specialized([
|
||||
nn.LayerNorm(dim),
|
||||
nn.LayerNorm(dim)
|
||||
])
|
||||
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
)
|
||||
def forward(self, x, token_lens = None):
|
||||
x = self.norm(x, token_lens = token_lens)
|
||||
return self.net(x)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, specialize_qkv = False):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.norm = Specialized([
|
||||
nn.LayerNorm(dim),
|
||||
nn.LayerNorm(dim)
|
||||
])
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.specialize_qkv = specialize_qkv
|
||||
|
||||
if specialize_qkv:
|
||||
self.to_qkv = Specialized([
|
||||
nn.Linear(dim, inner_dim * 3, bias = False),
|
||||
nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
])
|
||||
else:
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
def forward(self, x, token_lens = None):
|
||||
x = self.norm(x, token_lens = token_lens)
|
||||
|
||||
if self.specialize_qkv:
|
||||
qkv = self.to_qkv(x, token_lens = token_lens).chunk(3, dim = -1)
|
||||
else:
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
|
||||
q, k, v = (rearrange(t, 'b n (h d) -> b h n d', h = self.heads) for t in qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, specialize_qkv_depth):
|
||||
super().__init__()
|
||||
self.norm = Specialized([nn.LayerNorm(dim), nn.LayerNorm(dim)])
|
||||
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for ind in range(depth):
|
||||
specialize_qkv = ind < specialize_qkv_depth
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, specialize_qkv = specialize_qkv),
|
||||
FeedForward(dim, mlp_dim)
|
||||
]))
|
||||
|
||||
def forward(self, x, token_lens = None):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x, token_lens = token_lens) + x
|
||||
x = ff(x, token_lens = token_lens) + x
|
||||
|
||||
return self.norm(x, token_lens = token_lens)
|
||||
|
||||
class SimpleViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, specialize_qkv_depth = None):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = posemb_sincos_2d(
|
||||
h = image_height // patch_height,
|
||||
w = image_width // patch_width,
|
||||
dim = dim,
|
||||
)
|
||||
|
||||
self.cls_token = nn.Parameter(torch.randn(dim) * 1e-2)
|
||||
|
||||
specialize_qkv_depth = default(specialize_qkv_depth, depth // 3) # author found just first third of transformer having specialized qkv projection for cls token is enough
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, specialize_qkv_depth)
|
||||
|
||||
self.pool = 'cls'
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.linear_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, img):
|
||||
device = img.device
|
||||
|
||||
x = self.to_patch_embedding(img)
|
||||
x += self.pos_embedding.to(device, dtype=x.dtype)
|
||||
|
||||
b, n, _ = x.shape
|
||||
cls_tokens = repeat(self.cls_token, 'd -> b 1 d', b = b)
|
||||
x = cat((cls_tokens, x), dim = 1)
|
||||
|
||||
x = self.transformer(x, token_lens = (1, n))
|
||||
|
||||
x = x[:, 0]
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.linear_head(x)
|
||||
|
||||
if __name__ == '__main__':
|
||||
v = SimpleViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 16,
|
||||
mlp_dim = 2048
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 256, 256)
|
||||
out = v(img)
|
||||
|
||||
assert out.shape == (1, 1000)
|
||||
@@ -120,7 +120,7 @@ class SimpleViT(Module):
|
||||
h = image_height // patch_height,
|
||||
w = image_width // patch_width,
|
||||
dim = dim,
|
||||
)
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ def posemb_sincos_2d(
|
||||
omega = temperature ** -omega
|
||||
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
|
||||
pe = cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
|
||||
pe = pe.type(dtype)
|
||||
@@ -442,7 +442,8 @@ class VAAT(Module):
|
||||
self_attn_heads = 4,
|
||||
self_attn_dim_head = 32,
|
||||
ast_layer_indices: tuple[int, ...] | None = None,
|
||||
vit_layer_indices: tuple[int, ...] | None = None
|
||||
vit_layer_indices: tuple[int, ...] | None = None,
|
||||
num_advantage_bins = 0
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -480,7 +481,7 @@ class VAAT(Module):
|
||||
|
||||
assert len(ast_layer_indices) == depth, f'number of ast layer indices {len(ast_layer_indices)} does not much the VAAT depth {depth}'
|
||||
|
||||
self.register_buffer('ast_layer_indices', tensor(vit_layer_indices), persistent = False)
|
||||
self.register_buffer('ast_layer_indices', tensor(ast_layer_indices), persistent = False)
|
||||
|
||||
# handle maybe multiple frames
|
||||
|
||||
@@ -511,6 +512,14 @@ class VAAT(Module):
|
||||
|
||||
self.action_pos_emb = nn.Parameter(torch.randn(action_chunk_len, dim) * 1e-2)
|
||||
|
||||
# handle maybe advantage conditioning
|
||||
|
||||
self.has_advantages = num_advantage_bins > 0
|
||||
self.num_advantage_bins = num_advantage_bins
|
||||
|
||||
if self.has_advantages:
|
||||
self.advantage_emb = nn.Embedding(num_advantage_bins + 1, dim)
|
||||
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
@@ -540,14 +549,15 @@ class VAAT(Module):
|
||||
video_or_image, # (b v? c t? h w) - batch, views [wrist + third person or more], channels, maybe time, height, width
|
||||
audio_or_spec, # (b v? t) | (b v?f t) - batch, audio len | batch, spec freq, time
|
||||
*,
|
||||
extra = None, # (b d) - batch, dim extra
|
||||
extra = None, # (b d) - batch, dim extra
|
||||
tasks = None, # (b)
|
||||
advantages = None,# (b)
|
||||
actions = None, # (b k d) - batch, action chunk length, action dimension
|
||||
return_hiddens = False,
|
||||
freeze_vit = False,
|
||||
freeze_ast = False
|
||||
):
|
||||
batch = video_or_image.shape[0]
|
||||
batch, device = video_or_image.shape[0], video_or_image.device
|
||||
return_loss = exists(actions)
|
||||
|
||||
# handle some various input dimensions
|
||||
@@ -655,53 +665,66 @@ class VAAT(Module):
|
||||
|
||||
audio_context = rearrange(audio_hiddens, 'l b v n d -> l b (v n) d')
|
||||
|
||||
# get main action tokens and maybe append extra
|
||||
# main action tokens
|
||||
|
||||
action_tokens = repeat(self.action_pos_emb, 'k d -> b k d', b = batch)
|
||||
action_tokens = repeat(self.action_pos_emb, 'n d -> b n d', b = batch)
|
||||
|
||||
has_extra = exists(extra)
|
||||
# maybe advantage tokens
|
||||
|
||||
if has_extra:
|
||||
assert self.accept_extra_token
|
||||
empty_token = action_tokens[:, 0:0]
|
||||
|
||||
extra_token = self.to_extra_token(extra)
|
||||
maybe_advantage_embed = empty_token
|
||||
|
||||
action_tokens, packed_extra = pack([action_tokens, extra_token], 'b * d')
|
||||
if self.has_advantages and exists(advantages):
|
||||
if isinstance(advantages, int):
|
||||
advantages = torch.full((batch,), advantages, device = device, dtype = torch.long)
|
||||
|
||||
maybe_advantage_embed = self.advantage_emb(advantages + 1)
|
||||
|
||||
# register tokens
|
||||
|
||||
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
|
||||
register_tokens = empty_token
|
||||
|
||||
action_tokens, registers_packed_shape = pack((register_tokens, action_tokens), 'b * d')
|
||||
if exists(self.register_tokens):
|
||||
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
|
||||
|
||||
# cross attention
|
||||
# extra
|
||||
|
||||
hiddens = [action_tokens]
|
||||
maybe_extra_embed = empty_token
|
||||
|
||||
has_extra = exists(extra)
|
||||
if has_extra:
|
||||
assert self.accept_extra_token
|
||||
|
||||
maybe_extra_embed = self.to_extra_token(extra)
|
||||
|
||||
# pack all tokens for attention
|
||||
|
||||
tokens, ps = pack((register_tokens, maybe_advantage_embed, action_tokens, maybe_extra_embed), 'b * d')
|
||||
|
||||
# transformer
|
||||
|
||||
hiddens = [tokens]
|
||||
|
||||
for (maybe_film, maybe_self_attn, image_cross_attn, audio_cross_attn, ff), image_layer_context, audio_layer_context in zip(self.layers, image_context, audio_context):
|
||||
|
||||
if exists(tasks):
|
||||
action_tokens = maybe_film(action_tokens, task_emb)
|
||||
if exists(maybe_film) and exists(tasks):
|
||||
tokens = maybe_film(tokens, task_emb)
|
||||
|
||||
action_tokens = image_cross_attn(action_tokens, image_layer_context) + action_tokens
|
||||
tokens = image_cross_attn(tokens, image_layer_context) + tokens
|
||||
|
||||
action_tokens = audio_cross_attn(action_tokens, audio_layer_context) + action_tokens
|
||||
tokens = audio_cross_attn(tokens, audio_layer_context) + tokens
|
||||
|
||||
if exists(maybe_self_attn):
|
||||
action_tokens = maybe_self_attn(action_tokens) + action_tokens
|
||||
tokens = maybe_self_attn(tokens) + tokens
|
||||
|
||||
action_tokens = ff(action_tokens) + action_tokens
|
||||
tokens = ff(tokens) + tokens
|
||||
|
||||
hiddens.append(action_tokens)
|
||||
hiddens.append(tokens)
|
||||
|
||||
# unpack registers
|
||||
# unpack register, advantage, action, and extra tokens
|
||||
|
||||
_, action_tokens = unpack(action_tokens, registers_packed_shape, 'b * d')
|
||||
|
||||
# maybe unpack extra
|
||||
|
||||
if has_extra:
|
||||
action_tokens, _ = unpack(action_tokens, packed_extra, 'b * d')
|
||||
maybe_register_embed, maybe_advantage_embed, action_tokens, maybe_extra_embed = unpack(tokens, ps, 'b * d')
|
||||
|
||||
# norm and prediction
|
||||
|
||||
@@ -744,43 +767,51 @@ if __name__ == '__main__':
|
||||
mlp_dim = 384 * 4
|
||||
)
|
||||
|
||||
vaat = VAAT(
|
||||
vit,
|
||||
ast,
|
||||
dim = 512,
|
||||
depth = 9,
|
||||
heads = 8,
|
||||
dim_head = 64,
|
||||
mlp_dim = 2048,
|
||||
dim_action = 20,
|
||||
action_chunk_len = 7,
|
||||
time_seq_len = 4,
|
||||
num_image_views = 2,
|
||||
num_audio_views = 2,
|
||||
num_tasks = 4,
|
||||
add_self_attn = True,
|
||||
dim_extra_token = 33, # extra token with some variable dimension
|
||||
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
|
||||
0, 0, 1, 1, 2, 2, 3, 3, 4
|
||||
),
|
||||
ast_layer_indices = (
|
||||
1, 1, 1, 2, 2, 2, 3, 3, 3
|
||||
for num_adv_bins in (0, 2, 10):
|
||||
vaat = VAAT(
|
||||
vit,
|
||||
ast,
|
||||
dim = 512,
|
||||
depth = 9,
|
||||
heads = 8,
|
||||
dim_head = 64,
|
||||
mlp_dim = 2048,
|
||||
dim_action = 20,
|
||||
action_chunk_len = 7,
|
||||
time_seq_len = 4,
|
||||
num_image_views = 2,
|
||||
num_audio_views = 2,
|
||||
num_tasks = 4,
|
||||
num_advantage_bins = num_adv_bins,
|
||||
add_self_attn = True,
|
||||
dim_extra_token = 33, # extra token with some variable dimension
|
||||
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
|
||||
0, 0, 1, 1, 2, 2, 3, 3, 4
|
||||
),
|
||||
ast_layer_indices = (
|
||||
1, 1, 1, 2, 2, 2, 3, 3, 3
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
images = torch.randn(2, 2, 3, 4, 256, 256) # (2 views with 4 frames)
|
||||
audio = torch.randn(2, 2, 14_100 * 5)
|
||||
images = torch.randn(2, 2, 3, 4, 256, 256) # (2 views with 4 frames)
|
||||
audio = torch.randn(2, 2, 14_100 * 5)
|
||||
|
||||
tasks = torch.randint(0, 4, (2,))
|
||||
extra = torch.randn(2, 33) # extra internal state
|
||||
tasks = torch.randint(0, 4, (2,))
|
||||
extra = torch.randn(2, 33) # extra internal state
|
||||
|
||||
actions = torch.randn(2, 7, 20) # actions for learning
|
||||
# advantage conditioning
|
||||
|
||||
loss = vaat(images, audio, actions = actions, tasks = tasks, extra = extra, freeze_vit = True)
|
||||
loss.backward()
|
||||
advantages = None
|
||||
if num_adv_bins > 0:
|
||||
advantages = torch.randint(-1, num_adv_bins, (2,))
|
||||
|
||||
# after much training
|
||||
actions = torch.randn(2, 7, 20) # actions for learning
|
||||
|
||||
pred_actions, hiddens = vaat(images, audio, tasks = tasks, extra = extra, return_hiddens = True)
|
||||
loss = vaat(images, audio, actions = actions, advantages = advantages, tasks = tasks, extra = extra, freeze_vit = True)
|
||||
loss.backward()
|
||||
|
||||
assert pred_actions.shape == (2, 7, 20)
|
||||
# after much training
|
||||
|
||||
pred_actions, hiddens = vaat(images, audio, advantages = advantages, tasks = tasks, extra = extra, return_hiddens = True)
|
||||
|
||||
assert pred_actions.shape == (2, 7, 20)
|
||||
|
||||
@@ -278,7 +278,8 @@ class VAT(Module):
|
||||
add_self_attn = True, # in the paper, they didn't have any ways for the action token to exchange information with the extra token, so we'll just add it as an option
|
||||
self_attn_heads = 4,
|
||||
self_attn_dim_head = 32,
|
||||
vit_layer_indices: tuple[int, ...] | None = None
|
||||
vit_layer_indices: tuple[int, ...] | None = None,
|
||||
num_advantage_bins = 0
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -324,6 +325,14 @@ class VAT(Module):
|
||||
|
||||
self.action_pos_emb = nn.Parameter(torch.randn(action_chunk_len, dim) * 1e-2)
|
||||
|
||||
# handle maybe advantage conditioning
|
||||
|
||||
self.has_advantages = num_advantage_bins > 0
|
||||
self.num_advantage_bins = num_advantage_bins
|
||||
|
||||
if self.has_advantages:
|
||||
self.advantage_emb = nn.Embedding(num_advantage_bins + 1, dim)
|
||||
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
@@ -351,13 +360,14 @@ class VAT(Module):
|
||||
self,
|
||||
video_or_image, # (b v? c t? h w) - batch, views [wrist + third person or more], channels, maybe time, height, width
|
||||
*,
|
||||
extra = None, # (b d) - batch, dim extra
|
||||
extra = None, # (b d) - batch, dim extra
|
||||
tasks = None, # (b)
|
||||
advantages = None,# (b)
|
||||
actions = None, # (b k d) - batch, action chunk length, action dimension
|
||||
return_hiddens = False,
|
||||
freeze_vit = False
|
||||
):
|
||||
batch = video_or_image.shape[0]
|
||||
batch, device = video_or_image.shape[0], video_or_image.device
|
||||
return_loss = exists(actions)
|
||||
|
||||
# handle some various input dimensions
|
||||
@@ -423,51 +433,64 @@ class VAT(Module):
|
||||
|
||||
context = rearrange(hiddens, 'l b v t n d -> l b (v t n) d')
|
||||
|
||||
# get main action tokens and maybe append extra
|
||||
# main action tokens
|
||||
|
||||
action_tokens = repeat(self.action_pos_emb, 'k d -> b k d', b = batch)
|
||||
action_tokens = repeat(self.action_pos_emb, 'n d -> b n d', b = batch)
|
||||
|
||||
has_extra = exists(extra)
|
||||
# maybe advantage tokens
|
||||
|
||||
if has_extra:
|
||||
assert self.accept_extra_token
|
||||
empty_token = action_tokens[:, 0:0]
|
||||
|
||||
extra_token = self.to_extra_token(extra)
|
||||
maybe_advantage_embed = empty_token
|
||||
|
||||
action_tokens, packed_extra = pack([action_tokens, extra_token], 'b * d')
|
||||
if self.has_advantages and exists(advantages):
|
||||
if isinstance(advantages, int):
|
||||
advantages = torch.full((batch,), advantages, device = device, dtype = torch.long)
|
||||
|
||||
maybe_advantage_embed = self.advantage_emb(advantages + 1)
|
||||
|
||||
# register tokens
|
||||
|
||||
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
|
||||
register_tokens = empty_token
|
||||
|
||||
action_tokens, registers_packed_shape = pack((register_tokens, action_tokens), 'b * d')
|
||||
if exists(self.register_tokens):
|
||||
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
|
||||
|
||||
# cross attention
|
||||
# extra
|
||||
|
||||
hiddens = [action_tokens]
|
||||
maybe_extra_embed = empty_token
|
||||
|
||||
has_extra = exists(extra)
|
||||
if has_extra:
|
||||
assert self.accept_extra_token
|
||||
|
||||
maybe_extra_embed = self.to_extra_token(extra)
|
||||
|
||||
# pack all tokens for attention
|
||||
|
||||
tokens, ps = pack((register_tokens, maybe_advantage_embed, action_tokens, maybe_extra_embed), 'b * d')
|
||||
|
||||
# transformer
|
||||
|
||||
hiddens = [tokens]
|
||||
|
||||
for (maybe_film, maybe_self_attn, cross_attn, ff), layer_context in zip(self.layers, context):
|
||||
|
||||
if exists(tasks):
|
||||
action_tokens = maybe_film(action_tokens, task_emb)
|
||||
if exists(maybe_film) and exists(tasks):
|
||||
tokens = maybe_film(tokens, task_emb)
|
||||
|
||||
action_tokens = cross_attn(action_tokens, layer_context) + action_tokens
|
||||
tokens = cross_attn(tokens, layer_context) + tokens
|
||||
|
||||
if exists(maybe_self_attn):
|
||||
action_tokens = maybe_self_attn(action_tokens) + action_tokens
|
||||
tokens = maybe_self_attn(tokens) + tokens
|
||||
|
||||
action_tokens = ff(action_tokens) + action_tokens
|
||||
tokens = ff(tokens) + tokens
|
||||
|
||||
hiddens.append(action_tokens)
|
||||
hiddens.append(tokens)
|
||||
|
||||
# unpack registers
|
||||
# unpack register, advantage, action, and extra tokens
|
||||
|
||||
_, action_tokens = unpack(action_tokens, registers_packed_shape, 'b * d')
|
||||
|
||||
# maybe unpack extra
|
||||
|
||||
if has_extra:
|
||||
action_tokens, _ = unpack(action_tokens, packed_extra, 'b * d')
|
||||
maybe_register_embed, maybe_advantage_embed, action_tokens, maybe_extra_embed = unpack(tokens, ps, 'b * d')
|
||||
|
||||
# norm and prediction
|
||||
|
||||
@@ -501,36 +524,44 @@ if __name__ == '__main__':
|
||||
mlp_dim = 1024
|
||||
)
|
||||
|
||||
vat = VAT(
|
||||
vit,
|
||||
dim = 512,
|
||||
depth = 9,
|
||||
heads = 8,
|
||||
dim_head = 64,
|
||||
mlp_dim = 2048,
|
||||
dim_action = 20,
|
||||
action_chunk_len = 7,
|
||||
time_seq_len = 4,
|
||||
num_views = 2,
|
||||
num_tasks = 4,
|
||||
add_self_attn = True,
|
||||
dim_extra_token = 33, # extra token with some variable dimension
|
||||
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
|
||||
0, 0, 1, 1, 2, 2, 3, 3, 4
|
||||
for num_adv_bins in (0, 2, 10):
|
||||
vat = VAT(
|
||||
vit,
|
||||
dim = 512,
|
||||
depth = 9,
|
||||
heads = 8,
|
||||
dim_head = 64,
|
||||
mlp_dim = 2048,
|
||||
dim_action = 20,
|
||||
action_chunk_len = 7,
|
||||
time_seq_len = 4,
|
||||
num_views = 2,
|
||||
num_tasks = 4,
|
||||
num_advantage_bins = num_adv_bins,
|
||||
add_self_attn = True,
|
||||
dim_extra_token = 33, # extra token with some variable dimension
|
||||
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
|
||||
0, 0, 1, 1, 2, 2, 3, 3, 4
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
images = torch.randn(2, 2, 3, 4, 256, 256) # (2 views with 4 frames)
|
||||
tasks = torch.randint(0, 4, (2,))
|
||||
extra = torch.randn(2, 33) # extra internal state
|
||||
images = torch.randn(2, 2, 3, 4, 256, 256) # (2 views with 4 frames)
|
||||
tasks = torch.randint(0, 4, (2,))
|
||||
extra = torch.randn(2, 33) # extra internal state
|
||||
|
||||
actions = torch.randn(2, 7, 20) # actions for learning
|
||||
# advantage conditioning
|
||||
|
||||
loss = vat(images, actions = actions, tasks = tasks, extra = extra, freeze_vit = True)
|
||||
loss.backward()
|
||||
advantages = None
|
||||
if num_adv_bins > 0:
|
||||
advantages = torch.randint(-1, num_adv_bins, (2,))
|
||||
|
||||
# after much training
|
||||
actions = torch.randn(2, 7, 20) # actions for learning
|
||||
|
||||
pred_actions, hiddens = vat(images, tasks = tasks, extra = extra, return_hiddens = True)
|
||||
loss = vat(images, actions = actions, advantages = advantages, tasks = tasks, extra = extra, freeze_vit = True)
|
||||
loss.backward()
|
||||
|
||||
assert pred_actions.shape == (2, 7, 20)
|
||||
# after much training
|
||||
|
||||
pred_actions, hiddens = vat(images, advantages = advantages, tasks = tasks, extra = extra, return_hiddens = True)
|
||||
|
||||
assert pred_actions.shape == (2, 7, 20)
|
||||
|
||||
@@ -188,6 +188,7 @@ class SigLIPVAT(Module):
|
||||
self_attn_heads = 4,
|
||||
self_attn_dim_head = 32,
|
||||
vit_layer_indices: tuple[int, ...] | None = None,
|
||||
num_advantage_bins = 0,
|
||||
siglip_image_size = 224,
|
||||
siglip_patch_size = 14,
|
||||
siglip_dim = 1152,
|
||||
@@ -240,6 +241,14 @@ class SigLIPVAT(Module):
|
||||
|
||||
self.action_pos_emb = nn.Parameter(torch.randn(action_chunk_len, dim) * 1e-2)
|
||||
|
||||
# handle maybe advantage conditioning
|
||||
|
||||
self.has_advantages = num_advantage_bins > 0
|
||||
self.num_advantage_bins = num_advantage_bins
|
||||
|
||||
if self.has_advantages:
|
||||
self.advantage_emb = nn.Embedding(num_advantage_bins + 1, dim)
|
||||
|
||||
self.layers = ModuleList([])
|
||||
for _ in range(depth):
|
||||
maybe_film = FiLM(dim = dim) if self.has_tasks else None
|
||||
@@ -281,13 +290,13 @@ class SigLIPVAT(Module):
|
||||
# Auto-detect prefix based on keys
|
||||
with safe_open(weights_path, framework = 'pt') as f:
|
||||
keys = f.keys()
|
||||
|
||||
|
||||
vi_p = ''
|
||||
if any(k.startswith('paligemma_with_expert.paligemma.model.vision_tower.vision_model') for k in keys):
|
||||
vi_p = 'paligemma_with_expert.paligemma.model.vision_tower.vision_model.'
|
||||
elif any(k.startswith('vision_model') for k in keys):
|
||||
vi_p = 'vision_model.'
|
||||
|
||||
|
||||
pz_state = self.vit.state_dict()
|
||||
|
||||
def copy_weight_bias(pz_prefix, vi_prefix):
|
||||
@@ -333,15 +342,16 @@ class SigLIPVAT(Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
video_or_image, # (b v? c t? h w)
|
||||
video_or_image, # (b v? c t? h w) - batch, views [wrist + third person or more], channels, maybe time, height, width
|
||||
*,
|
||||
extra = None,
|
||||
tasks = None,
|
||||
actions = None,
|
||||
extra = None, # (b d) - batch, dim extra
|
||||
tasks = None, # (b)
|
||||
advantages = None,# (b)
|
||||
actions = None, # (b k d) - batch, action chunk length, action dimension
|
||||
return_hiddens = False,
|
||||
freeze_vit = False
|
||||
):
|
||||
batch = video_or_image.shape[0]
|
||||
batch, device = video_or_image.shape[0], video_or_image.device
|
||||
return_loss = exists(actions)
|
||||
|
||||
# handle some various input dimensions
|
||||
@@ -397,46 +407,62 @@ class SigLIPVAT(Module):
|
||||
|
||||
context = rearrange(hiddens, 'l b v t n d -> l b (v t n) d')
|
||||
|
||||
# get main action tokens and maybe append extra
|
||||
# main action tokens
|
||||
|
||||
action_tokens = repeat(self.action_pos_emb, 'k d -> b k d', b = batch)
|
||||
action_tokens = repeat(self.action_pos_emb, 'n d -> b n d', b = batch)
|
||||
|
||||
has_extra = exists(extra)
|
||||
if has_extra:
|
||||
extra_token = self.to_extra_token(extra)
|
||||
action_tokens, packed_extra = pack([action_tokens, extra_token], 'b * d')
|
||||
# maybe advantage tokens
|
||||
|
||||
empty_token = action_tokens[:, 0:0]
|
||||
|
||||
maybe_advantage_embed = empty_token
|
||||
|
||||
if self.has_advantages and exists(advantages):
|
||||
if isinstance(advantages, int):
|
||||
advantages = torch.full((batch,), advantages, device = device, dtype = torch.long)
|
||||
|
||||
maybe_advantage_embed = self.advantage_emb(advantages + 1)
|
||||
|
||||
# register tokens
|
||||
|
||||
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
|
||||
action_tokens, registers_packed_shape = pack((register_tokens, action_tokens), 'b * d')
|
||||
register_tokens = empty_token
|
||||
|
||||
# cross attention
|
||||
if exists(self.register_tokens):
|
||||
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
|
||||
|
||||
vat_hiddens = [action_tokens]
|
||||
# extra
|
||||
|
||||
maybe_extra_embed = empty_token
|
||||
|
||||
has_extra = exists(extra)
|
||||
if has_extra:
|
||||
maybe_extra_embed = self.to_extra_token(extra)
|
||||
|
||||
# pack all tokens for attention
|
||||
|
||||
tokens, ps = pack((register_tokens, maybe_advantage_embed, action_tokens, maybe_extra_embed), 'b * d')
|
||||
|
||||
# transformer
|
||||
|
||||
vat_hiddens = [tokens]
|
||||
|
||||
for (maybe_film, maybe_self_attn, cross_attn, ff), layer_context in zip(self.layers, context):
|
||||
|
||||
if exists(tasks):
|
||||
action_tokens = maybe_film(action_tokens, task_emb)
|
||||
if exists(maybe_film) and exists(tasks):
|
||||
tokens = maybe_film(tokens, task_emb)
|
||||
|
||||
action_tokens = cross_attn(action_tokens, layer_context) + action_tokens
|
||||
tokens = cross_attn(tokens, layer_context) + tokens
|
||||
|
||||
if exists(maybe_self_attn):
|
||||
action_tokens = maybe_self_attn(action_tokens) + action_tokens
|
||||
tokens = maybe_self_attn(tokens) + tokens
|
||||
|
||||
action_tokens = ff(action_tokens) + action_tokens
|
||||
tokens = ff(tokens) + tokens
|
||||
|
||||
vat_hiddens.append(action_tokens)
|
||||
vat_hiddens.append(tokens)
|
||||
|
||||
# unpack registers
|
||||
# unpack register, advantage, action, and extra tokens
|
||||
|
||||
_, action_tokens = unpack(action_tokens, registers_packed_shape, 'b * d')
|
||||
|
||||
# maybe unpack extra
|
||||
|
||||
if has_extra:
|
||||
action_tokens, _ = unpack(action_tokens, packed_extra, 'b * d')
|
||||
maybe_register_embed, maybe_advantage_embed, action_tokens, maybe_extra_embed = unpack(tokens, ps, 'b * d')
|
||||
|
||||
# norm and prediction
|
||||
|
||||
@@ -456,32 +482,40 @@ class SigLIPVAT(Module):
|
||||
# quick test
|
||||
|
||||
if __name__ == '__main__':
|
||||
vat = SigLIPVAT(
|
||||
num_tasks = 4,
|
||||
dim_extra_token = 32,
|
||||
time_seq_len = 2,
|
||||
num_views = 2,
|
||||
depth = 4,
|
||||
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
|
||||
0, 1, 26, 27
|
||||
for num_adv_bins in (0, 2, 10):
|
||||
vat = SigLIPVAT(
|
||||
num_tasks = 4,
|
||||
dim_extra_token = 32,
|
||||
time_seq_len = 2,
|
||||
num_views = 2,
|
||||
depth = 4,
|
||||
num_advantage_bins = num_adv_bins,
|
||||
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
|
||||
0, 1, 26, 27
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
vat.load_siglip() # load siglip weights from hf
|
||||
vat.load_siglip() # load siglip weights from hf
|
||||
|
||||
# inputs
|
||||
# inputs
|
||||
|
||||
images = torch.randn(1, 2, 3, 2, 224, 224) # (b, v, c, t, h, w)
|
||||
tasks = torch.randint(0, 4, (1,))
|
||||
extra = torch.randn(1, 32)
|
||||
images = torch.randn(1, 2, 3, 2, 224, 224) # (b, v, c, t, h, w)
|
||||
tasks = torch.randint(0, 4, (1,))
|
||||
extra = torch.randn(1, 32)
|
||||
|
||||
actions = torch.randn(1, 50, 32) # actions for learning
|
||||
# advantage conditioning
|
||||
|
||||
loss = vat(images, actions = actions, tasks = tasks, extra = extra, freeze_vit = True)
|
||||
loss.backward()
|
||||
advantages = None
|
||||
if num_adv_bins > 0:
|
||||
advantages = torch.randint(-1, num_adv_bins, (1,))
|
||||
|
||||
# after much training
|
||||
actions = torch.randn(1, 50, 32) # actions for learning
|
||||
|
||||
pred_actions = vat(images, tasks = tasks, extra = extra)
|
||||
|
||||
assert pred_actions.shape == (1, 50, 32)
|
||||
loss = vat(images, actions = actions, advantages = advantages, tasks = tasks, extra = extra, freeze_vit = True)
|
||||
loss.backward()
|
||||
|
||||
# after much training
|
||||
|
||||
pred_actions = vat(images, advantages = advantages, tasks = tasks, extra = extra)
|
||||
|
||||
assert pred_actions.shape == (1, 50, 32)
|
||||
|
||||
254
vit_pytorch/vit_detpool.py
Normal file
254
vit_pytorch/vit_detpool.py
Normal file
@@ -0,0 +1,254 @@
|
||||
from __future__ import annotations
|
||||
|
||||
# DetPool ViT - a vit that accepts an object mask and attends and pools only using that mask - table 1
|
||||
# Dantong Niu et al. - https://openreview.net/forum?id=NZDaMcpXZm
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.nn import Module, ModuleList
|
||||
|
||||
from einops import rearrange, repeat, pack, unpack
|
||||
from einops.layers.torch import Rearrange, Reduce
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def masked_mean(t, mask, dim = 1, eps = 1e-5):
|
||||
if not exists(mask):
|
||||
return t.mean(dim = dim)
|
||||
|
||||
mask = rearrange(mask.bool(), '... -> ... 1')
|
||||
t = t.masked_fill(~mask, 0.)
|
||||
return t.sum(dim = dim) / mask.sum(dim = dim).clamp(min = eps)
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x, mask = None):
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = (rearrange(t, 'b n (h d) -> b h n d', h = self.heads) for t in qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b j -> b 1 1 j')
|
||||
mask_value = -torch.finfo(dots.dtype).max
|
||||
dots = dots.masked_fill(~mask, mask_value)
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
def forward(self, x, mask = None):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x, mask = mask) + x
|
||||
x = ff(x) + x
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class ViTDetPool(Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, use_cls_token = True, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., mask_generator: Module | None = None):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
|
||||
self.patch_height = patch_height
|
||||
self.patch_width = patch_width
|
||||
|
||||
self.downsample_mask = Reduce('b (h p1) (w p2) -> b (h w)', 'max', p1 = patch_height, p2 = patch_width)
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
# maybe cls
|
||||
|
||||
self.use_cls_token = use_cls_token
|
||||
|
||||
if use_cls_token:
|
||||
self.cls_token = nn.Parameter(torch.randn(dim) * 1e-2)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(num_patches, dim) * 1e-2)
|
||||
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Linear(dim, num_classes) if num_classes > 0 else None
|
||||
|
||||
self.mask_generator = mask_generator
|
||||
|
||||
def forward(self, img, object_mask = None):
|
||||
|
||||
if not exists(object_mask) and exists(self.mask_generator):
|
||||
with torch.no_grad():
|
||||
self.mask_generator.eval()
|
||||
object_mask = self.mask_generator(img)
|
||||
|
||||
has_cls = self.use_cls_token
|
||||
|
||||
batch, _, height, width = img.shape
|
||||
tokens = self.to_patch_embedding(img)
|
||||
|
||||
seq = tokens.shape[1]
|
||||
tokens = tokens + self.pos_embedding[:seq]
|
||||
|
||||
if has_cls:
|
||||
cls_token = repeat(self.cls_token, 'd -> b d', b = batch)
|
||||
tokens, packed_shape = pack((cls_token, tokens), 'b * d')
|
||||
|
||||
tokens = self.dropout(tokens)
|
||||
|
||||
# handle the attention mask, and for final pooling
|
||||
|
||||
mask = None
|
||||
|
||||
if exists(object_mask):
|
||||
assert object_mask.ndim in {3, 2}
|
||||
|
||||
if object_mask.shape == (batch, height, width):
|
||||
mask = self.downsample_mask(object_mask)
|
||||
else:
|
||||
mask = object_mask
|
||||
|
||||
mask = rearrange(mask, 'b ... -> b (...)')
|
||||
|
||||
assert mask.shape == (batch, seq)
|
||||
|
||||
if has_cls:
|
||||
mask = F.pad(mask, (1, 0), value = True)
|
||||
|
||||
# attend with maybe mask
|
||||
|
||||
tokens = self.transformer(tokens, mask = mask)
|
||||
|
||||
if not exists(self.mlp_head):
|
||||
return tokens
|
||||
|
||||
# splice out cls
|
||||
|
||||
if has_cls:
|
||||
_, tokens = unpack(tokens, packed_shape, 'b * d')
|
||||
|
||||
if exists(mask):
|
||||
mask = mask[..., 1:]
|
||||
|
||||
# pooling with the mask
|
||||
|
||||
pooled = masked_mean(tokens, mask, dim = 1)
|
||||
|
||||
pooled = self.to_latent(pooled)
|
||||
return self.mlp_head(pooled)
|
||||
|
||||
# quick test
|
||||
|
||||
if __name__ == '__main__':
|
||||
vit = ViTDetPool(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 256, 256)
|
||||
object_mask = torch.randint(0, 2, (1, 256, 256)).bool()
|
||||
|
||||
preds = vit(img, object_mask = object_mask)
|
||||
assert preds.shape == (1, 1000)
|
||||
|
||||
preds_no_mask = vit(img)
|
||||
assert preds_no_mask.shape == (1, 1000)
|
||||
|
||||
# test with module included
|
||||
|
||||
class MockMasker(Module):
|
||||
def forward(self, img):
|
||||
batch, _, height, width = img.shape
|
||||
return torch.ones(batch, height, width).bool()
|
||||
|
||||
vit = ViTDetPool(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 1,
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
mask_generator = MockMasker()
|
||||
)
|
||||
|
||||
preds = vit(img)
|
||||
assert preds.shape == (1, 1000)
|
||||
@@ -31,7 +31,7 @@ class FeedForward(Module):
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
@@ -40,31 +40,31 @@ class Attention(Module):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
@@ -79,7 +79,7 @@ class Transformer(Module):
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
@@ -105,73 +105,73 @@ class ViTND(Module):
|
||||
emb_dropout: float = 0.
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
assert 1 <= ndim <= 7, 'ndim must be between 1 and 7'
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
|
||||
|
||||
self.ndim = ndim
|
||||
self.pool = pool
|
||||
|
||||
|
||||
input_shape = ensure_tuple(input_shape, ndim)
|
||||
patch_size = ensure_tuple(patch_size, ndim)
|
||||
|
||||
|
||||
for i, (inp_dim, patch_dim) in enumerate(zip(input_shape, patch_size)):
|
||||
assert inp_dim % patch_dim == 0, f'Input dimension {i} ({inp_dim}) must be divisible by patch size ({patch_dim})'
|
||||
|
||||
|
||||
num_patches_per_dim = [inp_dim // patch_dim for inp_dim, patch_dim in zip(input_shape, patch_size)]
|
||||
num_patches = 1
|
||||
for n in num_patches_per_dim:
|
||||
num_patches *= n
|
||||
|
||||
|
||||
patch_dim = channels
|
||||
for p in patch_size:
|
||||
patch_dim *= p
|
||||
|
||||
|
||||
dim_names = 'fghijkl'[:ndim]
|
||||
|
||||
|
||||
input_dims = [f'({d} p{i})' for i, d in enumerate(dim_names)]
|
||||
patch_dims = [f'p{i}' for i in range(ndim)]
|
||||
|
||||
|
||||
input_pattern = f'b c {join(input_dims)}'
|
||||
output_pattern = f'b ({join(dim_names)}) ({join(patch_dims)} c)'
|
||||
rearrange_str = f'{input_pattern} -> {output_pattern}'
|
||||
|
||||
|
||||
rearrange_kwargs = {f'p{i}': p for i, p in enumerate(patch_size)}
|
||||
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange(rearrange_str, **rearrange_kwargs),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
|
||||
self.to_latent = nn.Identity()
|
||||
self.mlp_head = nn.Linear(dim, num_classes)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
x = self.to_patch_embedding(x)
|
||||
b, n, _ = x.shape
|
||||
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim = 1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.dropout(x)
|
||||
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
|
||||
x = x[:, 1:].mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.mlp_head(x)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
|
||||
model = ViTND(
|
||||
ndim = 4,
|
||||
input_shape = (8, 16, 32, 64),
|
||||
@@ -185,7 +185,7 @@ if __name__ == '__main__':
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
|
||||
occupancy_time = torch.randn(2, 3, 8, 16, 32, 64)
|
||||
|
||||
|
||||
logits = model(occupancy_time)
|
||||
|
||||
@@ -121,7 +121,7 @@ class FeedForward(Module):
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
@@ -130,14 +130,14 @@ class Attention(Module):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
|
||||
self.to_qk = nn.Linear(dim, inner_dim * 2, bias = False)
|
||||
self.to_v = nn.Linear(dim, inner_dim, bias = False)
|
||||
|
||||
@@ -145,7 +145,7 @@ class Attention(Module):
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
|
||||
def forward(self, x, polar_pos_emb = None):
|
||||
x = self.norm(x)
|
||||
qkv = (*self.to_qk(x).chunk(2, dim = -1), self.to_v(x))
|
||||
@@ -156,12 +156,12 @@ class Attention(Module):
|
||||
freqs, bias = polar_pos_emb
|
||||
q = apply_polar_pos_emb(q, freqs)
|
||||
k = apply_polar_pos_emb(k, freqs + bias)
|
||||
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
@@ -180,7 +180,7 @@ class Transformer(Module):
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
|
||||
def forward(self, x, pos = None):
|
||||
|
||||
# pope embedding
|
||||
@@ -219,45 +219,45 @@ class ViTND(Module):
|
||||
pope_init_learned_bias_uniform = False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
assert 1 <= ndim <= 7, 'ndim must be between 1 and 7'
|
||||
|
||||
|
||||
self.ndim = ndim
|
||||
|
||||
|
||||
input_shape = ensure_tuple(input_shape, ndim)
|
||||
patch_size = ensure_tuple(patch_size, ndim)
|
||||
|
||||
|
||||
for i, (inp_dim, patch_dim) in enumerate(zip(input_shape, patch_size)):
|
||||
assert inp_dim % patch_dim == 0, f'Input dimension {i} ({inp_dim}) must be divisible by patch size ({patch_dim})'
|
||||
|
||||
|
||||
num_patches_per_dim = [inp_dim // patch_dim for inp_dim, patch_dim in zip(input_shape, patch_size)]
|
||||
num_patches = 1
|
||||
for n in num_patches_per_dim:
|
||||
num_patches *= n
|
||||
|
||||
|
||||
patch_dim = channels
|
||||
for p in patch_size:
|
||||
patch_dim *= p
|
||||
|
||||
|
||||
dim_names = 'fghijkl'[:ndim]
|
||||
|
||||
|
||||
input_dims = [f'({d} p{i})' for i, d in enumerate(dim_names)]
|
||||
patch_dims = [f'p{i}' for i in range(ndim)]
|
||||
|
||||
|
||||
input_pattern = f'b c {join(input_dims)}'
|
||||
output_pattern = f'b {join(dim_names)} ({join(patch_dims)} c)'
|
||||
rearrange_str = f'{input_pattern} -> {output_pattern}'
|
||||
|
||||
|
||||
rearrange_kwargs = {f'p{i}': p for i, p in enumerate(patch_size)}
|
||||
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange(rearrange_str, **rearrange_kwargs),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
|
||||
# golden gate pope
|
||||
|
||||
self.polar_emb = GoldenGatePoPENd(
|
||||
@@ -269,12 +269,12 @@ class ViTND(Module):
|
||||
p_zero_freqs = pope_p_zero_freqs,
|
||||
init_learned_bias_uniform = pope_init_learned_bias_uniform
|
||||
)
|
||||
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, polar_emb = self.polar_emb)
|
||||
|
||||
|
||||
self.to_latent = nn.Identity()
|
||||
self.mlp_head = nn.Linear(dim, num_classes)
|
||||
|
||||
|
||||
def muon_parameters(self):
|
||||
params = []
|
||||
|
||||
@@ -298,9 +298,9 @@ class ViTND(Module):
|
||||
return_embed = False
|
||||
):
|
||||
x = self.to_patch_embedding(x) # (b, *spatial_dims, patch_dim)
|
||||
|
||||
|
||||
batch, *spatial_dims, _, device = *x.shape, x.device
|
||||
|
||||
|
||||
# Generate position coordinates
|
||||
|
||||
grids = [arange(d, device = device, dtype = torch.float32) for d in spatial_dims]
|
||||
@@ -308,12 +308,12 @@ class ViTND(Module):
|
||||
pos = stack(grid, dim = -1) # (*spatial_dims, ndim)
|
||||
|
||||
# flatten spatial dimensions for attention with nd rotary
|
||||
|
||||
|
||||
pos = repeat(pos, '... p -> b (...) p', b = batch)
|
||||
x, packed_shape = pack([x], 'b * d')
|
||||
|
||||
x = self.dropout(x)
|
||||
|
||||
|
||||
embed = self.transformer(x, pos)
|
||||
|
||||
# return the embed with reconstituted patch shape
|
||||
@@ -330,7 +330,7 @@ class ViTND(Module):
|
||||
return self.mlp_head(pooled)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
|
||||
model = ViTND(
|
||||
ndim = 5,
|
||||
input_shape = (4, 8, 16, 32, 64),
|
||||
|
||||
@@ -75,23 +75,23 @@ class GoldenGateRoPENd(Module):
|
||||
# input shape: (b, h, n, d) where d = head_dim
|
||||
# pos shape: (b, n, p) where p = pos_dim
|
||||
# self.freqs shape: (h, f, p) where f = d // 2
|
||||
|
||||
|
||||
x, y = input.float().chunk(2, dim = -1) # both (b, h, n, f)
|
||||
|
||||
|
||||
# Expand dimensions for broadcasting
|
||||
freqs = rearrange(self.freqs, 'h f p -> 1 h 1 f p')
|
||||
positions = rearrange(pos.float(), 'b n p -> b 1 n 1 p')
|
||||
|
||||
|
||||
# Compute theta for each (batch, head, seq, freq)
|
||||
theta = reduce(freqs * positions, 'b h n f p -> b h n f', 'sum')
|
||||
|
||||
|
||||
cos_theta = torch.cos(theta)
|
||||
sin_theta = torch.sin(theta)
|
||||
|
||||
|
||||
# Apply rotation
|
||||
x_out = x * cos_theta - y * sin_theta
|
||||
y_out = x * sin_theta + y * cos_theta
|
||||
|
||||
|
||||
output = cat((x_out, y_out), dim=-1)
|
||||
return output.type_as(input)
|
||||
|
||||
@@ -108,7 +108,7 @@ class FeedForward(Module):
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
@@ -117,15 +117,15 @@ class Attention(Module):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.rotary_emb = rotary_emb
|
||||
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
|
||||
self.to_qk = nn.Linear(dim, inner_dim * 2, bias = False)
|
||||
self.to_v = nn.Linear(dim, inner_dim, bias = False)
|
||||
|
||||
@@ -133,24 +133,24 @@ class Attention(Module):
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
|
||||
def forward(self, x, pos = None):
|
||||
x = self.norm(x)
|
||||
qkv = (*self.to_qk(x).chunk(2, dim = -1), self.to_v(x))
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
|
||||
# Apply rotary embeddings if available
|
||||
if exists(self.rotary_emb):
|
||||
assert exists(pos)
|
||||
q = self.rotary_emb(q, pos)
|
||||
k = self.rotary_emb(k, pos)
|
||||
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
@@ -165,7 +165,7 @@ class Transformer(Module):
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, rotary_emb = rotary_emb),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
|
||||
def forward(self, x, pos = None):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x, pos) + x
|
||||
@@ -193,45 +193,45 @@ class ViTND(Module):
|
||||
rope_p_zero_freqs: float = 0.0
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
assert 1 <= ndim <= 7, 'ndim must be between 1 and 7'
|
||||
|
||||
|
||||
self.ndim = ndim
|
||||
|
||||
|
||||
input_shape = ensure_tuple(input_shape, ndim)
|
||||
patch_size = ensure_tuple(patch_size, ndim)
|
||||
|
||||
|
||||
for i, (inp_dim, patch_dim) in enumerate(zip(input_shape, patch_size)):
|
||||
assert inp_dim % patch_dim == 0, f'Input dimension {i} ({inp_dim}) must be divisible by patch size ({patch_dim})'
|
||||
|
||||
|
||||
num_patches_per_dim = [inp_dim // patch_dim for inp_dim, patch_dim in zip(input_shape, patch_size)]
|
||||
num_patches = 1
|
||||
for n in num_patches_per_dim:
|
||||
num_patches *= n
|
||||
|
||||
|
||||
patch_dim = channels
|
||||
for p in patch_size:
|
||||
patch_dim *= p
|
||||
|
||||
|
||||
dim_names = 'fghijkl'[:ndim]
|
||||
|
||||
|
||||
input_dims = [f'({d} p{i})' for i, d in enumerate(dim_names)]
|
||||
patch_dims = [f'p{i}' for i in range(ndim)]
|
||||
|
||||
|
||||
input_pattern = f'b c {join(input_dims)}'
|
||||
output_pattern = f'b {join(dim_names)} ({join(patch_dims)} c)'
|
||||
rearrange_str = f'{input_pattern} -> {output_pattern}'
|
||||
|
||||
|
||||
rearrange_kwargs = {f'p{i}': p for i, p in enumerate(patch_size)}
|
||||
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange(rearrange_str, **rearrange_kwargs),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
|
||||
# Create rotary embeddings
|
||||
self.rotary_emb = GoldenGateRoPENd(
|
||||
dim_pos = ndim,
|
||||
@@ -241,12 +241,12 @@ class ViTND(Module):
|
||||
rope_max_freq = rope_max_freq,
|
||||
rope_p_zero_freqs = rope_p_zero_freqs
|
||||
)
|
||||
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, rotary_emb = self.rotary_emb)
|
||||
|
||||
|
||||
self.to_latent = nn.Identity()
|
||||
self.mlp_head = nn.Linear(dim, num_classes)
|
||||
|
||||
|
||||
def muon_parameters(self):
|
||||
params = []
|
||||
|
||||
@@ -270,9 +270,9 @@ class ViTND(Module):
|
||||
return_embed = False
|
||||
):
|
||||
x = self.to_patch_embedding(x) # (b, *spatial_dims, patch_dim)
|
||||
|
||||
|
||||
batch, *spatial_dims, _, device = *x.shape, x.device
|
||||
|
||||
|
||||
# Generate position coordinates
|
||||
|
||||
grids = [arange(d, device = device, dtype = torch.float32) for d in spatial_dims]
|
||||
@@ -280,12 +280,12 @@ class ViTND(Module):
|
||||
pos = stack(grid, dim = -1) # (*spatial_dims, ndim)
|
||||
|
||||
# flatten spatial dimensions for attention with nd rotary
|
||||
|
||||
|
||||
pos = repeat(pos, '... p -> b (...) p', b = batch)
|
||||
x, packed_shape = pack([x], 'b * d')
|
||||
|
||||
x = self.dropout(x)
|
||||
|
||||
|
||||
embed = self.transformer(x, pos)
|
||||
|
||||
# return the embed with reconstituted patch shape
|
||||
@@ -303,7 +303,7 @@ class ViTND(Module):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
|
||||
model = ViTND(
|
||||
ndim = 5,
|
||||
input_shape = (4, 8, 16, 32, 64),
|
||||
|
||||
@@ -231,4 +231,4 @@ if __name__ == '__main__':
|
||||
|
||||
decorr_loss = DecorrelationLoss(0.0001, soft_validate_num_sampled = True)
|
||||
out = decorr_loss(hiddens)
|
||||
assert out.item() == 0
|
||||
assert out.item() == 0
|
||||
|
||||
217
vit_pytorch/vit_with_keel_post_ln.py
Normal file
217
vit_pytorch/vit_with_keel_post_ln.py
Normal file
@@ -0,0 +1,217 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Module, ModuleList
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# functions
|
||||
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
def default(v, d):
|
||||
return v if exists(v) else d
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim, bias = False),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.norm = nn.LayerNorm(dim, bias = False)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
depth,
|
||||
heads,
|
||||
dim_head,
|
||||
mlp_dim,
|
||||
dropout = 0.,
|
||||
keel_residual_scale = None
|
||||
):
|
||||
super().__init__()
|
||||
assert depth > 1
|
||||
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.extend([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
])
|
||||
|
||||
num_layers = depth * 2
|
||||
self.keel_residual_scale = default(keel_residual_scale, num_layers)
|
||||
|
||||
self.post_norms = ModuleList([nn.LayerNorm(dim, bias = False) for _ in range(num_layers - 1)])
|
||||
|
||||
def forward(self, x):
|
||||
residual_scale = self.keel_residual_scale
|
||||
|
||||
for layer_ind, layer in enumerate(self.layers):
|
||||
first_layer = layer_ind == 0
|
||||
|
||||
residual = x
|
||||
|
||||
out = layer(x)
|
||||
|
||||
if first_layer:
|
||||
x = out + residual
|
||||
continue
|
||||
|
||||
post_norm = self.post_norms[layer_ind - 1]
|
||||
|
||||
x = post_norm(out + residual * residual_scale)
|
||||
|
||||
return x
|
||||
|
||||
class ViT(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_size,
|
||||
patch_size,
|
||||
num_classes,
|
||||
dim,
|
||||
depth,
|
||||
heads,
|
||||
mlp_dim,
|
||||
pool = 'cls',
|
||||
channels = 3,
|
||||
dim_head = 64,
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.,
|
||||
keel_residual_scale = None
|
||||
):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
num_cls_tokens = 1 if pool == 'cls' else 0
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.cls_token = nn.Parameter(torch.randn(num_cls_tokens, dim))
|
||||
self.pos_embedding = nn.Parameter(torch.randn(num_patches + num_cls_tokens, dim))
|
||||
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(
|
||||
dim,
|
||||
depth,
|
||||
heads,
|
||||
dim_head,
|
||||
mlp_dim,
|
||||
dropout,
|
||||
keel_residual_scale = keel_residual_scale
|
||||
)
|
||||
|
||||
self.pool = pool
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Linear(dim, num_classes) if num_classes > 0 else None
|
||||
|
||||
def forward(self, img):
|
||||
batch = img.shape[0]
|
||||
x = self.to_patch_embedding(img)
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '... d -> b ... d', b = batch)
|
||||
x = torch.cat((cls_tokens, x), dim = 1)
|
||||
|
||||
seq = x.shape[1]
|
||||
|
||||
x = x + self.pos_embedding[:seq]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
if not exists(self.mlp_head):
|
||||
return x
|
||||
|
||||
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.mlp_head(x)
|
||||
|
||||
if __name__ == '__main__':
|
||||
v = ViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 256, 256)
|
||||
|
||||
preds = v(img)
|
||||
|
||||
assert preds.shape == (1, 1000)
|
||||
@@ -89,7 +89,6 @@ class Attention(Module):
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b j -> b 1 1 j')
|
||||
dots = dots.masked_fill(~mask, -torch.finfo(dots.dtype).max)
|
||||
|
||||
attn = self.attend(dots)
|
||||
@@ -109,7 +108,7 @@ class Transformer(Module):
|
||||
self.layers = ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_flash_attn = use_flash_attn),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user