mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
Compare commits
42 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a36546df23 | ||
|
|
d830b05f06 | ||
|
|
8208c859a5 | ||
|
|
4264efd906 | ||
|
|
b194359301 | ||
|
|
950c901b80 | ||
|
|
3e5d1be6f0 | ||
|
|
6e2393de95 | ||
|
|
32974c33df | ||
|
|
17675e0de4 | ||
|
|
598cffab53 | ||
|
|
23820bc54a | ||
|
|
e9ca1f4d57 | ||
|
|
d4daf7bd0f | ||
|
|
9e3fec2398 | ||
|
|
ce4bcd08fb | ||
|
|
ad4ca19775 | ||
|
|
e1b08c15b9 | ||
|
|
c59843d7b8 | ||
|
|
9a8e509b27 | ||
|
|
258dd8c7c6 | ||
|
|
4218556acd | ||
|
|
f621c2b041 | ||
|
|
5699ed7d13 | ||
|
|
46dcaf23d8 | ||
|
|
bdaf2d1491 | ||
|
|
500e23105a | ||
|
|
89e1996c8b | ||
|
|
2f87c0cf8f | ||
|
|
59c8948c6a | ||
|
|
cb6d749821 | ||
|
|
6ec8fdaa6d | ||
|
|
13fabf901e | ||
|
|
c0eb4c0150 | ||
|
|
5f1a6a05e9 | ||
|
|
9a95e7904e | ||
|
|
b4853d39c2 | ||
|
|
29fbf0aff4 | ||
|
|
4b8f5bc900 | ||
|
|
f86e052c05 | ||
|
|
2fa2b62def | ||
|
|
9f87d1c43b |
1
.github/workflows/python-test.yml
vendored
1
.github/workflows/python-test.yml
vendored
@@ -27,6 +27,7 @@ jobs:
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install pytest
|
||||
python -m pip install wheel
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
|
||||
276
README.md
276
README.md
@@ -7,6 +7,7 @@
|
||||
- [Usage](#usage)
|
||||
- [Parameters](#parameters)
|
||||
- [Simple ViT](#simple-vit)
|
||||
- [NaViT](#navit)
|
||||
- [Distillation](#distillation)
|
||||
- [Deep ViT](#deep-vit)
|
||||
- [CaiT](#cait)
|
||||
@@ -27,9 +28,12 @@
|
||||
- [Masked Autoencoder](#masked-autoencoder)
|
||||
- [Simple Masked Image Modeling](#simple-masked-image-modeling)
|
||||
- [Masked Patch Prediction](#masked-patch-prediction)
|
||||
- [Masked Position Prediction](#masked-position-prediction)
|
||||
- [Adaptive Token Sampling](#adaptive-token-sampling)
|
||||
- [Patch Merger](#patch-merger)
|
||||
- [Vision Transformer for Small Datasets](#vision-transformer-for-small-datasets)
|
||||
- [3D Vit](#3d-vit)
|
||||
- [ViVit](#vivit)
|
||||
- [Parallel ViT](#parallel-vit)
|
||||
- [Learnable Memory ViT](#learnable-memory-vit)
|
||||
- [Dino](#dino)
|
||||
@@ -52,6 +56,8 @@ The official Jax repository is <a href="https://github.com/google-research/visio
|
||||
|
||||
A tensorflow2 translation also exists <a href="https://github.com/taki0112/vit-tensorflow">here</a>, created by research scientist <a href="https://github.com/taki0112">Junho Kim</a>! 🙏
|
||||
|
||||
<a href="https://github.com/conceptofmind/vit-flax">Flax translation</a> by <a href="https://github.com/conceptofmind">Enrico Shippole</a>!
|
||||
|
||||
## Install
|
||||
|
||||
```bash
|
||||
@@ -134,6 +140,63 @@ img = torch.randn(1, 3, 256, 256)
|
||||
preds = v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## NaViT
|
||||
|
||||
<img src="./images/navit.png" width="450px"></img>
|
||||
|
||||
<a href="https://arxiv.org/abs/2307.06304">This paper</a> proposes to leverage the flexibility of attention and masking for variable lengthed sequences to train images of multiple resolution, packed into a single batch. They demonstrate much faster training and improved accuracies, with the only cost being extra complexity in the architecture and dataloading. They use factorized 2d positional encodings, token dropping, as well as query-key normalization.
|
||||
|
||||
You can use it as follows
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.na_vit import NaViT
|
||||
|
||||
v = NaViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1,
|
||||
token_dropout_prob = 0.1 # token dropout of 10% (keep 90% of tokens)
|
||||
)
|
||||
|
||||
# 5 images of different resolutions - List[List[Tensor]]
|
||||
|
||||
# for now, you'll have to correctly place images in same batch element as to not exceed maximum allowed sequence length for self-attention w/ masking
|
||||
|
||||
images = [
|
||||
[torch.randn(3, 256, 256), torch.randn(3, 128, 128)],
|
||||
[torch.randn(3, 128, 256), torch.randn(3, 256, 128)],
|
||||
[torch.randn(3, 64, 256)]
|
||||
]
|
||||
|
||||
preds = v(images) # (5, 1000) - 5, because 5 images of different resolution above
|
||||
|
||||
```
|
||||
|
||||
Or if you would rather that the framework auto group the images into variable lengthed sequences that do not exceed a certain max length
|
||||
|
||||
```python
|
||||
images = [
|
||||
torch.randn(3, 256, 256),
|
||||
torch.randn(3, 128, 128),
|
||||
torch.randn(3, 128, 256),
|
||||
torch.randn(3, 256, 128),
|
||||
torch.randn(3, 64, 256)
|
||||
]
|
||||
|
||||
preds = v(
|
||||
images,
|
||||
group_images = True,
|
||||
group_max_seq_len = 64
|
||||
) # (5, 1000)
|
||||
```
|
||||
|
||||
## Distillation
|
||||
|
||||
<img src="./images/distill.png" width="300px"></img>
|
||||
@@ -299,7 +362,7 @@ cct = CCT(
|
||||
pooling_padding = 1,
|
||||
num_layers = 14,
|
||||
num_heads = 6,
|
||||
mlp_radio = 3.,
|
||||
mlp_ratio = 3.,
|
||||
num_classes = 1000,
|
||||
positional_embedding = 'learnable', # ['sine', 'learnable', 'none']
|
||||
)
|
||||
@@ -661,7 +724,7 @@ preds = v(img) # (2, 1000)
|
||||
|
||||
<img src="./images/nest.png" width="400px"></img>
|
||||
|
||||
This <a href="https://arxiv.org/abs/2105.12723">paper</a> decided to process the image in hierarchical stages, with attention only within tokens of local blocks, which aggregate as it moves up the heirarchy. The aggregation is done in the image plane, and contains a convolution and subsequent maxpool to allow it to pass information across the boundary.
|
||||
This <a href="https://arxiv.org/abs/2105.12723">paper</a> decided to process the image in hierarchical stages, with attention only within tokens of local blocks, which aggregate as it moves up the hierarchy. The aggregation is done in the image plane, and contains a convolution and subsequent maxpool to allow it to pass information across the boundary.
|
||||
|
||||
You can use it with the following code (ex. NesT-T)
|
||||
|
||||
@@ -675,7 +738,7 @@ nest = NesT(
|
||||
dim = 96,
|
||||
heads = 3,
|
||||
num_hierarchies = 3, # number of hierarchies
|
||||
block_repeats = (2, 2, 8), # the number of transformer blocks at each heirarchy, starting from the bottom
|
||||
block_repeats = (2, 2, 8), # the number of transformer blocks at each hierarchy, starting from the bottom
|
||||
num_classes = 1000
|
||||
)
|
||||
|
||||
@@ -840,6 +903,44 @@ for _ in range(100):
|
||||
torch.save(model.state_dict(), './pretrained-net.pt')
|
||||
```
|
||||
|
||||
## Masked Position Prediction
|
||||
|
||||
<img src="./images/mp3.png" width="400px"></img>
|
||||
|
||||
New <a href="https://arxiv.org/abs/2207.07611">paper</a> that introduces masked position prediction pre-training criteria. This strategy is more efficient than the Masked Autoencoder strategy and has comparable performance.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.mp3 import ViT, MP3
|
||||
|
||||
v = ViT(
|
||||
num_classes = 1000,
|
||||
image_size = 256,
|
||||
patch_size = 8,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
)
|
||||
|
||||
mp3 = MP3(
|
||||
vit = v,
|
||||
masking_ratio = 0.75
|
||||
)
|
||||
|
||||
images = torch.randn(8, 3, 256, 256)
|
||||
|
||||
loss = mp3(images)
|
||||
loss.backward()
|
||||
|
||||
# that's all!
|
||||
# do the above in a for loop many times with a lot of images and your vision transformer will learn
|
||||
|
||||
# save your improved vision transformer
|
||||
torch.save(v.state_dict(), './trained-vit.pt')
|
||||
```
|
||||
|
||||
## Adaptive Token Sampling
|
||||
|
||||
<img src="./images/ats.png" width="400px"></img>
|
||||
@@ -965,6 +1066,117 @@ img = torch.randn(4, 3, 256, 256)
|
||||
tokens = spt(img) # (4, 256, 1024)
|
||||
```
|
||||
|
||||
## 3D ViT
|
||||
|
||||
By popular request, I will start extending a few of the architectures in this repository to 3D ViTs, for use with video, medical imaging, etc.
|
||||
|
||||
You will need to pass in two additional hyperparameters: (1) the number of frames `frames` and (2) patch size along the frame dimension `frame_patch_size`
|
||||
|
||||
For starters, 3D ViT
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.vit_3d import ViT
|
||||
|
||||
v = ViT(
|
||||
image_size = 128, # image size
|
||||
frames = 16, # number of frames
|
||||
image_patch_size = 16, # image patch size
|
||||
frame_patch_size = 2, # frame patch size
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width)
|
||||
|
||||
preds = v(video) # (4, 1000)
|
||||
```
|
||||
|
||||
3D Simple ViT
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.simple_vit_3d import SimpleViT
|
||||
|
||||
v = SimpleViT(
|
||||
image_size = 128, # image size
|
||||
frames = 16, # number of frames
|
||||
image_patch_size = 16, # image patch size
|
||||
frame_patch_size = 2, # frame patch size
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048
|
||||
)
|
||||
|
||||
video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width)
|
||||
|
||||
preds = v(video) # (4, 1000)
|
||||
```
|
||||
|
||||
3D version of <a href="https://github.com/lucidrains/vit-pytorch#cct">CCT</a>
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.cct_3d import CCT
|
||||
|
||||
cct = CCT(
|
||||
img_size = 224,
|
||||
num_frames = 8,
|
||||
embedding_dim = 384,
|
||||
n_conv_layers = 2,
|
||||
frame_kernel_size = 3,
|
||||
kernel_size = 7,
|
||||
stride = 2,
|
||||
padding = 3,
|
||||
pooling_kernel_size = 3,
|
||||
pooling_stride = 2,
|
||||
pooling_padding = 1,
|
||||
num_layers = 14,
|
||||
num_heads = 6,
|
||||
mlp_ratio = 3.,
|
||||
num_classes = 1000,
|
||||
positional_embedding = 'learnable'
|
||||
)
|
||||
|
||||
video = torch.randn(1, 3, 8, 224, 224) # (batch, channels, frames, height, width)
|
||||
pred = cct(video)
|
||||
```
|
||||
|
||||
## ViViT
|
||||
|
||||
<img src="./images/vivit.png" width="350px"></img>
|
||||
|
||||
This <a href="https://arxiv.org/abs/2103.15691">paper</a> offers 3 different types of architectures for efficient attention of videos, with the main theme being factorizing the attention across space and time. This repository will offer the first variant, which is a spatial transformer followed by a temporal one.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.vivit import ViT
|
||||
|
||||
v = ViT(
|
||||
image_size = 128, # image size
|
||||
frames = 16, # number of frames
|
||||
image_patch_size = 16, # image patch size
|
||||
frame_patch_size = 2, # frame patch size
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
spatial_depth = 6, # depth of the spatial transformer
|
||||
temporal_depth = 6, # depth of the temporal transformer
|
||||
heads = 8,
|
||||
mlp_dim = 2048
|
||||
)
|
||||
|
||||
video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width)
|
||||
|
||||
preds = v(video) # (4, 1000)
|
||||
```
|
||||
|
||||
## Parallel ViT
|
||||
|
||||
<img src="./images/parallel-vit.png" width="350px"></img>
|
||||
@@ -1748,6 +1960,46 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Arnab2021ViViTAV,
|
||||
title = {ViViT: A Video Vision Transformer},
|
||||
author = {Anurag Arnab and Mostafa Dehghani and Georg Heigold and Chen Sun and Mario Lucic and Cordelia Schmid},
|
||||
journal = {2021 IEEE/CVF International Conference on Computer Vision (ICCV)},
|
||||
year = {2021},
|
||||
pages = {6816-6826}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Liu2022PatchDropoutEV,
|
||||
title = {PatchDropout: Economizing Vision Transformers Using Patch Dropout},
|
||||
author = {Yue Liu and Christos Matsoukas and Fredrik Strand and Hossein Azizpour and Kevin Smith},
|
||||
journal = {ArXiv},
|
||||
year = {2022},
|
||||
volume = {abs/2208.07220}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{https://doi.org/10.48550/arxiv.2302.01327,
|
||||
doi = {10.48550/ARXIV.2302.01327},
|
||||
url = {https://arxiv.org/abs/2302.01327},
|
||||
author = {Kumar, Manoj and Dehghani, Mostafa and Houlsby, Neil},
|
||||
title = {Dual PatchNorm},
|
||||
publisher = {arXiv},
|
||||
year = {2023},
|
||||
copyright = {Creative Commons Attribution 4.0 International}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Dehghani2023PatchNP,
|
||||
title = {Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution},
|
||||
author = {Mostafa Dehghani and Basil Mustafa and Josip Djolonga and Jonathan Heek and Matthias Minderer and Mathilde Caron and Andreas Steiner and Joan Puigcerver and Robert Geirhos and Ibrahim M. Alabdulmohsin and Avital Oliver and Piotr Padlewski and Alexey A. Gritsenko and Mario Luvci'c and Neil Houlsby},
|
||||
year = {2023}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{vaswani2017attention,
|
||||
title = {Attention Is All You Need},
|
||||
@@ -1759,4 +2011,22 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{dao2022flashattention,
|
||||
title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
|
||||
author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
|
||||
booktitle = {Advances in Neural Information Processing Systems},
|
||||
year = {2022}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Darcet2023VisionTN,
|
||||
title = {Vision Transformers Need Registers},
|
||||
author = {Timoth'ee Darcet and Maxime Oquab and Julien Mairal and Piotr Bojanowski},
|
||||
year = {2023},
|
||||
url = {https://api.semanticscholar.org/CorpusID:263134283}
|
||||
}
|
||||
```
|
||||
|
||||
*I visualise a time when we will be to robots what dogs are to humans, and I’m rooting for the machines.* — Claude Shannon
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
"\n",
|
||||
"* Dogs vs. Cats Redux: Kernels Edition - https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition\n",
|
||||
"* Base Code - https://www.kaggle.com/reukki/pytorch-cnn-tutorial-with-cats-and-dogs/\n",
|
||||
"* Effecient Attention Implementation - https://github.com/lucidrains/vit-pytorch#efficient-attention"
|
||||
"* Efficient Attention Implementation - https://github.com/lucidrains/vit-pytorch#efficient-attention"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -342,7 +342,7 @@
|
||||
"id": "ZhYDJXk2SRDu"
|
||||
},
|
||||
"source": [
|
||||
"## Image Augumentation"
|
||||
"## Image Augmentation"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -497,7 +497,7 @@
|
||||
"id": "TF9yMaRrSvmv"
|
||||
},
|
||||
"source": [
|
||||
"## Effecient Attention"
|
||||
"## Efficient Attention"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -1307,7 +1307,7 @@
|
||||
"celltoolbar": "Edit Metadata",
|
||||
"colab": {
|
||||
"collapsed_sections": [],
|
||||
"name": "Effecient Attention | Cats & Dogs",
|
||||
"name": "Efficient Attention | Cats & Dogs",
|
||||
"provenance": [],
|
||||
"toc_visible": true
|
||||
},
|
||||
|
||||
BIN
images/mp3.png
Normal file
BIN
images/mp3.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 518 KiB |
BIN
images/navit.png
Normal file
BIN
images/navit.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 133 KiB |
BIN
images/vivit.png
Normal file
BIN
images/vivit.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 104 KiB |
8
setup.py
8
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.35.5',
|
||||
version = '1.5.0 ',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
long_description_content_type = 'text/markdown',
|
||||
@@ -16,7 +16,7 @@ setup(
|
||||
'image recognition'
|
||||
],
|
||||
install_requires=[
|
||||
'einops>=0.4.1',
|
||||
'einops>=0.6.1',
|
||||
'torch>=1.10',
|
||||
'torchvision'
|
||||
],
|
||||
@@ -24,7 +24,9 @@ setup(
|
||||
'pytest-runner',
|
||||
],
|
||||
tests_require=[
|
||||
'pytest'
|
||||
'pytest',
|
||||
'torch==1.12.1',
|
||||
'torchvision==0.13.1'
|
||||
],
|
||||
classifiers=[
|
||||
'Development Status :: 4 - Beta',
|
||||
|
||||
@@ -1,3 +1,10 @@
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse('2.0.0'):
|
||||
from einops._torch_specific import allow_ops_in_compiled_graph
|
||||
allow_ops_in_compiled_graph()
|
||||
|
||||
from vit_pytorch.vit import ViT
|
||||
from vit_pytorch.simple_vit import SimpleViT
|
||||
|
||||
|
||||
@@ -110,18 +110,11 @@ class AdaptiveTokenSampling(nn.Module):
|
||||
|
||||
# classes
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
@@ -138,6 +131,7 @@ class Attention(nn.Module):
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
@@ -154,6 +148,7 @@ class Attention(nn.Module):
|
||||
def forward(self, x, *, mask):
|
||||
num_tokens = x.shape[1]
|
||||
|
||||
x = self.norm(x)
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
@@ -189,8 +184,8 @@ class Transformer(nn.Module):
|
||||
self.layers = nn.ModuleList([])
|
||||
for _, output_num_tokens in zip(range(depth), max_tokens_per_depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, output_num_tokens = output_num_tokens, heads = heads, dim_head = dim_head, dropout = dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
|
||||
Attention(dim, output_num_tokens = output_num_tokens, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
@@ -230,7 +225,9 @@ class ViT(nn.Module):
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim)
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
|
||||
@@ -44,18 +44,11 @@ class LayerScale(nn.Module):
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(x, **kwargs) * self.scale
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
@@ -72,6 +65,7 @@ class Attention(nn.Module):
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
|
||||
|
||||
@@ -89,6 +83,7 @@ class Attention(nn.Module):
|
||||
def forward(self, x, context = None):
|
||||
b, n, _, h = *x.shape, self.heads
|
||||
|
||||
x = self.norm(x)
|
||||
context = x if not exists(context) else torch.cat((x, context), dim = 1)
|
||||
|
||||
qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
|
||||
@@ -115,8 +110,8 @@ class Transformer(nn.Module):
|
||||
|
||||
for ind in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
LayerScale(dim, PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), depth = ind + 1),
|
||||
LayerScale(dim, PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)), depth = ind + 1)
|
||||
LayerScale(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), depth = ind + 1),
|
||||
LayerScale(dim, FeedForward(dim, mlp_dim, dropout = dropout), depth = ind + 1)
|
||||
]))
|
||||
def forward(self, x, context = None):
|
||||
layers = dropout_layers(self.layers, dropout = self.layer_dropout)
|
||||
@@ -150,7 +145,9 @@ class CaiT(nn.Module):
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim)
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
|
||||
|
||||
@@ -1,9 +1,17 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
@@ -50,8 +58,9 @@ def cct_16(*args, **kwargs):
|
||||
def _cct(num_layers, num_heads, mlp_ratio, embedding_dim,
|
||||
kernel_size=3, stride=None, padding=None,
|
||||
*args, **kwargs):
|
||||
stride = stride if stride is not None else max(1, (kernel_size // 2) - 1)
|
||||
padding = padding if padding is not None else max(1, (kernel_size // 2))
|
||||
stride = default(stride, max(1, (kernel_size // 2) - 1))
|
||||
padding = default(padding, max(1, (kernel_size // 2)))
|
||||
|
||||
return CCT(num_layers=num_layers,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
@@ -61,13 +70,22 @@ def _cct(num_layers, num_heads, mlp_ratio, embedding_dim,
|
||||
padding=padding,
|
||||
*args, **kwargs)
|
||||
|
||||
# positional
|
||||
|
||||
def sinusoidal_embedding(n_channels, dim):
|
||||
pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)]
|
||||
for p in range(n_channels)])
|
||||
pe[:, 0::2] = torch.sin(pe[:, 0::2])
|
||||
pe[:, 1::2] = torch.cos(pe[:, 1::2])
|
||||
return rearrange(pe, '... -> 1 ...')
|
||||
|
||||
# modules
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // self.num_heads
|
||||
self.heads = num_heads
|
||||
head_dim = dim // self.heads
|
||||
self.scale = head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=False)
|
||||
@@ -77,17 +95,20 @@ class Attention(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
qkv = self.qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
q = q * self.scale
|
||||
|
||||
attn = einsum('b h i d, b h j d -> b h i j', q, k)
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
x = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
x = rearrange(x, 'b h n d -> b n (h d)')
|
||||
|
||||
return self.proj_drop(self.proj(x))
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
@@ -97,7 +118,8 @@ class TransformerEncoderLayer(nn.Module):
|
||||
"""
|
||||
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
||||
attention_dropout=0.1, drop_path_rate=0.1):
|
||||
super(TransformerEncoderLayer, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
self.pre_norm = nn.LayerNorm(d_model)
|
||||
self.self_attn = Attention(dim=d_model, num_heads=nhead,
|
||||
attention_dropout=attention_dropout, projection_dropout=dropout)
|
||||
@@ -108,50 +130,34 @@ class TransformerEncoderLayer(nn.Module):
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
|
||||
self.drop_path = DropPath(drop_path_rate)
|
||||
|
||||
self.activation = F.gelu
|
||||
|
||||
def forward(self, src: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
||||
def forward(self, src, *args, **kwargs):
|
||||
src = src + self.drop_path(self.self_attn(self.pre_norm(src)))
|
||||
src = self.norm1(src)
|
||||
src2 = self.linear2(self.dropout1(self.activation(self.linear1(src))))
|
||||
src = src + self.drop_path(self.dropout2(src2))
|
||||
return src
|
||||
|
||||
|
||||
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
||||
"""
|
||||
Obtained from: github.com:rwightman/pytorch-image-models
|
||||
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
||||
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
||||
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
||||
'survival rate' as the argument.
|
||||
"""
|
||||
if drop_prob == 0. or not training:
|
||||
return x
|
||||
keep_prob = 1 - drop_prob
|
||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
||||
random_tensor.floor_() # binarize
|
||||
output = x.div(keep_prob) * random_tensor
|
||||
return output
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""
|
||||
Obtained from: github.com:rwightman/pytorch-image-models
|
||||
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
"""
|
||||
def __init__(self, drop_prob=None):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
super().__init__()
|
||||
self.drop_prob = float(drop_prob)
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training)
|
||||
batch, drop_prob, device, dtype = x.shape[0], self.drop_prob, x.device, x.dtype
|
||||
|
||||
if drop_prob <= 0. or not self.training:
|
||||
return x
|
||||
|
||||
keep_prob = 1 - self.drop_prob
|
||||
shape = (batch, *((1,) * (x.ndim - 1)))
|
||||
|
||||
keep_mask = torch.zeros(shape, device = device).float().uniform_(0, 1) < keep_prob
|
||||
output = x.div(keep_prob) * keep_mask.float()
|
||||
return output
|
||||
|
||||
class Tokenizer(nn.Module):
|
||||
def __init__(self,
|
||||
@@ -164,34 +170,35 @@ class Tokenizer(nn.Module):
|
||||
activation=None,
|
||||
max_pool=True,
|
||||
conv_bias=False):
|
||||
super(Tokenizer, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
n_filter_list = [n_input_channels] + \
|
||||
[in_planes for _ in range(n_conv_layers - 1)] + \
|
||||
[n_output_channels]
|
||||
|
||||
n_filter_list_pairs = zip(n_filter_list[:-1], n_filter_list[1:])
|
||||
|
||||
self.conv_layers = nn.Sequential(
|
||||
*[nn.Sequential(
|
||||
nn.Conv2d(n_filter_list[i], n_filter_list[i + 1],
|
||||
nn.Conv2d(chan_in, chan_out,
|
||||
kernel_size=(kernel_size, kernel_size),
|
||||
stride=(stride, stride),
|
||||
padding=(padding, padding), bias=conv_bias),
|
||||
nn.Identity() if activation is None else activation(),
|
||||
nn.Identity() if not exists(activation) else activation(),
|
||||
nn.MaxPool2d(kernel_size=pooling_kernel_size,
|
||||
stride=pooling_stride,
|
||||
padding=pooling_padding) if max_pool else nn.Identity()
|
||||
)
|
||||
for i in range(n_conv_layers)
|
||||
for chan_in, chan_out in n_filter_list_pairs
|
||||
])
|
||||
|
||||
self.flattener = nn.Flatten(2, 3)
|
||||
self.apply(self.init_weight)
|
||||
|
||||
def sequence_length(self, n_channels=3, height=224, width=224):
|
||||
return self.forward(torch.zeros((1, n_channels, height, width))).shape[1]
|
||||
|
||||
def forward(self, x):
|
||||
return self.flattener(self.conv_layers(x)).transpose(-2, -1)
|
||||
return rearrange(self.conv_layers(x), 'b c h w -> b (h w) c')
|
||||
|
||||
@staticmethod
|
||||
def init_weight(m):
|
||||
@@ -214,106 +221,104 @@ class TransformerClassifier(nn.Module):
|
||||
sequence_length=None,
|
||||
*args, **kwargs):
|
||||
super().__init__()
|
||||
positional_embedding = positional_embedding if \
|
||||
positional_embedding in ['sine', 'learnable', 'none'] else 'sine'
|
||||
assert positional_embedding in {'sine', 'learnable', 'none'}
|
||||
|
||||
dim_feedforward = int(embedding_dim * mlp_ratio)
|
||||
self.embedding_dim = embedding_dim
|
||||
self.sequence_length = sequence_length
|
||||
self.seq_pool = seq_pool
|
||||
|
||||
assert sequence_length is not None or positional_embedding == 'none', \
|
||||
assert exists(sequence_length) or positional_embedding == 'none', \
|
||||
f"Positional embedding is set to {positional_embedding} and" \
|
||||
f" the sequence length was not specified."
|
||||
|
||||
if not seq_pool:
|
||||
sequence_length += 1
|
||||
self.class_emb = nn.Parameter(torch.zeros(1, 1, self.embedding_dim),
|
||||
requires_grad=True)
|
||||
self.class_emb = nn.Parameter(torch.zeros(1, 1, self.embedding_dim), requires_grad=True)
|
||||
else:
|
||||
self.attention_pool = nn.Linear(self.embedding_dim, 1)
|
||||
|
||||
if positional_embedding != 'none':
|
||||
if positional_embedding == 'learnable':
|
||||
self.positional_emb = nn.Parameter(torch.zeros(1, sequence_length, embedding_dim),
|
||||
requires_grad=True)
|
||||
nn.init.trunc_normal_(self.positional_emb, std=0.2)
|
||||
else:
|
||||
self.positional_emb = nn.Parameter(self.sinusoidal_embedding(sequence_length, embedding_dim),
|
||||
requires_grad=False)
|
||||
else:
|
||||
if positional_embedding == 'none':
|
||||
self.positional_emb = None
|
||||
elif positional_embedding == 'learnable':
|
||||
self.positional_emb = nn.Parameter(torch.zeros(1, sequence_length, embedding_dim),
|
||||
requires_grad=True)
|
||||
nn.init.trunc_normal_(self.positional_emb, std=0.2)
|
||||
else:
|
||||
self.positional_emb = nn.Parameter(sinusoidal_embedding(sequence_length, embedding_dim),
|
||||
requires_grad=False)
|
||||
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
|
||||
dpr = [x.item() for x in torch.linspace(0, stochastic_depth_rate, num_layers)]
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads,
|
||||
dim_feedforward=dim_feedforward, dropout=dropout_rate,
|
||||
attention_dropout=attention_dropout, drop_path_rate=dpr[i])
|
||||
for i in range(num_layers)])
|
||||
attention_dropout=attention_dropout, drop_path_rate=layer_dpr)
|
||||
for layer_dpr in dpr])
|
||||
|
||||
self.norm = nn.LayerNorm(embedding_dim)
|
||||
|
||||
self.fc = nn.Linear(embedding_dim, num_classes)
|
||||
self.apply(self.init_weight)
|
||||
|
||||
def forward(self, x):
|
||||
if self.positional_emb is None and x.size(1) < self.sequence_length:
|
||||
b = x.shape[0]
|
||||
|
||||
if not exists(self.positional_emb) and x.size(1) < self.sequence_length:
|
||||
x = F.pad(x, (0, 0, 0, self.n_channels - x.size(1)), mode='constant', value=0)
|
||||
|
||||
if not self.seq_pool:
|
||||
cls_token = self.class_emb.expand(x.shape[0], -1, -1)
|
||||
cls_token = repeat(self.class_emb, '1 1 d -> b 1 d', b = b)
|
||||
x = torch.cat((cls_token, x), dim=1)
|
||||
|
||||
if self.positional_emb is not None:
|
||||
if exists(self.positional_emb):
|
||||
x += self.positional_emb
|
||||
|
||||
x = self.dropout(x)
|
||||
|
||||
for blk in self.blocks:
|
||||
x = blk(x)
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
if self.seq_pool:
|
||||
x = torch.matmul(F.softmax(self.attention_pool(x), dim=1).transpose(-1, -2), x).squeeze(-2)
|
||||
attn_weights = rearrange(self.attention_pool(x), 'b n 1 -> b n')
|
||||
x = einsum('b n, b n d -> b d', attn_weights.softmax(dim = 1), x)
|
||||
else:
|
||||
x = x[:, 0]
|
||||
|
||||
x = self.fc(x)
|
||||
return x
|
||||
return self.fc(x)
|
||||
|
||||
@staticmethod
|
||||
def init_weight(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
if isinstance(m, nn.Linear) and exists(m.bias):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
@staticmethod
|
||||
def sinusoidal_embedding(n_channels, dim):
|
||||
pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)]
|
||||
for p in range(n_channels)])
|
||||
pe[:, 0::2] = torch.sin(pe[:, 0::2])
|
||||
pe[:, 1::2] = torch.cos(pe[:, 1::2])
|
||||
return pe.unsqueeze(0)
|
||||
|
||||
|
||||
# CCT Main model
|
||||
|
||||
class CCT(nn.Module):
|
||||
def __init__(self,
|
||||
img_size=224,
|
||||
embedding_dim=768,
|
||||
n_input_channels=3,
|
||||
n_conv_layers=1,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
padding=3,
|
||||
pooling_kernel_size=3,
|
||||
pooling_stride=2,
|
||||
pooling_padding=1,
|
||||
*args, **kwargs):
|
||||
super(CCT, self).__init__()
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
embedding_dim=768,
|
||||
n_input_channels=3,
|
||||
n_conv_layers=1,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
padding=3,
|
||||
pooling_kernel_size=3,
|
||||
pooling_stride=2,
|
||||
pooling_padding=1,
|
||||
*args, **kwargs
|
||||
):
|
||||
super().__init__()
|
||||
img_height, img_width = pair(img_size)
|
||||
|
||||
self.tokenizer = Tokenizer(n_input_channels=n_input_channels,
|
||||
|
||||
376
vit_pytorch/cct_3d.py
Normal file
376
vit_pytorch/cct_3d.py
Normal file
@@ -0,0 +1,376 @@
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
# CCT Models
|
||||
|
||||
__all__ = ['cct_2', 'cct_4', 'cct_6', 'cct_7', 'cct_8', 'cct_14', 'cct_16']
|
||||
|
||||
|
||||
def cct_2(*args, **kwargs):
|
||||
return _cct(num_layers=2, num_heads=2, mlp_ratio=1, embedding_dim=128,
|
||||
*args, **kwargs)
|
||||
|
||||
|
||||
def cct_4(*args, **kwargs):
|
||||
return _cct(num_layers=4, num_heads=2, mlp_ratio=1, embedding_dim=128,
|
||||
*args, **kwargs)
|
||||
|
||||
|
||||
def cct_6(*args, **kwargs):
|
||||
return _cct(num_layers=6, num_heads=4, mlp_ratio=2, embedding_dim=256,
|
||||
*args, **kwargs)
|
||||
|
||||
|
||||
def cct_7(*args, **kwargs):
|
||||
return _cct(num_layers=7, num_heads=4, mlp_ratio=2, embedding_dim=256,
|
||||
*args, **kwargs)
|
||||
|
||||
|
||||
def cct_8(*args, **kwargs):
|
||||
return _cct(num_layers=8, num_heads=4, mlp_ratio=2, embedding_dim=256,
|
||||
*args, **kwargs)
|
||||
|
||||
|
||||
def cct_14(*args, **kwargs):
|
||||
return _cct(num_layers=14, num_heads=6, mlp_ratio=3, embedding_dim=384,
|
||||
*args, **kwargs)
|
||||
|
||||
|
||||
def cct_16(*args, **kwargs):
|
||||
return _cct(num_layers=16, num_heads=6, mlp_ratio=3, embedding_dim=384,
|
||||
*args, **kwargs)
|
||||
|
||||
|
||||
def _cct(num_layers, num_heads, mlp_ratio, embedding_dim,
|
||||
kernel_size=3, stride=None, padding=None,
|
||||
*args, **kwargs):
|
||||
stride = default(stride, max(1, (kernel_size // 2) - 1))
|
||||
padding = default(padding, max(1, (kernel_size // 2)))
|
||||
|
||||
return CCT(num_layers=num_layers,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
embedding_dim=embedding_dim,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
*args, **kwargs)
|
||||
|
||||
# positional
|
||||
|
||||
def sinusoidal_embedding(n_channels, dim):
|
||||
pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)]
|
||||
for p in range(n_channels)])
|
||||
pe[:, 0::2] = torch.sin(pe[:, 0::2])
|
||||
pe[:, 1::2] = torch.cos(pe[:, 1::2])
|
||||
return rearrange(pe, '... -> 1 ...')
|
||||
|
||||
# modules
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1):
|
||||
super().__init__()
|
||||
self.heads = num_heads
|
||||
head_dim = dim // self.heads
|
||||
self.scale = head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=False)
|
||||
self.attn_drop = nn.Dropout(attention_dropout)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(projection_dropout)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
|
||||
qkv = self.qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
q = q * self.scale
|
||||
|
||||
attn = einsum('b h i d, b h j d -> b h i j', q, k)
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
x = rearrange(x, 'b h n d -> b n (h d)')
|
||||
|
||||
return self.proj_drop(self.proj(x))
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
"""
|
||||
Inspired by torch.nn.TransformerEncoderLayer and
|
||||
rwightman's timm package.
|
||||
"""
|
||||
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
||||
attention_dropout=0.1, drop_path_rate=0.1):
|
||||
super().__init__()
|
||||
|
||||
self.pre_norm = nn.LayerNorm(d_model)
|
||||
self.self_attn = Attention(dim=d_model, num_heads=nhead,
|
||||
attention_dropout=attention_dropout, projection_dropout=dropout)
|
||||
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
|
||||
self.drop_path = DropPath(drop_path_rate)
|
||||
|
||||
self.activation = F.gelu
|
||||
|
||||
def forward(self, src, *args, **kwargs):
|
||||
src = src + self.drop_path(self.self_attn(self.pre_norm(src)))
|
||||
src = self.norm1(src)
|
||||
src2 = self.linear2(self.dropout1(self.activation(self.linear1(src))))
|
||||
src = src + self.drop_path(self.dropout2(src2))
|
||||
return src
|
||||
|
||||
class DropPath(nn.Module):
|
||||
def __init__(self, drop_prob=None):
|
||||
super().__init__()
|
||||
self.drop_prob = float(drop_prob)
|
||||
|
||||
def forward(self, x):
|
||||
batch, drop_prob, device, dtype = x.shape[0], self.drop_prob, x.device, x.dtype
|
||||
|
||||
if drop_prob <= 0. or not self.training:
|
||||
return x
|
||||
|
||||
keep_prob = 1 - self.drop_prob
|
||||
shape = (batch, *((1,) * (x.ndim - 1)))
|
||||
|
||||
keep_mask = torch.zeros(shape, device = device).float().uniform_(0, 1) < keep_prob
|
||||
output = x.div(keep_prob) * keep_mask.float()
|
||||
return output
|
||||
|
||||
class Tokenizer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
frame_kernel_size,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
frame_stride=1,
|
||||
frame_pooling_stride=1,
|
||||
frame_pooling_kernel_size=1,
|
||||
pooling_kernel_size=3,
|
||||
pooling_stride=2,
|
||||
pooling_padding=1,
|
||||
n_conv_layers=1,
|
||||
n_input_channels=3,
|
||||
n_output_channels=64,
|
||||
in_planes=64,
|
||||
activation=None,
|
||||
max_pool=True,
|
||||
conv_bias=False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
n_filter_list = [n_input_channels] + \
|
||||
[in_planes for _ in range(n_conv_layers - 1)] + \
|
||||
[n_output_channels]
|
||||
|
||||
n_filter_list_pairs = zip(n_filter_list[:-1], n_filter_list[1:])
|
||||
|
||||
self.conv_layers = nn.Sequential(
|
||||
*[nn.Sequential(
|
||||
nn.Conv3d(chan_in, chan_out,
|
||||
kernel_size=(frame_kernel_size, kernel_size, kernel_size),
|
||||
stride=(frame_stride, stride, stride),
|
||||
padding=(frame_kernel_size // 2, padding, padding), bias=conv_bias),
|
||||
nn.Identity() if not exists(activation) else activation(),
|
||||
nn.MaxPool3d(kernel_size=(frame_pooling_kernel_size, pooling_kernel_size, pooling_kernel_size),
|
||||
stride=(frame_pooling_stride, pooling_stride, pooling_stride),
|
||||
padding=(frame_pooling_kernel_size // 2, pooling_padding, pooling_padding)) if max_pool else nn.Identity()
|
||||
)
|
||||
for chan_in, chan_out in n_filter_list_pairs
|
||||
])
|
||||
|
||||
self.apply(self.init_weight)
|
||||
|
||||
def sequence_length(self, n_channels=3, frames=8, height=224, width=224):
|
||||
return self.forward(torch.zeros((1, n_channels, frames, height, width))).shape[1]
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_layers(x)
|
||||
return rearrange(x, 'b c f h w -> b (f h w) c')
|
||||
|
||||
@staticmethod
|
||||
def init_weight(m):
|
||||
if isinstance(m, nn.Conv3d):
|
||||
nn.init.kaiming_normal_(m.weight)
|
||||
|
||||
|
||||
class TransformerClassifier(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
seq_pool=True,
|
||||
embedding_dim=768,
|
||||
num_layers=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4.0,
|
||||
num_classes=1000,
|
||||
dropout_rate=0.1,
|
||||
attention_dropout=0.1,
|
||||
stochastic_depth_rate=0.1,
|
||||
positional_embedding='sine',
|
||||
sequence_length=None,
|
||||
*args, **kwargs
|
||||
):
|
||||
super().__init__()
|
||||
assert positional_embedding in {'sine', 'learnable', 'none'}
|
||||
|
||||
dim_feedforward = int(embedding_dim * mlp_ratio)
|
||||
self.embedding_dim = embedding_dim
|
||||
self.sequence_length = sequence_length
|
||||
self.seq_pool = seq_pool
|
||||
|
||||
assert exists(sequence_length) or positional_embedding == 'none', \
|
||||
f"Positional embedding is set to {positional_embedding} and" \
|
||||
f" the sequence length was not specified."
|
||||
|
||||
if not seq_pool:
|
||||
sequence_length += 1
|
||||
self.class_emb = nn.Parameter(torch.zeros(1, 1, self.embedding_dim))
|
||||
else:
|
||||
self.attention_pool = nn.Linear(self.embedding_dim, 1)
|
||||
|
||||
if positional_embedding == 'none':
|
||||
self.positional_emb = None
|
||||
elif positional_embedding == 'learnable':
|
||||
self.positional_emb = nn.Parameter(torch.zeros(1, sequence_length, embedding_dim))
|
||||
nn.init.trunc_normal_(self.positional_emb, std = 0.2)
|
||||
else:
|
||||
self.register_buffer('positional_emb', sinusoidal_embedding(sequence_length, embedding_dim))
|
||||
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
|
||||
dpr = [x.item() for x in torch.linspace(0, stochastic_depth_rate, num_layers)]
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads,
|
||||
dim_feedforward=dim_feedforward, dropout=dropout_rate,
|
||||
attention_dropout=attention_dropout, drop_path_rate=layer_dpr)
|
||||
for layer_dpr in dpr])
|
||||
|
||||
self.norm = nn.LayerNorm(embedding_dim)
|
||||
|
||||
self.fc = nn.Linear(embedding_dim, num_classes)
|
||||
self.apply(self.init_weight)
|
||||
|
||||
@staticmethod
|
||||
def init_weight(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and exists(m.bias):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
def forward(self, x):
|
||||
b = x.shape[0]
|
||||
|
||||
if not exists(self.positional_emb) and x.size(1) < self.sequence_length:
|
||||
x = F.pad(x, (0, 0, 0, self.n_channels - x.size(1)), mode='constant', value=0)
|
||||
|
||||
if not self.seq_pool:
|
||||
cls_token = repeat(self.class_emb, '1 1 d -> b 1 d', b = b)
|
||||
x = torch.cat((cls_token, x), dim=1)
|
||||
|
||||
if exists(self.positional_emb):
|
||||
x += self.positional_emb
|
||||
|
||||
x = self.dropout(x)
|
||||
|
||||
for blk in self.blocks:
|
||||
x = blk(x)
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
if self.seq_pool:
|
||||
attn_weights = rearrange(self.attention_pool(x), 'b n 1 -> b n')
|
||||
x = einsum('b n, b n d -> b d', attn_weights.softmax(dim = 1), x)
|
||||
else:
|
||||
x = x[:, 0]
|
||||
|
||||
return self.fc(x)
|
||||
|
||||
# CCT Main model
|
||||
|
||||
class CCT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
num_frames=8,
|
||||
embedding_dim=768,
|
||||
n_input_channels=3,
|
||||
n_conv_layers=1,
|
||||
frame_stride=1,
|
||||
frame_kernel_size=3,
|
||||
frame_pooling_kernel_size=1,
|
||||
frame_pooling_stride=1,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
padding=3,
|
||||
pooling_kernel_size=3,
|
||||
pooling_stride=2,
|
||||
pooling_padding=1,
|
||||
*args, **kwargs
|
||||
):
|
||||
super().__init__()
|
||||
img_height, img_width = pair(img_size)
|
||||
|
||||
self.tokenizer = Tokenizer(
|
||||
n_input_channels=n_input_channels,
|
||||
n_output_channels=embedding_dim,
|
||||
frame_stride=frame_stride,
|
||||
frame_kernel_size=frame_kernel_size,
|
||||
frame_pooling_stride=frame_pooling_stride,
|
||||
frame_pooling_kernel_size=frame_pooling_kernel_size,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
pooling_kernel_size=pooling_kernel_size,
|
||||
pooling_stride=pooling_stride,
|
||||
pooling_padding=pooling_padding,
|
||||
max_pool=True,
|
||||
activation=nn.ReLU,
|
||||
n_conv_layers=n_conv_layers,
|
||||
conv_bias=False
|
||||
)
|
||||
|
||||
self.classifier = TransformerClassifier(
|
||||
sequence_length=self.tokenizer.sequence_length(
|
||||
n_channels=n_input_channels,
|
||||
frames=num_frames,
|
||||
height=img_height,
|
||||
width=img_width
|
||||
),
|
||||
embedding_dim=embedding_dim,
|
||||
seq_pool=True,
|
||||
dropout_rate=0.,
|
||||
attention_dropout=0.1,
|
||||
stochastic_depth=0.1,
|
||||
*args, **kwargs
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.tokenizer(x)
|
||||
return self.classifier(x)
|
||||
@@ -13,22 +13,13 @@ def exists(val):
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
# pre-layernorm
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
# feedforward
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
@@ -47,6 +38,7 @@ class Attention(nn.Module):
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
@@ -60,6 +52,7 @@ class Attention(nn.Module):
|
||||
|
||||
def forward(self, x, context = None, kv_include_self = False):
|
||||
b, n, _, h = *x.shape, self.heads
|
||||
x = self.norm(x)
|
||||
context = default(context, x)
|
||||
|
||||
if kv_include_self:
|
||||
@@ -86,8 +79,8 @@ class Transformer(nn.Module):
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
@@ -121,8 +114,8 @@ class CrossTransformer(nn.Module):
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
ProjectInOut(sm_dim, lg_dim, PreNorm(lg_dim, Attention(lg_dim, heads = heads, dim_head = dim_head, dropout = dropout))),
|
||||
ProjectInOut(lg_dim, sm_dim, PreNorm(sm_dim, Attention(sm_dim, heads = heads, dim_head = dim_head, dropout = dropout)))
|
||||
ProjectInOut(sm_dim, lg_dim, Attention(lg_dim, heads = heads, dim_head = dim_head, dropout = dropout)),
|
||||
ProjectInOut(lg_dim, sm_dim, Attention(sm_dim, heads = heads, dim_head = dim_head, dropout = dropout))
|
||||
]))
|
||||
|
||||
def forward(self, sm_tokens, lg_tokens):
|
||||
@@ -186,7 +179,9 @@ class ImageEmbedder(nn.Module):
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim)
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
|
||||
@@ -34,19 +34,11 @@ class LayerNorm(nn.Module): # layernorm, but done in the channel dimension #1
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
x = self.norm(x)
|
||||
return self.fn(x, **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, mult = 4, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
LayerNorm(dim),
|
||||
nn.Conv2d(dim, dim * mult, 1),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
@@ -75,6 +67,7 @@ class Attention(nn.Module):
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.norm = LayerNorm(dim)
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
@@ -89,6 +82,8 @@ class Attention(nn.Module):
|
||||
def forward(self, x):
|
||||
shape = x.shape
|
||||
b, n, _, y, h = *shape, self.heads
|
||||
|
||||
x = self.norm(x)
|
||||
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = 1))
|
||||
q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> (b h) (x y) d', h = h), (q, k, v))
|
||||
|
||||
@@ -107,8 +102,8 @@ class Transformer(nn.Module):
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, proj_kernel = proj_kernel, kv_proj_stride = kv_proj_stride, heads = heads, dim_head = dim_head, dropout = dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout))
|
||||
Attention(dim, proj_kernel = proj_kernel, kv_proj_stride = kv_proj_stride, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_mult, dropout = dropout)
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
|
||||
@@ -5,25 +5,11 @@ import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
class Residual(nn.Module):
|
||||
def __init__(self, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(x, **kwargs) + x
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
@@ -40,6 +26,7 @@ class Attention(nn.Module):
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
@@ -59,6 +46,8 @@ class Attention(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
b, n, _, h = *x.shape, self.heads
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
||||
|
||||
@@ -86,13 +75,13 @@ class Transformer(nn.Module):
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
|
||||
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x)
|
||||
x = ff(x)
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
class DeepViT(nn.Module):
|
||||
@@ -105,7 +94,9 @@ class DeepViT(nn.Module):
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim)
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
|
||||
@@ -17,7 +17,9 @@ class ViT(nn.Module):
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim)
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
|
||||
@@ -4,6 +4,12 @@ from torch import nn
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def identity(t):
|
||||
return t
|
||||
|
||||
def clone_and_detach(t):
|
||||
return t.clone().detach()
|
||||
|
||||
def apply_tuple_or_single(fn, val):
|
||||
if isinstance(val, tuple):
|
||||
return tuple(map(fn, val))
|
||||
@@ -17,7 +23,8 @@ class Extractor(nn.Module):
|
||||
layer = None,
|
||||
layer_name = 'transformer',
|
||||
layer_save_input = False,
|
||||
return_embeddings_only = False
|
||||
return_embeddings_only = False,
|
||||
detach = True
|
||||
):
|
||||
super().__init__()
|
||||
self.vit = vit
|
||||
@@ -34,9 +41,11 @@ class Extractor(nn.Module):
|
||||
self.layer_save_input = layer_save_input # whether to save input or output of layer
|
||||
self.return_embeddings_only = return_embeddings_only
|
||||
|
||||
self.detach_fn = clone_and_detach if detach else identity
|
||||
|
||||
def _hook(self, _, inputs, output):
|
||||
layer_output = inputs if self.layer_save_input else output
|
||||
self.latents = apply_tuple_or_single(lambda t: t.clone().detach(), layer_output)
|
||||
self.latents = apply_tuple_or_single(self.detach_fn, layer_output)
|
||||
|
||||
def _register_hook(self):
|
||||
if not exists(self.layer):
|
||||
|
||||
@@ -118,7 +118,9 @@ class ViT(nn.Module):
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim)
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
|
||||
@@ -26,16 +26,6 @@ class ExcludeCLS(nn.Module):
|
||||
x = self.fn(x, **kwargs)
|
||||
return torch.cat((cls_token, x), dim = 1)
|
||||
|
||||
# prenorm
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
# feed forward related classes
|
||||
|
||||
class DepthWiseConv2d(nn.Module):
|
||||
@@ -52,6 +42,7 @@ class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Conv2d(dim, hidden_dim, 1),
|
||||
nn.Hardswish(),
|
||||
DepthWiseConv2d(hidden_dim, hidden_dim, 3, padding = 1),
|
||||
@@ -77,6 +68,7 @@ class Attention(nn.Module):
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
@@ -88,6 +80,8 @@ class Attention(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
b, n, _, h = *x.shape, self.heads
|
||||
|
||||
x = self.norm(x)
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
||||
|
||||
@@ -106,8 +100,8 @@ class Transformer(nn.Module):
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
|
||||
ExcludeCLS(Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))))
|
||||
Residual(Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
|
||||
ExcludeCLS(Residual(FeedForward(dim, mlp_dim, dropout = dropout)))
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
@@ -126,7 +120,9 @@ class LocalViT(nn.Module):
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
|
||||
@@ -24,11 +24,14 @@ class MAE(nn.Module):
|
||||
|
||||
self.encoder = encoder
|
||||
num_patches, encoder_dim = encoder.pos_embedding.shape[-2:]
|
||||
self.to_patch, self.patch_to_emb = encoder.to_patch_embedding[:2]
|
||||
pixel_values_per_patch = self.patch_to_emb.weight.shape[-1]
|
||||
|
||||
self.to_patch = encoder.to_patch_embedding[0]
|
||||
self.patch_to_emb = nn.Sequential(*encoder.to_patch_embedding[1:])
|
||||
|
||||
pixel_values_per_patch = encoder.to_patch_embedding[2].weight.shape[-1]
|
||||
|
||||
# decoder parameters
|
||||
|
||||
self.decoder_dim = decoder_dim
|
||||
self.enc_to_dec = nn.Linear(encoder_dim, decoder_dim) if encoder_dim != decoder_dim else nn.Identity()
|
||||
self.mask_token = nn.Parameter(torch.randn(decoder_dim))
|
||||
self.decoder = Transformer(dim = decoder_dim, depth = decoder_depth, heads = decoder_heads, dim_head = decoder_dim_head, mlp_dim = decoder_dim * 4)
|
||||
@@ -46,7 +49,10 @@ class MAE(nn.Module):
|
||||
# patch to encoder tokens and add positions
|
||||
|
||||
tokens = self.patch_to_emb(patches)
|
||||
tokens = tokens + self.encoder.pos_embedding[:, 1:(num_patches + 1)]
|
||||
if self.encoder.pool == "cls":
|
||||
tokens += self.encoder.pos_embedding[:, 1:(num_patches + 1)]
|
||||
elif self.encoder.pool == "mean":
|
||||
tokens += self.encoder.pos_embedding.to(device, dtype=tokens.dtype)
|
||||
|
||||
# calculate of patches needed to be masked, and get random indices, dividing it up for mask vs unmasked
|
||||
|
||||
@@ -73,7 +79,7 @@ class MAE(nn.Module):
|
||||
|
||||
# reapply decoder position embedding to unmasked tokens
|
||||
|
||||
decoder_tokens = decoder_tokens + self.decoder_pos_emb(unmasked_indices)
|
||||
unmasked_decoder_tokens = decoder_tokens + self.decoder_pos_emb(unmasked_indices)
|
||||
|
||||
# repeat mask tokens for number of masked, and add the positions using the masked indices derived above
|
||||
|
||||
@@ -81,13 +87,15 @@ class MAE(nn.Module):
|
||||
mask_tokens = mask_tokens + self.decoder_pos_emb(masked_indices)
|
||||
|
||||
# concat the masked tokens to the decoder tokens and attend with decoder
|
||||
|
||||
decoder_tokens = torch.cat((mask_tokens, decoder_tokens), dim = 1)
|
||||
|
||||
decoder_tokens = torch.zeros(batch, num_patches, self.decoder_dim, device=device)
|
||||
decoder_tokens[batch_range, unmasked_indices] = unmasked_decoder_tokens
|
||||
decoder_tokens[batch_range, masked_indices] = mask_tokens
|
||||
decoded_tokens = self.decoder(decoder_tokens)
|
||||
|
||||
# splice out the mask tokens and project to pixel values
|
||||
|
||||
mask_tokens = decoded_tokens[:, :num_masked]
|
||||
mask_tokens = decoded_tokens[batch_range, masked_indices]
|
||||
pred_pixel_values = self.to_pixels(mask_tokens)
|
||||
|
||||
# calculate reconstruction loss
|
||||
|
||||
@@ -19,20 +19,20 @@ def cast_tuple(val, length = 1):
|
||||
|
||||
# helper classes
|
||||
|
||||
class PreNormResidual(nn.Module):
|
||||
class Residual(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x):
|
||||
return self.fn(self.norm(x)) + x
|
||||
return self.fn(x) + x
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, mult = 4, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, inner_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
@@ -103,7 +103,9 @@ def MBConv(
|
||||
nn.Conv2d(dim_in, hidden_dim, 1),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride = stride, padding = 1, groups = dim_out),
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride = stride, padding = 1, groups = hidden_dim),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.GELU(),
|
||||
SqueezeExcitation(hidden_dim, shrinkage_rate = shrinkage_rate),
|
||||
nn.Conv2d(hidden_dim, dim_out, 1),
|
||||
nn.BatchNorm2d(dim_out)
|
||||
@@ -130,6 +132,7 @@ class Attention(nn.Module):
|
||||
self.heads = dim // dim_head
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
|
||||
|
||||
self.attend = nn.Sequential(
|
||||
@@ -158,6 +161,8 @@ class Attention(nn.Module):
|
||||
def forward(self, x):
|
||||
batch, height, width, window_height, window_width, _, device, h = *x.shape, x.device, self.heads
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
# flatten
|
||||
|
||||
x = rearrange(x, 'b x y w1 w2 d -> (b x y) (w1 w2) d')
|
||||
@@ -257,13 +262,13 @@ class MaxViT(nn.Module):
|
||||
shrinkage_rate = mbconv_shrinkage_rate
|
||||
),
|
||||
Rearrange('b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w), # block-like attention
|
||||
PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
|
||||
PreNormResidual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
|
||||
Residual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
|
||||
Residual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
|
||||
Rearrange('b x y w1 w2 d -> b d (x w1) (y w2)'),
|
||||
|
||||
Rearrange('b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w), # grid-like attention
|
||||
PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
|
||||
PreNormResidual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
|
||||
Residual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
|
||||
Residual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
|
||||
Rearrange('b x y w1 w2 d -> b d (w1 x) (w2 y)'),
|
||||
)
|
||||
|
||||
|
||||
@@ -13,29 +13,20 @@ def conv_1x1_bn(inp, oup):
|
||||
nn.SiLU()
|
||||
)
|
||||
|
||||
def conv_nxn_bn(inp, oup, kernal_size=3, stride=1):
|
||||
def conv_nxn_bn(inp, oup, kernel_size=3, stride=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False),
|
||||
nn.Conv2d(inp, oup, kernel_size, stride, 1, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
nn.SiLU()
|
||||
)
|
||||
|
||||
# classes
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout=0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(dropout),
|
||||
@@ -53,6 +44,7 @@ class Attention(nn.Module):
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.attend = nn.Softmax(dim=-1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
@@ -64,9 +56,10 @@ class Attention(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
||||
q, k, v = map(lambda t: rearrange(
|
||||
t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
@@ -88,8 +81,8 @@ class Transformer(nn.Module):
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
|
||||
Attention(dim, heads, dim_head, dropout),
|
||||
FeedForward(dim, mlp_dim, dropout)
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
@@ -167,11 +160,9 @@ class MobileViTBlock(nn.Module):
|
||||
|
||||
# Global representations
|
||||
_, _, h, w = x.shape
|
||||
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d',
|
||||
ph=self.ph, pw=self.pw)
|
||||
x = self.transformer(x)
|
||||
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)',
|
||||
h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)
|
||||
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
|
||||
x = self.transformer(x)
|
||||
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)
|
||||
|
||||
# Fusion
|
||||
x = self.conv3(x)
|
||||
|
||||
186
vit_pytorch/mp3.py
Normal file
186
vit_pytorch/mp3.py
Normal file
@@ -0,0 +1,186 @@
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
# positional embedding
|
||||
|
||||
def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
|
||||
_, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype
|
||||
|
||||
y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
|
||||
assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
|
||||
omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1)
|
||||
omega = 1. / (temperature ** omega)
|
||||
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
|
||||
return pe.type(dtype)
|
||||
|
||||
# feedforward
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
# (cross)attention
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, context = None):
|
||||
b, n, _, h = *x.shape, self.heads
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
context = self.norm(context) if exists(context) else x
|
||||
|
||||
qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
||||
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
def forward(self, x, context = None):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x, context = context) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
class ViT(nn.Module):
|
||||
def __init__(self, *, num_classes, image_size, patch_size, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
|
||||
self.dim = dim
|
||||
self.num_patches = num_patches
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_height, p2 = patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
self.to_latent = nn.Identity()
|
||||
self.linear_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
*_, h, w, dtype = *img.shape, img.dtype
|
||||
|
||||
x = self.to_patch_embedding(img)
|
||||
pe = posemb_sincos_2d(x)
|
||||
x = rearrange(x, 'b ... d -> b (...) d') + pe
|
||||
|
||||
x = self.transformer(x)
|
||||
x = x.mean(dim = 1)
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.linear_head(x)
|
||||
|
||||
# Masked Position Prediction Pre-Training
|
||||
|
||||
class MP3(nn.Module):
|
||||
def __init__(self, vit: ViT, masking_ratio):
|
||||
super().__init__()
|
||||
self.vit = vit
|
||||
|
||||
assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1'
|
||||
self.masking_ratio = masking_ratio
|
||||
|
||||
dim = vit.dim
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, vit.num_patches)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
device = img.device
|
||||
tokens = self.vit.to_patch_embedding(img)
|
||||
tokens = rearrange(tokens, 'b ... d -> b (...) d')
|
||||
|
||||
batch, num_patches, *_ = tokens.shape
|
||||
|
||||
# Masking
|
||||
num_masked = int(self.masking_ratio * num_patches)
|
||||
rand_indices = torch.rand(batch, num_patches, device = device).argsort(dim = -1)
|
||||
masked_indices, unmasked_indices = rand_indices[:, :num_masked], rand_indices[:, num_masked:]
|
||||
|
||||
batch_range = torch.arange(batch, device = device)[:, None]
|
||||
tokens_unmasked = tokens[batch_range, unmasked_indices]
|
||||
|
||||
attended_tokens = self.vit.transformer(tokens, tokens_unmasked)
|
||||
logits = rearrange(self.mlp_head(attended_tokens), 'b n d -> (b n) d')
|
||||
|
||||
# Define labels
|
||||
labels = repeat(torch.arange(num_patches, device = device), 'n -> (b n)', b = batch)
|
||||
loss = F.cross_entropy(logits, labels)
|
||||
|
||||
return loss
|
||||
@@ -96,6 +96,9 @@ class MPP(nn.Module):
|
||||
self.loss = MPPLoss(patch_size, channels, output_channel_bits,
|
||||
max_pixel_val, mean, std)
|
||||
|
||||
# extract patching function
|
||||
self.patch_to_emb = nn.Sequential(transformer.to_patch_embedding[1:])
|
||||
|
||||
# output transformation
|
||||
self.to_bits = nn.Linear(dim, 2**(output_channel_bits * channels))
|
||||
|
||||
@@ -151,7 +154,7 @@ class MPP(nn.Module):
|
||||
masked_input[bool_mask_replace] = self.mask_token
|
||||
|
||||
# linear embedding of patches
|
||||
masked_input = transformer.to_patch_embedding[-1](masked_input)
|
||||
masked_input = self.patch_to_emb(masked_input)
|
||||
|
||||
# add cls token to input sequence
|
||||
b, n, _ = masked_input.shape
|
||||
|
||||
389
vit_pytorch/na_vit.py
Normal file
389
vit_pytorch/na_vit.py
Normal file
@@ -0,0 +1,389 @@
|
||||
from functools import partial
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, Tensor
|
||||
from torch.nn.utils.rnn import pad_sequence as orig_pad_sequence
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def always(val):
|
||||
return lambda *args: val
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def divisible_by(numer, denom):
|
||||
return (numer % denom) == 0
|
||||
|
||||
# auto grouping images
|
||||
|
||||
def group_images_by_max_seq_len(
|
||||
images: List[Tensor],
|
||||
patch_size: int,
|
||||
calc_token_dropout = None,
|
||||
max_seq_len = 2048
|
||||
|
||||
) -> List[List[Tensor]]:
|
||||
|
||||
calc_token_dropout = default(calc_token_dropout, always(0.))
|
||||
|
||||
groups = []
|
||||
group = []
|
||||
seq_len = 0
|
||||
|
||||
if isinstance(calc_token_dropout, (float, int)):
|
||||
calc_token_dropout = always(calc_token_dropout)
|
||||
|
||||
for image in images:
|
||||
assert isinstance(image, Tensor)
|
||||
|
||||
image_dims = image.shape[-2:]
|
||||
ph, pw = map(lambda t: t // patch_size, image_dims)
|
||||
|
||||
image_seq_len = (ph * pw)
|
||||
image_seq_len = int(image_seq_len * (1 - calc_token_dropout(*image_dims)))
|
||||
|
||||
assert image_seq_len <= max_seq_len, f'image with dimensions {image_dims} exceeds maximum sequence length'
|
||||
|
||||
if (seq_len + image_seq_len) > max_seq_len:
|
||||
groups.append(group)
|
||||
group = []
|
||||
seq_len = 0
|
||||
|
||||
group.append(image)
|
||||
seq_len += image_seq_len
|
||||
|
||||
if len(group) > 0:
|
||||
groups.append(group)
|
||||
|
||||
return groups
|
||||
|
||||
# normalization
|
||||
# they use layernorm without bias, something that pytorch does not offer
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.ones(dim))
|
||||
self.register_buffer('beta', torch.zeros(dim))
|
||||
|
||||
def forward(self, x):
|
||||
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
|
||||
|
||||
# they use a query-key normalization that is equivalent to rms norm (no mean-centering, learned gamma), from vit 22B paper
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, heads, dim):
|
||||
super().__init__()
|
||||
self.scale = dim ** 0.5
|
||||
self.gamma = nn.Parameter(torch.ones(heads, 1, dim))
|
||||
|
||||
def forward(self, x):
|
||||
normed = F.normalize(x, dim = -1)
|
||||
return normed * self.scale * self.gamma
|
||||
|
||||
# feedforward
|
||||
|
||||
def FeedForward(dim, hidden_dim, dropout = 0.):
|
||||
return nn.Sequential(
|
||||
LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.norm = LayerNorm(dim)
|
||||
|
||||
self.q_norm = RMSNorm(heads, dim_head)
|
||||
self.k_norm = RMSNorm(heads, dim_head)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim, bias = False),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
context = None,
|
||||
mask = None,
|
||||
attn_mask = None
|
||||
):
|
||||
x = self.norm(x)
|
||||
kv_input = default(context, x)
|
||||
|
||||
qkv = (self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1))
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2))
|
||||
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b j -> b 1 1 j')
|
||||
dots = dots.masked_fill(~mask, -torch.finfo(dots.dtype).max)
|
||||
|
||||
if exists(attn_mask):
|
||||
dots = dots.masked_fill(~attn_mask, -torch.finfo(dots.dtype).max)
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
self.norm = LayerNorm(dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
mask = None,
|
||||
attn_mask = None
|
||||
):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x, mask = mask, attn_mask = attn_mask) + x
|
||||
x = ff(x) + x
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class NaViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., token_dropout_prob = None):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
|
||||
# what percent of tokens to dropout
|
||||
# if int or float given, then assume constant dropout prob
|
||||
# otherwise accept a callback that in turn calculates dropout prob from height and width
|
||||
|
||||
self.calc_token_dropout = None
|
||||
|
||||
if callable(token_dropout_prob):
|
||||
self.calc_token_dropout = token_dropout_prob
|
||||
|
||||
elif isinstance(token_dropout_prob, (float, int)):
|
||||
assert 0. < token_dropout_prob < 1.
|
||||
token_dropout_prob = float(token_dropout_prob)
|
||||
self.calc_token_dropout = lambda height, width: token_dropout_prob
|
||||
|
||||
# calculate patching related stuff
|
||||
|
||||
assert divisible_by(image_height, patch_size) and divisible_by(image_width, patch_size), 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
patch_height_dim, patch_width_dim = (image_height // patch_size), (image_width // patch_size)
|
||||
patch_dim = channels * (patch_size ** 2)
|
||||
|
||||
self.channels = channels
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.pos_embed_height = nn.Parameter(torch.randn(patch_height_dim, dim))
|
||||
self.pos_embed_width = nn.Parameter(torch.randn(patch_width_dim, dim))
|
||||
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
# final attention pooling queries
|
||||
|
||||
self.attn_pool_queries = nn.Parameter(torch.randn(dim))
|
||||
self.attn_pool = Attention(dim = dim, dim_head = dim_head, heads = heads)
|
||||
|
||||
# output to logits
|
||||
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes, bias = False)
|
||||
)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batched_images: Union[List[Tensor], List[List[Tensor]]], # assume different resolution images already grouped correctly
|
||||
group_images = False,
|
||||
group_max_seq_len = 2048
|
||||
):
|
||||
p, c, device, has_token_dropout = self.patch_size, self.channels, self.device, exists(self.calc_token_dropout)
|
||||
|
||||
arange = partial(torch.arange, device = device)
|
||||
pad_sequence = partial(orig_pad_sequence, batch_first = True)
|
||||
|
||||
# auto pack if specified
|
||||
|
||||
if group_images:
|
||||
batched_images = group_images_by_max_seq_len(
|
||||
batched_images,
|
||||
patch_size = self.patch_size,
|
||||
calc_token_dropout = self.calc_token_dropout,
|
||||
max_seq_len = group_max_seq_len
|
||||
)
|
||||
|
||||
# process images into variable lengthed sequences with attention mask
|
||||
|
||||
num_images = []
|
||||
batched_sequences = []
|
||||
batched_positions = []
|
||||
batched_image_ids = []
|
||||
|
||||
for images in batched_images:
|
||||
num_images.append(len(images))
|
||||
|
||||
sequences = []
|
||||
positions = []
|
||||
image_ids = torch.empty((0,), device = device, dtype = torch.long)
|
||||
|
||||
for image_id, image in enumerate(images):
|
||||
assert image.ndim ==3 and image.shape[0] == c
|
||||
image_dims = image.shape[-2:]
|
||||
assert all([divisible_by(dim, p) for dim in image_dims]), f'height and width {image_dims} of images must be divisible by patch size {p}'
|
||||
|
||||
ph, pw = map(lambda dim: dim // p, image_dims)
|
||||
|
||||
pos = torch.stack(torch.meshgrid((
|
||||
arange(ph),
|
||||
arange(pw)
|
||||
), indexing = 'ij'), dim = -1)
|
||||
|
||||
pos = rearrange(pos, 'h w c -> (h w) c')
|
||||
seq = rearrange(image, 'c (h p1) (w p2) -> (h w) (c p1 p2)', p1 = p, p2 = p)
|
||||
|
||||
seq_len = seq.shape[-2]
|
||||
|
||||
if has_token_dropout:
|
||||
token_dropout = self.calc_token_dropout(*image_dims)
|
||||
num_keep = max(1, int(seq_len * (1 - token_dropout)))
|
||||
keep_indices = torch.randn((seq_len,), device = device).topk(num_keep, dim = -1).indices
|
||||
|
||||
seq = seq[keep_indices]
|
||||
pos = pos[keep_indices]
|
||||
|
||||
image_ids = F.pad(image_ids, (0, seq.shape[-2]), value = image_id)
|
||||
sequences.append(seq)
|
||||
positions.append(pos)
|
||||
|
||||
batched_image_ids.append(image_ids)
|
||||
batched_sequences.append(torch.cat(sequences, dim = 0))
|
||||
batched_positions.append(torch.cat(positions, dim = 0))
|
||||
|
||||
# derive key padding mask
|
||||
|
||||
lengths = torch.tensor([seq.shape[-2] for seq in batched_sequences], device = device, dtype = torch.long)
|
||||
max_length = arange(lengths.amax().item())
|
||||
key_pad_mask = rearrange(lengths, 'b -> b 1') <= rearrange(max_length, 'n -> 1 n')
|
||||
|
||||
# derive attention mask, and combine with key padding mask from above
|
||||
|
||||
batched_image_ids = pad_sequence(batched_image_ids)
|
||||
attn_mask = rearrange(batched_image_ids, 'b i -> b 1 i 1') == rearrange(batched_image_ids, 'b j -> b 1 1 j')
|
||||
attn_mask = attn_mask & rearrange(key_pad_mask, 'b j -> b 1 1 j')
|
||||
|
||||
# combine patched images as well as the patched width / height positions for 2d positional embedding
|
||||
|
||||
patches = pad_sequence(batched_sequences)
|
||||
patch_positions = pad_sequence(batched_positions)
|
||||
|
||||
# need to know how many images for final attention pooling
|
||||
|
||||
num_images = torch.tensor(num_images, device = device, dtype = torch.long)
|
||||
|
||||
# to patches
|
||||
|
||||
x = self.to_patch_embedding(patches)
|
||||
|
||||
# factorized 2d absolute positional embedding
|
||||
|
||||
h_indices, w_indices = patch_positions.unbind(dim = -1)
|
||||
|
||||
h_pos = self.pos_embed_height[h_indices]
|
||||
w_pos = self.pos_embed_width[w_indices]
|
||||
|
||||
x = x + h_pos + w_pos
|
||||
|
||||
# embed dropout
|
||||
|
||||
x = self.dropout(x)
|
||||
|
||||
# attention
|
||||
|
||||
x = self.transformer(x, attn_mask = attn_mask)
|
||||
|
||||
# do attention pooling at the end
|
||||
|
||||
max_queries = num_images.amax().item()
|
||||
|
||||
queries = repeat(self.attn_pool_queries, 'd -> b n d', n = max_queries, b = x.shape[0])
|
||||
|
||||
# attention pool mask
|
||||
|
||||
image_id_arange = arange(max_queries)
|
||||
|
||||
attn_pool_mask = rearrange(image_id_arange, 'i -> i 1') == rearrange(batched_image_ids, 'b j -> b 1 j')
|
||||
|
||||
attn_pool_mask = attn_pool_mask & rearrange(key_pad_mask, 'b j -> b 1 j')
|
||||
|
||||
attn_pool_mask = rearrange(attn_pool_mask, 'b i j -> b 1 i j')
|
||||
|
||||
# attention pool
|
||||
|
||||
x = self.attn_pool(queries, context = x, attn_mask = attn_pool_mask) + queries
|
||||
|
||||
x = rearrange(x, 'b n d -> (b n) d')
|
||||
|
||||
# each batch element may not have same amount of images
|
||||
|
||||
is_images = image_id_arange < rearrange(num_images, 'b -> b 1')
|
||||
is_images = rearrange(is_images, 'b n -> (b n)')
|
||||
|
||||
x = x[is_images]
|
||||
|
||||
# project out to logits
|
||||
|
||||
x = self.to_latent(x)
|
||||
|
||||
return self.mlp_head(x)
|
||||
@@ -24,19 +24,11 @@ class LayerNorm(nn.Module):
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = LayerNorm(dim)
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, mlp_mult = 4, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
LayerNorm(dim),
|
||||
nn.Conv2d(dim, dim * mlp_mult, 1),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
@@ -54,6 +46,7 @@ class Attention(nn.Module):
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.norm = LayerNorm(dim)
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
|
||||
@@ -66,6 +59,8 @@ class Attention(nn.Module):
|
||||
def forward(self, x):
|
||||
b, c, h, w, heads = *x.shape, self.heads
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = 1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> b h (x y) d', h = heads), qkv)
|
||||
|
||||
@@ -93,8 +88,8 @@ class Transformer(nn.Module):
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, heads = heads, dropout = dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout))
|
||||
Attention(dim, heads = heads, dropout = dropout),
|
||||
FeedForward(dim, mlp_mult, dropout = dropout)
|
||||
]))
|
||||
def forward(self, x):
|
||||
*_, h, w = x.shape
|
||||
@@ -131,7 +126,7 @@ class NesT(nn.Module):
|
||||
fmap_size = image_size // patch_size
|
||||
blocks = 2 ** (num_hierarchies - 1)
|
||||
|
||||
seq_len = (fmap_size // blocks) ** 2 # sequence length is held constant across heirarchy
|
||||
seq_len = (fmap_size // blocks) ** 2 # sequence length is held constant across hierarchy
|
||||
hierarchies = list(reversed(range(num_hierarchies)))
|
||||
mults = [2 ** i for i in reversed(hierarchies)]
|
||||
|
||||
@@ -144,7 +139,9 @@ class NesT(nn.Module):
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (p1 p2 c) h w', p1 = patch_size, p2 = patch_size),
|
||||
LayerNorm(patch_dim),
|
||||
nn.Conv2d(patch_dim, layer_dims[0], 1),
|
||||
LayerNorm(layer_dims[0])
|
||||
)
|
||||
|
||||
block_repeats = cast_tuple(block_repeats, num_hierarchies)
|
||||
|
||||
@@ -19,18 +19,11 @@ class Parallel(nn.Module):
|
||||
def forward(self, x):
|
||||
return sum([fn(x) for fn in self.fns])
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
@@ -49,6 +42,7 @@ class Attention(nn.Module):
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
@@ -60,6 +54,7 @@ class Attention(nn.Module):
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
@@ -77,8 +72,8 @@ class Transformer(nn.Module):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
attn_block = lambda: PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))
|
||||
ff_block = lambda: PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
|
||||
attn_block = lambda: Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)
|
||||
ff_block = lambda: FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
|
||||
@@ -17,18 +17,11 @@ def conv_output_size(image_size, kernel_size, stride, padding = 0):
|
||||
|
||||
# classes
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
@@ -47,6 +40,7 @@ class Attention(nn.Module):
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
@@ -58,6 +52,8 @@ class Attention(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
b, n, _, h = *x.shape, self.heads
|
||||
|
||||
x = self.norm(x)
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
||||
|
||||
@@ -76,8 +72,8 @@ class Transformer(nn.Module):
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
|
||||
@@ -55,14 +55,6 @@ class DepthWiseConv2d(nn.Module):
|
||||
|
||||
# helper classes
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
class SpatialConv(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, kernel, bias = False):
|
||||
super().__init__()
|
||||
@@ -86,6 +78,7 @@ class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0., use_glu = True):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim * 2 if use_glu else hidden_dim),
|
||||
GEGLU() if use_glu else nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
@@ -103,6 +96,7 @@ class Attention(nn.Module):
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
@@ -121,6 +115,9 @@ class Attention(nn.Module):
|
||||
b, n, _, h = *x.shape, self.heads
|
||||
|
||||
to_q_kwargs = {'fmap_dims': fmap_dims} if self.use_ds_conv else {}
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
q = self.to_q(x, **to_q_kwargs)
|
||||
|
||||
qkv = (q, *self.to_kv(x).chunk(2, dim = -1))
|
||||
@@ -162,8 +159,8 @@ class Transformer(nn.Module):
|
||||
self.pos_emb = AxialRotaryEmbedding(dim_head, max_freq = image_size)
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_rotary = use_rotary, use_ds_conv = use_ds_conv)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout, use_glu = use_glu))
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_rotary = use_rotary, use_ds_conv = use_ds_conv),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout, use_glu = use_glu)
|
||||
]))
|
||||
def forward(self, x, fmap_dims):
|
||||
pos_emb = self.pos_emb(x[:, 1:])
|
||||
|
||||
@@ -33,15 +33,6 @@ class ChanLayerNorm(nn.Module):
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = ChanLayerNorm(dim)
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x):
|
||||
return self.fn(self.norm(x))
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
@@ -65,6 +56,7 @@ class FeedForward(nn.Module):
|
||||
super().__init__()
|
||||
inner_dim = dim * expansion_factor
|
||||
self.net = nn.Sequential(
|
||||
ChanLayerNorm(dim),
|
||||
nn.Conv2d(dim, inner_dim, 1),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
@@ -92,6 +84,7 @@ class ScalableSelfAttention(nn.Module):
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.norm = ChanLayerNorm(dim)
|
||||
self.to_q = nn.Conv2d(dim, dim_key * heads, 1, bias = False)
|
||||
self.to_k = nn.Conv2d(dim, dim_key * heads, reduction_factor, stride = reduction_factor, bias = False)
|
||||
self.to_v = nn.Conv2d(dim, dim_value * heads, reduction_factor, stride = reduction_factor, bias = False)
|
||||
@@ -104,6 +97,8 @@ class ScalableSelfAttention(nn.Module):
|
||||
def forward(self, x):
|
||||
height, width, heads = *x.shape[-2:], self.heads
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
|
||||
|
||||
# split out heads
|
||||
@@ -145,6 +140,7 @@ class InteractiveWindowedSelfAttention(nn.Module):
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.norm = ChanLayerNorm(dim)
|
||||
self.local_interactive_module = nn.Conv2d(dim_value * heads, dim_value * heads, 3, padding = 1)
|
||||
|
||||
self.to_q = nn.Conv2d(dim, dim_key * heads, 1, bias = False)
|
||||
@@ -159,6 +155,8 @@ class InteractiveWindowedSelfAttention(nn.Module):
|
||||
def forward(self, x):
|
||||
height, width, heads, wsz = *x.shape[-2:], self.heads, self.window_size
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
wsz_h, wsz_w = default(wsz, height), default(wsz, width)
|
||||
assert (height % wsz_h) == 0 and (width % wsz_w) == 0, f'height ({height}) or width ({width}) of feature map is not divisible by the window size ({wsz_h}, {wsz_w})'
|
||||
|
||||
@@ -217,11 +215,11 @@ class Transformer(nn.Module):
|
||||
is_first = ind == 0
|
||||
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, ScalableSelfAttention(dim, heads = heads, dim_key = ssa_dim_key, dim_value = ssa_dim_value, reduction_factor = ssa_reduction_factor, dropout = dropout)),
|
||||
PreNorm(dim, FeedForward(dim, expansion_factor = ff_expansion_factor, dropout = dropout)),
|
||||
ScalableSelfAttention(dim, heads = heads, dim_key = ssa_dim_key, dim_value = ssa_dim_value, reduction_factor = ssa_reduction_factor, dropout = dropout),
|
||||
FeedForward(dim, expansion_factor = ff_expansion_factor, dropout = dropout),
|
||||
PEG(dim) if is_first else None,
|
||||
PreNorm(dim, FeedForward(dim, expansion_factor = ff_expansion_factor, dropout = dropout)),
|
||||
PreNorm(dim, InteractiveWindowedSelfAttention(dim, heads = heads, dim_key = iwsa_dim_key, dim_value = iwsa_dim_value, window_size = iwsa_window_size, dropout = dropout))
|
||||
FeedForward(dim, expansion_factor = ff_expansion_factor, dropout = dropout),
|
||||
InteractiveWindowedSelfAttention(dim, heads = heads, dim_key = iwsa_dim_key, dim_value = iwsa_dim_value, window_size = iwsa_window_size, dropout = dropout)
|
||||
]))
|
||||
|
||||
self.norm = ChanLayerNorm(dim) if norm_output else nn.Identity()
|
||||
|
||||
@@ -25,15 +25,6 @@ class ChanLayerNorm(nn.Module):
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = ChanLayerNorm(dim)
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x):
|
||||
return self.fn(self.norm(x))
|
||||
|
||||
class OverlappingPatchEmbed(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, stride = 2):
|
||||
super().__init__()
|
||||
@@ -59,6 +50,7 @@ class FeedForward(nn.Module):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
self.net = nn.Sequential(
|
||||
ChanLayerNorm(dim),
|
||||
nn.Conv2d(dim, inner_dim, 1),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
@@ -85,6 +77,8 @@ class DSSA(nn.Module):
|
||||
self.window_size = window_size
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.norm = ChanLayerNorm(dim)
|
||||
|
||||
self.attend = nn.Sequential(
|
||||
nn.Softmax(dim = -1),
|
||||
nn.Dropout(dropout)
|
||||
@@ -138,6 +132,8 @@ class DSSA(nn.Module):
|
||||
assert (height % wsz) == 0 and (width % wsz) == 0, f'height {height} and width {width} must be divisible by window size {wsz}'
|
||||
num_windows = (height // wsz) * (width // wsz)
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
# fold in windows for "depthwise" attention - not sure why it is named depthwise when it is just "windowed" attention
|
||||
|
||||
x = rearrange(x, 'b c (h w1) (w w2) -> (b h w) c (w1 w2)', w1 = wsz, w2 = wsz)
|
||||
@@ -225,8 +221,8 @@ class Transformer(nn.Module):
|
||||
|
||||
for ind in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, DSSA(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mult = ff_mult, dropout = dropout)),
|
||||
DSSA(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mult = ff_mult, dropout = dropout),
|
||||
]))
|
||||
|
||||
self.norm = ChanLayerNorm(dim) if norm_output else nn.Identity()
|
||||
|
||||
@@ -18,8 +18,11 @@ class SimMIM(nn.Module):
|
||||
|
||||
self.encoder = encoder
|
||||
num_patches, encoder_dim = encoder.pos_embedding.shape[-2:]
|
||||
self.to_patch, self.patch_to_emb = encoder.to_patch_embedding[:2]
|
||||
pixel_values_per_patch = self.patch_to_emb.weight.shape[-1]
|
||||
|
||||
self.to_patch = encoder.to_patch_embedding[0]
|
||||
self.patch_to_emb = nn.Sequential(*encoder.to_patch_embedding[1:])
|
||||
|
||||
pixel_values_per_patch = encoder.to_patch_embedding[2].weight.shape[-1]
|
||||
|
||||
# simple linear head
|
||||
|
||||
|
||||
176
vit_pytorch/simple_flash_attn_vit.py
Normal file
176
vit_pytorch/simple_flash_attn_vit.py
Normal file
@@ -0,0 +1,176 @@
|
||||
from collections import namedtuple
|
||||
from packaging import version
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# constants
|
||||
|
||||
Config = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
|
||||
|
||||
# helpers
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
|
||||
_, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype
|
||||
|
||||
y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
|
||||
assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
|
||||
omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1)
|
||||
omega = 1. / (temperature ** omega)
|
||||
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
|
||||
return pe.type(dtype)
|
||||
|
||||
# main class
|
||||
|
||||
class Attend(nn.Module):
|
||||
def __init__(self, use_flash = False):
|
||||
super().__init__()
|
||||
self.use_flash = use_flash
|
||||
assert not (use_flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
|
||||
|
||||
# determine efficient attention configs for cuda and cpu
|
||||
|
||||
self.cpu_config = Config(True, True, True)
|
||||
self.cuda_config = None
|
||||
|
||||
if not torch.cuda.is_available() or not use_flash:
|
||||
return
|
||||
|
||||
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
|
||||
|
||||
if device_properties.major == 8 and device_properties.minor == 0:
|
||||
self.cuda_config = Config(True, False, False)
|
||||
else:
|
||||
self.cuda_config = Config(False, True, True)
|
||||
|
||||
def flash_attn(self, q, k, v):
|
||||
config = self.cuda_config if q.is_cuda else self.cpu_config
|
||||
|
||||
# flash attention - https://arxiv.org/abs/2205.14135
|
||||
|
||||
with torch.backends.cuda.sdp_kernel(**config._asdict()):
|
||||
out = F.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
return out
|
||||
|
||||
def forward(self, q, k, v):
|
||||
n, device, scale = q.shape[-2], q.device, q.shape[-1] ** -0.5
|
||||
|
||||
if self.use_flash:
|
||||
return self.flash_attn(q, k, v)
|
||||
|
||||
# similarity
|
||||
|
||||
sim = einsum("b h i d, b j d -> b h i j", q, k) * scale
|
||||
|
||||
# attention
|
||||
|
||||
attn = sim.softmax(dim=-1)
|
||||
|
||||
# aggregate values
|
||||
|
||||
out = einsum("b h i j, b j d -> b h i d", attn, v)
|
||||
|
||||
return out
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, use_flash = True):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.attend = Attend(use_flash = use_flash)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
out = self.attend(q, k, v)
|
||||
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, use_flash):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, use_flash = use_flash),
|
||||
FeedForward(dim, mlp_dim)
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
class SimpleViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, use_flash = True):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_height, p2 = patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, use_flash)
|
||||
|
||||
self.to_latent = nn.Identity()
|
||||
self.linear_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
*_, h, w, dtype = *img.shape, img.dtype
|
||||
|
||||
x = self.to_patch_embedding(img)
|
||||
pe = posemb_sincos_2d(x)
|
||||
x = rearrange(x, 'b ... d -> b (...) d') + pe
|
||||
|
||||
x = self.transformer(x)
|
||||
x = x.mean(dim = 1)
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.linear_head(x)
|
||||
@@ -9,17 +9,15 @@ from einops.layers.torch import Rearrange
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
|
||||
_, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype
|
||||
|
||||
y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
|
||||
assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
|
||||
omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1)
|
||||
omega = 1. / (temperature ** omega)
|
||||
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
|
||||
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
|
||||
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
|
||||
omega = torch.arange(dim // 4) / (dim // 4 - 1)
|
||||
omega = 1.0 / (temperature ** omega)
|
||||
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
|
||||
return pe.type(dtype)
|
||||
|
||||
# classes
|
||||
@@ -66,6 +64,7 @@ class Attention(nn.Module):
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
@@ -76,7 +75,7 @@ class Transformer(nn.Module):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
return self.norm(x)
|
||||
|
||||
class SimpleViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
|
||||
@@ -86,28 +85,33 @@ class SimpleViT(nn.Module):
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_height, p2 = patch_width),
|
||||
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = posemb_sincos_2d(
|
||||
h = image_height // patch_height,
|
||||
w = image_width // patch_width,
|
||||
dim = dim,
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
|
||||
|
||||
self.pool = "mean"
|
||||
self.to_latent = nn.Identity()
|
||||
self.linear_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
self.linear_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, img):
|
||||
*_, h, w, dtype = *img.shape, img.dtype
|
||||
device = img.device
|
||||
|
||||
x = self.to_patch_embedding(img)
|
||||
pe = posemb_sincos_2d(x)
|
||||
x = rearrange(x, 'b ... d -> b (...) d') + pe
|
||||
x += self.pos_embedding.to(device, dtype=x.dtype)
|
||||
|
||||
x = self.transformer(x)
|
||||
x = x.mean(dim = 1)
|
||||
|
||||
125
vit_pytorch/simple_vit_1d.py
Normal file
125
vit_pytorch/simple_vit_1d.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def posemb_sincos_1d(patches, temperature = 10000, dtype = torch.float32):
|
||||
_, n, dim, device, dtype = *patches.shape, patches.device, patches.dtype
|
||||
|
||||
n = torch.arange(n, device = device)
|
||||
assert (dim % 2) == 0, 'feature dimension must be multiple of 2 for sincos emb'
|
||||
omega = torch.arange(dim // 2, device = device) / (dim // 2 - 1)
|
||||
omega = 1. / (temperature ** omega)
|
||||
|
||||
n = n.flatten()[:, None] * omega[None, :]
|
||||
pe = torch.cat((n.sin(), n.cos()), dim = 1)
|
||||
return pe.type(dtype)
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head),
|
||||
FeedForward(dim, mlp_dim)
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return self.norm(x)
|
||||
|
||||
class SimpleViT(nn.Module):
|
||||
def __init__(self, *, seq_len, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
|
||||
super().__init__()
|
||||
|
||||
assert seq_len % patch_size == 0
|
||||
|
||||
num_patches = seq_len // patch_size
|
||||
patch_dim = channels * patch_size
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (n p) -> b n (p c)', p = patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
|
||||
|
||||
self.to_latent = nn.Identity()
|
||||
self.linear_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, series):
|
||||
*_, n, dtype = *series.shape, series.dtype
|
||||
|
||||
x = self.to_patch_embedding(series)
|
||||
pe = posemb_sincos_1d(x)
|
||||
x = rearrange(x, 'b ... d -> b (...) d') + pe
|
||||
|
||||
x = self.transformer(x)
|
||||
x = x.mean(dim = 1)
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.linear_head(x)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
v = SimpleViT(
|
||||
seq_len = 256,
|
||||
patch_size = 16,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048
|
||||
)
|
||||
|
||||
time_series = torch.randn(4, 3, 256)
|
||||
logits = v(time_series) # (4, 1000)
|
||||
128
vit_pytorch/simple_vit_3d.py
Normal file
128
vit_pytorch/simple_vit_3d.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def posemb_sincos_3d(patches, temperature = 10000, dtype = torch.float32):
|
||||
_, f, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype
|
||||
|
||||
z, y, x = torch.meshgrid(
|
||||
torch.arange(f, device = device),
|
||||
torch.arange(h, device = device),
|
||||
torch.arange(w, device = device),
|
||||
indexing = 'ij')
|
||||
|
||||
fourier_dim = dim // 6
|
||||
|
||||
omega = torch.arange(fourier_dim, device = device) / (fourier_dim - 1)
|
||||
omega = 1. / (temperature ** omega)
|
||||
|
||||
z = z.flatten()[:, None] * omega[None, :]
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos(), z.sin(), z.cos()), dim = 1)
|
||||
|
||||
pe = F.pad(pe, (0, dim - (fourier_dim * 6))) # pad if feature dimension not cleanly divisible by 6
|
||||
return pe.type(dtype)
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head),
|
||||
FeedForward(dim, mlp_dim)
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return self.norm(x)
|
||||
|
||||
class SimpleViT(nn.Module):
|
||||
def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(image_patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
assert frames % frame_patch_size == 0, 'Frames must be divisible by the frame patch size'
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width) * (frames // frame_patch_size)
|
||||
patch_dim = channels * patch_height * patch_width * frame_patch_size
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
|
||||
|
||||
self.to_latent = nn.Identity()
|
||||
self.linear_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, video):
|
||||
*_, h, w, dtype = *video.shape, video.dtype
|
||||
|
||||
x = self.to_patch_embedding(video)
|
||||
pe = posemb_sincos_3d(x)
|
||||
x = rearrange(x, 'b ... d -> b (...) d') + pe
|
||||
|
||||
x = self.transformer(x)
|
||||
x = x.mean(dim = 1)
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.linear_head(x)
|
||||
141
vit_pytorch/simple_vit_with_patch_dropout.py
Normal file
141
vit_pytorch/simple_vit_with_patch_dropout.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
|
||||
_, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype
|
||||
|
||||
y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
|
||||
assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
|
||||
omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1)
|
||||
omega = 1. / (temperature ** omega)
|
||||
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
|
||||
return pe.type(dtype)
|
||||
|
||||
# patch dropout
|
||||
|
||||
class PatchDropout(nn.Module):
|
||||
def __init__(self, prob):
|
||||
super().__init__()
|
||||
assert 0 <= prob < 1.
|
||||
self.prob = prob
|
||||
|
||||
def forward(self, x):
|
||||
if not self.training or self.prob == 0.:
|
||||
return x
|
||||
|
||||
b, n, _, device = *x.shape, x.device
|
||||
|
||||
batch_indices = torch.arange(b, device = device)
|
||||
batch_indices = rearrange(batch_indices, '... -> ... 1')
|
||||
num_patches_keep = max(1, int(n * (1 - self.prob)))
|
||||
patch_indices_keep = torch.randn(b, n, device = device).topk(num_patches_keep, dim = -1).indices
|
||||
|
||||
return x[batch_indices, patch_indices_keep]
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head),
|
||||
FeedForward(dim, mlp_dim)
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return self.norm(x)
|
||||
|
||||
class SimpleViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, patch_dropout = 0.5):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_height, p2 = patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim)
|
||||
)
|
||||
|
||||
self.patch_dropout = PatchDropout(patch_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
|
||||
|
||||
self.to_latent = nn.Identity()
|
||||
self.linear_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, img):
|
||||
*_, h, w, dtype = *img.shape, img.dtype
|
||||
|
||||
x = self.to_patch_embedding(img)
|
||||
pe = posemb_sincos_2d(x)
|
||||
x = rearrange(x, 'b ... d -> b (...) d') + pe
|
||||
|
||||
x = self.patch_dropout(x)
|
||||
|
||||
x = self.transformer(x)
|
||||
x = x.mean(dim = 1)
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.linear_head(x)
|
||||
141
vit_pytorch/simple_vit_with_qk_norm.py
Normal file
141
vit_pytorch/simple_vit_with_qk_norm.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
|
||||
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
|
||||
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
|
||||
omega = torch.arange(dim // 4) / (dim // 4 - 1)
|
||||
omega = 1.0 / (temperature ** omega)
|
||||
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
|
||||
return pe.type(dtype)
|
||||
|
||||
# they use a query-key normalization that is equivalent to rms norm (no mean-centering, learned gamma), from vit 22B paper
|
||||
|
||||
# in latest tweet, seem to claim more stable training at higher learning rates
|
||||
# unsure if this has taken off within Brain, or it has some hidden drawback
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, heads, dim):
|
||||
super().__init__()
|
||||
self.scale = dim ** 0.5
|
||||
self.gamma = nn.Parameter(torch.ones(heads, 1, dim) / self.scale)
|
||||
|
||||
def forward(self, x):
|
||||
normed = F.normalize(x, dim = -1)
|
||||
return normed * self.scale * self.gamma
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
self.q_norm = RMSNorm(heads, dim_head)
|
||||
self.k_norm = RMSNorm(heads, dim_head)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2))
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head),
|
||||
FeedForward(dim, mlp_dim)
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return self.norm(x)
|
||||
|
||||
class SimpleViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = posemb_sincos_2d(
|
||||
h = image_height // patch_height,
|
||||
w = image_width // patch_width,
|
||||
dim = dim,
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
|
||||
|
||||
self.pool = "mean"
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.linear_head = nn.LayerNorm(dim)
|
||||
|
||||
def forward(self, img):
|
||||
device = img.device
|
||||
|
||||
x = self.to_patch_embedding(img)
|
||||
x += self.pos_embedding.to(device, dtype=x.dtype)
|
||||
|
||||
x = self.transformer(x)
|
||||
x = x.mean(dim = 1)
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.linear_head(x)
|
||||
129
vit_pytorch/simple_vit_with_register_tokens.py
Normal file
129
vit_pytorch/simple_vit_with_register_tokens.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from einops import rearrange, repeat, pack, unpack
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
|
||||
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
|
||||
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
|
||||
omega = torch.arange(dim // 4) / (dim // 4 - 1)
|
||||
omega = 1.0 / (temperature ** omega)
|
||||
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
|
||||
return pe.type(dtype)
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head),
|
||||
FeedForward(dim, mlp_dim)
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return self.norm(x)
|
||||
|
||||
class SimpleViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, num_register_tokens = 4, channels = 3, dim_head = 64):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim))
|
||||
|
||||
self.pos_embedding = posemb_sincos_2d(
|
||||
h = image_height // patch_height,
|
||||
w = image_width // patch_width,
|
||||
dim = dim,
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
|
||||
|
||||
self.pool = "mean"
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.linear_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, img):
|
||||
batch, device = img.shape[0], img.device
|
||||
|
||||
x = self.to_patch_embedding(img)
|
||||
x += self.pos_embedding.to(device, dtype=x.dtype)
|
||||
|
||||
r = repeat(self.register_tokens, 'n d -> b n d', b = batch)
|
||||
|
||||
x, ps = pack([x, r], 'b * d')
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
x, _ = unpack(x, ps, 'b * d')
|
||||
|
||||
x = x.mean(dim = 1)
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.linear_head(x)
|
||||
@@ -42,20 +42,11 @@ class LayerNorm(nn.Module):
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = LayerNorm(dim)
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
x = self.norm(x)
|
||||
return self.fn(x, **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, mult = 4, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
LayerNorm(dim),
|
||||
nn.Conv2d(dim, dim * mult, 1),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
@@ -71,7 +62,12 @@ class PatchEmbedding(nn.Module):
|
||||
self.dim = dim
|
||||
self.dim_out = dim_out
|
||||
self.patch_size = patch_size
|
||||
self.proj = nn.Conv2d(patch_size ** 2 * dim, dim_out, 1)
|
||||
|
||||
self.proj = nn.Sequential(
|
||||
LayerNorm(patch_size ** 2 * dim),
|
||||
nn.Conv2d(patch_size ** 2 * dim, dim_out, 1),
|
||||
LayerNorm(dim_out)
|
||||
)
|
||||
|
||||
def forward(self, fmap):
|
||||
p = self.patch_size
|
||||
@@ -94,6 +90,7 @@ class LocalAttention(nn.Module):
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.norm = LayerNorm(dim)
|
||||
self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False)
|
||||
self.to_kv = nn.Conv2d(dim, inner_dim * 2, 1, bias = False)
|
||||
|
||||
@@ -103,6 +100,8 @@ class LocalAttention(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, fmap):
|
||||
fmap = self.norm(fmap)
|
||||
|
||||
shape, p = fmap.shape, self.patch_size
|
||||
b, n, x, y, h = *shape, self.heads
|
||||
x, y = map(lambda t: t // p, (x, y))
|
||||
@@ -127,6 +126,8 @@ class GlobalAttention(nn.Module):
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.norm = LayerNorm(dim)
|
||||
|
||||
self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False)
|
||||
self.to_kv = nn.Conv2d(dim, inner_dim * 2, k, stride = k, bias = False)
|
||||
|
||||
@@ -138,6 +139,8 @@ class GlobalAttention(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
shape = x.shape
|
||||
b, n, _, y, h = *shape, self.heads
|
||||
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = 1))
|
||||
@@ -159,10 +162,10 @@ class Transformer(nn.Module):
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Residual(PreNorm(dim, LocalAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout, patch_size = local_patch_size))) if has_local else nn.Identity(),
|
||||
Residual(PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout))) if has_local else nn.Identity(),
|
||||
Residual(PreNorm(dim, GlobalAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout, k = global_k))),
|
||||
Residual(PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout)))
|
||||
Residual(LocalAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout, patch_size = local_patch_size)) if has_local else nn.Identity(),
|
||||
Residual(FeedForward(dim, mlp_mult, dropout = dropout)) if has_local else nn.Identity(),
|
||||
Residual(GlobalAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout, k = global_k)),
|
||||
Residual(FeedForward(dim, mlp_mult, dropout = dropout))
|
||||
]))
|
||||
def forward(self, x):
|
||||
for local_attn, ff1, global_attn, ff2 in self.layers:
|
||||
|
||||
@@ -11,24 +11,18 @@ def pair(t):
|
||||
|
||||
# classes
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
@@ -41,6 +35,8 @@ class Attention(nn.Module):
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
@@ -52,6 +48,8 @@ class Attention(nn.Module):
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
@@ -67,17 +65,20 @@ class Attention(nn.Module):
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class ViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
|
||||
@@ -93,7 +94,9 @@ class ViT(nn.Module):
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
@@ -105,16 +108,13 @@ class ViT(nn.Module):
|
||||
self.pool = pool
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
self.mlp_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, img):
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = b)
|
||||
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.dropout(x)
|
||||
|
||||
130
vit_pytorch/vit_1d.py
Normal file
130
vit_pytorch/vit_1d.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from einops import rearrange, repeat, pack, unpack
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Layernorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
class ViT(nn.Module):
|
||||
def __init__(self, *, seq_len, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
|
||||
super().__init__()
|
||||
assert (seq_len % patch_size) == 0
|
||||
|
||||
num_patches = seq_len // patch_size
|
||||
patch_dim = channels * patch_size
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (n p) -> b n (p c)', p = patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, series):
|
||||
x = self.to_patch_embedding(series)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, 'd -> b d', b = b)
|
||||
|
||||
x, ps = pack([cls_tokens, x], 'b * d')
|
||||
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
cls_tokens, _ = unpack(x, ps, 'b * d')
|
||||
|
||||
return self.mlp_head(cls_tokens)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
v = ViT(
|
||||
seq_len = 256,
|
||||
patch_size = 16,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
time_series = torch.randn(4, 3, 256)
|
||||
logits = v(time_series) # (4, 1000)
|
||||
126
vit_pytorch/vit_3d.py
Normal file
126
vit_pytorch/vit_3d.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
class ViT(nn.Module):
|
||||
def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(image_patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size'
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width) * (frames // frame_patch_size)
|
||||
patch_dim = channels * patch_height * patch_width * frame_patch_size
|
||||
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b (f h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
self.pool = pool
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, video):
|
||||
x = self.to_patch_embedding(video)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.mlp_head(x)
|
||||
@@ -13,18 +13,11 @@ def pair(t):
|
||||
|
||||
# classes
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
@@ -41,6 +34,7 @@ class LSA(nn.Module):
|
||||
self.heads = heads
|
||||
self.temperature = nn.Parameter(torch.log(torch.tensor(dim_head ** -0.5)))
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
@@ -52,6 +46,7 @@ class LSA(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
@@ -74,8 +69,8 @@ class Transformer(nn.Module):
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, LSA(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
|
||||
LSA(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
|
||||
147
vit_pytorch/vit_with_patch_dropout.py
Normal file
147
vit_pytorch/vit_with_patch_dropout.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
# classes
|
||||
|
||||
class PatchDropout(nn.Module):
|
||||
def __init__(self, prob):
|
||||
super().__init__()
|
||||
assert 0 <= prob < 1.
|
||||
self.prob = prob
|
||||
|
||||
def forward(self, x):
|
||||
if not self.training or self.prob == 0.:
|
||||
return x
|
||||
|
||||
b, n, _, device = *x.shape, x.device
|
||||
|
||||
batch_indices = torch.arange(b, device = device)
|
||||
batch_indices = rearrange(batch_indices, '... -> ... 1')
|
||||
num_patches_keep = max(1, int(n * (1 - self.prob)))
|
||||
patch_indices_keep = torch.randn(b, n, device = device).topk(num_patches_keep, dim = -1).indices
|
||||
|
||||
return x[batch_indices, patch_indices_keep]
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
class ViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., patch_dropout = 0.25):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
|
||||
nn.Linear(patch_dim, dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(num_patches, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
|
||||
self.patch_dropout = PatchDropout(patch_dropout)
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
self.pool = pool
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
x += self.pos_embedding
|
||||
|
||||
x = self.patch_dropout(x)
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
|
||||
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.mlp_head(x)
|
||||
@@ -32,18 +32,11 @@ class PatchMerger(nn.Module):
|
||||
|
||||
# classes
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
@@ -62,6 +55,7 @@ class Attention(nn.Module):
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
@@ -73,6 +67,7 @@ class Attention(nn.Module):
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
@@ -88,6 +83,7 @@ class Attention(nn.Module):
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., patch_merge_layer = None, patch_merge_num_tokens = 8):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
self.patch_merge_layer_index = default(patch_merge_layer, depth // 2) - 1 # default to mid-way through transformer, as shown in paper
|
||||
@@ -95,8 +91,8 @@ class Transformer(nn.Module):
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
def forward(self, x):
|
||||
for index, (attn, ff) in enumerate(self.layers):
|
||||
@@ -106,7 +102,7 @@ class Transformer(nn.Module):
|
||||
if index == self.patch_merge_layer_index:
|
||||
x = self.patch_merger(x)
|
||||
|
||||
return x
|
||||
return self.norm(x)
|
||||
|
||||
class ViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, patch_merge_layer = None, patch_merge_num_tokens = 8, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
|
||||
@@ -121,7 +117,9 @@ class ViT(nn.Module):
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim)
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
@@ -131,7 +129,6 @@ class ViT(nn.Module):
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
Reduce('b n d -> b d', 'mean'),
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
|
||||
178
vit_pytorch/vivit.py
Normal file
178
vit_pytorch/vivit.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from einops import rearrange, repeat, reduce
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return self.norm(x)
|
||||
|
||||
class ViT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_size,
|
||||
image_patch_size,
|
||||
frames,
|
||||
frame_patch_size,
|
||||
num_classes,
|
||||
dim,
|
||||
spatial_depth,
|
||||
temporal_depth,
|
||||
heads,
|
||||
mlp_dim,
|
||||
pool = 'cls',
|
||||
channels = 3,
|
||||
dim_head = 64,
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(image_patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size'
|
||||
|
||||
num_image_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
num_frame_patches = (frames // frame_patch_size)
|
||||
|
||||
patch_dim = channels * patch_height * patch_width * frame_patch_size
|
||||
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
|
||||
self.global_average_pool = pool == 'mean'
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b f (h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim)
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_frame_patches, num_image_patches, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.spatial_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None
|
||||
self.temporal_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None
|
||||
|
||||
self.spatial_transformer = Transformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)
|
||||
self.temporal_transformer = Transformer(dim, temporal_depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
self.pool = pool
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, video):
|
||||
x = self.to_patch_embedding(video)
|
||||
b, f, n, _ = x.shape
|
||||
|
||||
x = x + self.pos_embedding[:, :f, :n]
|
||||
|
||||
if exists(self.spatial_cls_token):
|
||||
spatial_cls_tokens = repeat(self.spatial_cls_token, '1 1 d -> b f 1 d', b = b, f = f)
|
||||
x = torch.cat((spatial_cls_tokens, x), dim = 2)
|
||||
|
||||
x = self.dropout(x)
|
||||
|
||||
x = rearrange(x, 'b f n d -> (b f) n d')
|
||||
|
||||
# attend across space
|
||||
|
||||
x = self.spatial_transformer(x)
|
||||
|
||||
x = rearrange(x, '(b f) n d -> b f n d', b = b)
|
||||
|
||||
# excise out the spatial cls tokens or average pool for temporal attention
|
||||
|
||||
x = x[:, :, 0] if not self.global_average_pool else reduce(x, 'b f n d -> b f d', 'mean')
|
||||
|
||||
# append temporal CLS tokens
|
||||
|
||||
if exists(self.temporal_cls_token):
|
||||
temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = b)
|
||||
|
||||
x = torch.cat((temporal_cls_tokens, x), dim = 1)
|
||||
|
||||
# attend across time
|
||||
|
||||
x = self.temporal_transformer(x)
|
||||
|
||||
# excise out temporal cls token or average pool
|
||||
|
||||
x = x[:, 0] if not self.global_average_pool else reduce(x, 'b f d -> b d', 'mean')
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.mlp_head(x)
|
||||
Reference in New Issue
Block a user