Compare commits
79 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
81661e3966 | ||
|
|
13f8e123bb | ||
|
|
2d4089c88e | ||
|
|
c7bb5fc43f | ||
|
|
946b19be64 | ||
|
|
d93cd84ccd | ||
|
|
5d4c798949 | ||
|
|
d65a742efe | ||
|
|
8c54e01492 | ||
|
|
df656fe7c7 | ||
|
|
4e6a42a0ca | ||
|
|
6d7298d8ad | ||
|
|
9cd56ff29b | ||
|
|
2aae406ce8 | ||
|
|
c2b2db2a54 | ||
|
|
719048d1bd | ||
|
|
d27721a85a | ||
|
|
cb22cbbd19 | ||
|
|
6db20debb4 | ||
|
|
1bae5d3cc5 | ||
|
|
25b384297d | ||
|
|
64a07f50e6 | ||
|
|
126d204ff2 | ||
|
|
c1528acd46 | ||
|
|
1cc0f182a6 | ||
|
|
28eaba6115 | ||
|
|
0082301f9e | ||
|
|
91ed738731 | ||
|
|
1b58daa20a | ||
|
|
f2414b2c1b | ||
|
|
891b92eb74 | ||
|
|
70ba532599 | ||
|
|
e52ac41955 | ||
|
|
0891885485 | ||
|
|
976f489230 | ||
|
|
2c368d1d4e | ||
|
|
b983bbee39 | ||
|
|
86a7302ba6 | ||
|
|
89d3a04b3f | ||
|
|
e7075c64aa | ||
|
|
5ea1559e4c | ||
|
|
f4b0b14094 | ||
|
|
365b4d931e | ||
|
|
79c864d796 | ||
|
|
b45c1356a1 | ||
|
|
ff44d97cb0 | ||
|
|
d35345df6a | ||
|
|
b69b5af34f | ||
|
|
36e32b70fb | ||
|
|
768e47441e | ||
|
|
de0b8ba189 | ||
|
|
6665fc6cd1 | ||
|
|
5b2382f9f0 | ||
|
|
9f8c60651d | ||
|
|
5ae555750f | ||
|
|
c5a461661c | ||
|
|
e212918e2d | ||
|
|
dc57c75478 | ||
|
|
99c44cf5f6 | ||
|
|
5b16e8f809 | ||
|
|
e8f6d72033 | ||
|
|
cb1729af28 | ||
|
|
9e50b2a41e | ||
|
|
06d375351e | ||
|
|
f196d1ec5b | ||
|
|
529044c9b3 | ||
|
|
c30655f3bc | ||
|
|
d2d6de01d3 | ||
|
|
b9eadaef60 | ||
|
|
24ac8350bf | ||
|
|
ca3cef9de0 | ||
|
|
6e1be11517 | ||
|
|
73ed562ce4 | ||
|
|
ff863175a6 | ||
|
|
ca0bdca192 | ||
|
|
1c70271778 | ||
|
|
d7d3febfe3 | ||
|
|
946815164a | ||
|
|
aeed3381c1 |
3
.github/FUNDING.yml
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
# These are supported funding model platforms
|
||||
|
||||
github: [lucidrains]
|
||||
33
.github/workflows/python-test.yml
vendored
Normal file
@@ -0,0 +1,33 @@
|
||||
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
|
||||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
|
||||
|
||||
name: Test
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install pytest
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
python setup.py test
|
||||
1
MANIFEST.in
Normal file
@@ -0,0 +1 @@
|
||||
recursive-include tests *
|
||||
763
README.md
@@ -1,5 +1,45 @@
|
||||
<img src="./images/vit.gif" width="500px"></img>
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Vision Transformer - Pytorch](#vision-transformer---pytorch)
|
||||
- [Install](#install)
|
||||
- [Usage](#usage)
|
||||
- [Parameters](#parameters)
|
||||
- [Distillation](#distillation)
|
||||
- [Deep ViT](#deep-vit)
|
||||
- [CaiT](#cait)
|
||||
- [Token-to-Token ViT](#token-to-token-vit)
|
||||
- [CCT](#cct)
|
||||
- [Cross ViT](#cross-vit)
|
||||
- [PiT](#pit)
|
||||
- [LeViT](#levit)
|
||||
- [CvT](#cvt)
|
||||
- [Twins SVT](#twins-svt)
|
||||
- [CrossFormer](#crossformer)
|
||||
- [RegionViT](#regionvit)
|
||||
- [ScalableViT](#scalablevit)
|
||||
- [SepViT](#sepvit)
|
||||
- [MaxViT](#maxvit)
|
||||
- [NesT](#nest)
|
||||
- [MobileViT](#mobilevit)
|
||||
- [Masked Autoencoder](#masked-autoencoder)
|
||||
- [Simple Masked Image Modeling](#simple-masked-image-modeling)
|
||||
- [Masked Patch Prediction](#masked-patch-prediction)
|
||||
- [Adaptive Token Sampling](#adaptive-token-sampling)
|
||||
- [Patch Merger](#patch-merger)
|
||||
- [Vision Transformer for Small Datasets](#vision-transformer-for-small-datasets)
|
||||
- [Parallel ViT](#parallel-vit)
|
||||
- [Learnable Memory ViT](#learnable-memory-vit)
|
||||
- [Dino](#dino)
|
||||
- [Accessing Attention](#accessing-attention)
|
||||
- [Research Ideas](#research-ideas)
|
||||
* [Efficient Attention](#efficient-attention)
|
||||
* [Combining with other Transformer improvements](#combining-with-other-transformer-improvements)
|
||||
- [FAQ](#faq)
|
||||
- [Resources](#resources)
|
||||
- [Citations](#citations)
|
||||
|
||||
## Vision Transformer - Pytorch
|
||||
|
||||
Implementation of <a href="https://openreview.net/pdf?id=YicbFdNTTy">Vision Transformer</a>, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. Significance is further explained in <a href="https://www.youtube.com/watch?v=TrdevFK_am4">Yannic Kilcher's</a> video. There's really not much to code here, but may as well lay it out for everyone so we expedite the attention revolution.
|
||||
@@ -8,6 +48,8 @@ For a Pytorch implementation with pretrained models, please see Ross Wightman's
|
||||
|
||||
The official Jax repository is <a href="https://github.com/google-research/vision_transformer">here</a>.
|
||||
|
||||
A tensorflow2 translation also exists <a href="https://github.com/taki0112/vit-tensorflow">here</a>, created by research scientist <a href="https://github.com/taki0112">Junho Kim</a>! 🙏
|
||||
|
||||
## Install
|
||||
|
||||
```bash
|
||||
@@ -204,6 +246,7 @@ preds = v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## CCT
|
||||
|
||||
<img src="https://raw.githubusercontent.com/SHI-Labs/Compact-Transformers/main/images/model_sym.png" width="400px"></img>
|
||||
|
||||
<a href="https://arxiv.org/abs/2104.05704">CCT</a> proposes compact transformers
|
||||
@@ -215,22 +258,25 @@ You can use this with two methods
|
||||
import torch
|
||||
from vit_pytorch.cct import CCT
|
||||
|
||||
model = CCT(
|
||||
img_size=224,
|
||||
embedding_dim=384,
|
||||
n_conv_layers=2,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
padding=3,
|
||||
pooling_kernel_size=3,
|
||||
pooling_stride=2,
|
||||
pooling_padding=1,
|
||||
num_layers=14,
|
||||
num_heads=6,
|
||||
mlp_radio=3.,
|
||||
num_classes=1000,
|
||||
positional_embedding='learnable', # ['sine', 'learnable', 'none']
|
||||
)
|
||||
cct = CCT(
|
||||
img_size = (224, 448),
|
||||
embedding_dim = 384,
|
||||
n_conv_layers = 2,
|
||||
kernel_size = 7,
|
||||
stride = 2,
|
||||
padding = 3,
|
||||
pooling_kernel_size = 3,
|
||||
pooling_stride = 2,
|
||||
pooling_padding = 1,
|
||||
num_layers = 14,
|
||||
num_heads = 6,
|
||||
mlp_radio = 3.,
|
||||
num_classes = 1000,
|
||||
positional_embedding = 'learnable', # ['sine', 'learnable', 'none']
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 224, 448)
|
||||
pred = cct(img) # (1, 1000)
|
||||
```
|
||||
|
||||
Alternatively you can use one of several pre-defined models `[2,4,6,7,8,14,16]`
|
||||
@@ -241,23 +287,23 @@ and the embedding dimension.
|
||||
import torch
|
||||
from vit_pytorch.cct import cct_14
|
||||
|
||||
model = cct_14(
|
||||
img_size=224,
|
||||
n_conv_layers=1,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
padding=3,
|
||||
pooling_kernel_size=3,
|
||||
pooling_stride=2,
|
||||
pooling_padding=1,
|
||||
num_classes=1000,
|
||||
positional_embedding='learnable', # ['sine', 'learnable', 'none']
|
||||
)
|
||||
cct = cct_14(
|
||||
img_size = 224,
|
||||
n_conv_layers = 1,
|
||||
kernel_size = 7,
|
||||
stride = 2,
|
||||
padding = 3,
|
||||
pooling_kernel_size = 3,
|
||||
pooling_stride = 2,
|
||||
pooling_padding = 1,
|
||||
num_classes = 1000,
|
||||
positional_embedding = 'learnable', # ['sine', 'learnable', 'none']
|
||||
)
|
||||
```
|
||||
|
||||
<a href="https://github.com/SHI-Labs/Compact-Transformers">Official
|
||||
Repository</a> includes links to pretrained model checkpoints.
|
||||
|
||||
|
||||
## Cross ViT
|
||||
|
||||
<img src="./images/cross_vit.png" width="400px"></img>
|
||||
@@ -435,6 +481,153 @@ img = torch.randn(1, 3, 224, 224)
|
||||
pred = model(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## RegionViT
|
||||
|
||||
<img src="./images/regionvit.png" width="400px"></img>
|
||||
|
||||
<img src="./images/regionvit2.png" width="400px"></img>
|
||||
|
||||
<a href="https://arxiv.org/abs/2106.02689">This paper</a> proposes to divide up the feature map into local regions, whereby the local tokens attend to each other. Each local region has its own regional token which then attends to all its local tokens, as well as other regional tokens.
|
||||
|
||||
You can use it as follows
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.regionvit import RegionViT
|
||||
|
||||
model = RegionViT(
|
||||
dim = (64, 128, 256, 512), # tuple of size 4, indicating dimension at each stage
|
||||
depth = (2, 2, 8, 2), # depth of the region to local transformer at each stage
|
||||
window_size = 7, # window size, which should be either 7 or 14
|
||||
num_classes = 1000, # number of output classes
|
||||
tokenize_local_3_conv = False, # whether to use a 3 layer convolution to encode the local tokens from the image. the paper uses this for the smaller models, but uses only 1 conv (set to False) for the larger models
|
||||
use_peg = False, # whether to use positional generating module. they used this for object detection for a boost in performance
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 224, 224)
|
||||
|
||||
pred = model(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## CrossFormer
|
||||
|
||||
<img src="./images/crossformer.png" width="400px"></img>
|
||||
|
||||
<img src="./images/crossformer2.png" width="400px"></img>
|
||||
|
||||
This <a href="https://arxiv.org/abs/2108.00154">paper</a> beats PVT and Swin using alternating local and global attention. The global attention is done across the windowing dimension for reduced complexity, much like the scheme used for axial attention.
|
||||
|
||||
They also have cross-scale embedding layer, which they shown to be a generic layer that can improve all vision transformers. Dynamic relative positional bias was also formulated to allow the net to generalize to images of greater resolution.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.crossformer import CrossFormer
|
||||
|
||||
model = CrossFormer(
|
||||
num_classes = 1000, # number of output classes
|
||||
dim = (64, 128, 256, 512), # dimension at each stage
|
||||
depth = (2, 2, 8, 2), # depth of transformer at each stage
|
||||
global_window_size = (8, 4, 2, 1), # global window sizes at each stage
|
||||
local_window_size = 7, # local window size (can be customized for each stage, but in paper, held constant at 7 for all stages)
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 224, 224)
|
||||
|
||||
pred = model(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## ScalableViT
|
||||
|
||||
<img src="./images/scalable-vit-1.png" width="400px"></img>
|
||||
|
||||
<img src="./images/scalable-vit-2.png" width="400px"></img>
|
||||
|
||||
This Bytedance AI <a href="https://arxiv.org/abs/2203.10790">paper</a> proposes the Scalable Self Attention (SSA) and the Interactive Windowed Self Attention (IWSA) modules. The SSA alleviates the computation needed at earlier stages by reducing the key / value feature map by some factor (`reduction_factor`), while modulating the dimension of the queries and keys (`ssa_dim_key`). The IWSA performs self attention within local windows, similar to other vision transformer papers. However, they add a residual of the values, passed through a convolution of kernel size 3, which they named Local Interactive Module (LIM).
|
||||
|
||||
They make the claim in this paper that this scheme outperforms Swin Transformer, and also demonstrate competitive performance against Crossformer.
|
||||
|
||||
You can use it as follows (ex. ScalableViT-S)
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.scalable_vit import ScalableViT
|
||||
|
||||
model = ScalableViT(
|
||||
num_classes = 1000,
|
||||
dim = 64, # starting model dimension. at every stage, dimension is doubled
|
||||
heads = (2, 4, 8, 16), # number of attention heads at each stage
|
||||
depth = (2, 2, 20, 2), # number of transformer blocks at each stage
|
||||
ssa_dim_key = (40, 40, 40, 32), # the dimension of the attention keys (and queries) for SSA. in the paper, they represented this as a scale factor on the base dimension per key (ssa_dim_key / dim_key)
|
||||
reduction_factor = (8, 4, 2, 1), # downsampling of the key / values in SSA. in the paper, this was represented as (reduction_factor ** -2)
|
||||
window_size = (64, 32, None, None), # window size of the IWSA at each stage. None means no windowing needed
|
||||
dropout = 0.1, # attention and feedforward dropout
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 256, 256)
|
||||
|
||||
preds = model(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## SepViT
|
||||
|
||||
<img src="./images/sep-vit.png" width="400px"></img>
|
||||
|
||||
Another <a href="https://arxiv.org/abs/2203.15380">Bytedance AI paper</a>, it proposes a depthwise-pointwise self-attention layer that seems largely inspired by mobilenet's depthwise-separable convolution. The most interesting aspect is the reuse of the feature map from the depthwise self-attention stage as the values for the pointwise self-attention, as shown in the diagram above.
|
||||
|
||||
I have decided to include only the version of `SepViT` with this specific self-attention layer, as the grouped attention layers are not remarkable nor novel, and the authors were not clear on how they treated the window tokens for the group self-attention layer. Besides, it seems like with `DSSA` layer alone, they were able to beat Swin.
|
||||
|
||||
ex. SepViT-Lite
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.sep_vit import SepViT
|
||||
|
||||
v = SepViT(
|
||||
num_classes = 1000,
|
||||
dim = 32, # dimensions of first stage, which doubles every stage (32, 64, 128, 256) for SepViT-Lite
|
||||
dim_head = 32, # attention head dimension
|
||||
heads = (1, 2, 4, 8), # number of heads per stage
|
||||
depth = (1, 2, 6, 2), # number of transformer blocks per stage
|
||||
window_size = 7, # window size of DSS Attention block
|
||||
dropout = 0.1 # dropout
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 224, 224)
|
||||
|
||||
preds = v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## MaxViT
|
||||
|
||||
<img src="./images/max-vit.png" width="400px"></img>
|
||||
|
||||
<a href="https://arxiv.org/abs/2204.01697">This paper</a> proposes a hybrid convolutional / attention network, using MBConv from the convolution side, and then block / grid axial sparse attention.
|
||||
|
||||
They also claim this specific vision transformer is good for generative models (GANs).
|
||||
|
||||
ex. MaxViT-S
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.max_vit import MaxViT
|
||||
|
||||
v = MaxViT(
|
||||
num_classes = 1000,
|
||||
dim_conv_stem = 64, # dimension of the convolutional stem, would default to dimension of first layer if not specified
|
||||
dim = 96, # dimension of first layer, doubles every layer
|
||||
dim_head = 32, # dimension of attention heads, kept at 32 in paper
|
||||
depth = (2, 2, 5, 2), # number of MaxViT blocks per stage, which consists of MBConv, block-like attention, grid-like attention
|
||||
window_size = 7, # window size for block and grids
|
||||
mbconv_expansion_rate = 4, # expansion rate of MBConv
|
||||
mbconv_shrinkage_rate = 0.25, # shrinkage rate of squeeze-excitation in MBConv
|
||||
dropout = 0.1 # dropout
|
||||
)
|
||||
|
||||
img = torch.randn(2, 3, 224, 224)
|
||||
|
||||
preds = v(img) # (2, 1000)
|
||||
```
|
||||
|
||||
## NesT
|
||||
|
||||
<img src="./images/nest.png" width="400px"></img>
|
||||
@@ -453,14 +646,125 @@ nest = NesT(
|
||||
dim = 96,
|
||||
heads = 3,
|
||||
num_hierarchies = 3, # number of hierarchies
|
||||
block_repeats = (8, 4, 1), # the number of transformer blocks at each heirarchy, starting from the bottom
|
||||
block_repeats = (2, 2, 8), # the number of transformer blocks at each heirarchy, starting from the bottom
|
||||
num_classes = 1000
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 224, 224)
|
||||
|
||||
pred = nest(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## MobileViT
|
||||
|
||||
<img src="./images/mbvit.png" width="400px"></img>
|
||||
|
||||
This <a href="https://arxiv.org/abs/2110.02178">paper</a> introduce MobileViT, a light-weight and general purpose vision transformer for mobile devices. MobileViT presents a different
|
||||
perspective for the global processing of information with transformers.
|
||||
|
||||
You can use it with the following code (ex. mobilevit_xs)
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.mobile_vit import MobileViT
|
||||
|
||||
mbvit_xs = MobileViT(
|
||||
image_size = (256, 256),
|
||||
dims = [96, 120, 144],
|
||||
channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384],
|
||||
num_classes = 1000
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 256, 256)
|
||||
|
||||
pred = mbvit_xs(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## Simple Masked Image Modeling
|
||||
|
||||
<img src="./images/simmim.png" width="400px"/>
|
||||
|
||||
This <a href="https://arxiv.org/abs/2111.09886">paper</a> proposes a simple masked image modeling (SimMIM) scheme, using only a linear projection off the masked tokens into pixel space followed by an L1 loss with the pixel values of the masked patches. Results are competitive with other more complicated approaches.
|
||||
|
||||
You can use this as follows
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch import ViT
|
||||
from vit_pytorch.simmim import SimMIM
|
||||
|
||||
v = ViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048
|
||||
)
|
||||
|
||||
mim = SimMIM(
|
||||
encoder = v,
|
||||
masking_ratio = 0.5 # they found 50% to yield the best results
|
||||
)
|
||||
|
||||
images = torch.randn(8, 3, 256, 256)
|
||||
|
||||
loss = mim(images)
|
||||
loss.backward()
|
||||
|
||||
# that's all!
|
||||
# do the above in a for loop many times with a lot of images and your vision transformer will learn
|
||||
|
||||
torch.save(v.state_dict(), './trained-vit.pt')
|
||||
```
|
||||
|
||||
|
||||
## Masked Autoencoder
|
||||
|
||||
<img src="./images/mae.png" width="400px"/>
|
||||
|
||||
A new <a href="https://arxiv.org/abs/2111.06377">Kaiming He paper</a> proposes a simple autoencoder scheme where the vision transformer attends to a set of unmasked patches, and a smaller decoder tries to reconstruct the masked pixel values.
|
||||
|
||||
<a href="https://www.youtube.com/watch?v=LKixq2S2Pz8">DeepReader quick paper review</a>
|
||||
|
||||
<a href="https://www.youtube.com/watch?v=Dp6iICL2dVI">AI Coffeebreak with Letitia</a>
|
||||
|
||||
You can use it with the following code
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch import ViT, MAE
|
||||
|
||||
v = ViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048
|
||||
)
|
||||
|
||||
mae = MAE(
|
||||
encoder = v,
|
||||
masking_ratio = 0.75, # the paper recommended 75% masked patches
|
||||
decoder_dim = 512, # paper showed good results with just 512
|
||||
decoder_depth = 6 # anywhere from 1 to 8
|
||||
)
|
||||
|
||||
images = torch.randn(8, 3, 256, 256)
|
||||
|
||||
loss = mae(images)
|
||||
loss.backward()
|
||||
|
||||
# that's all!
|
||||
# do the above in a for loop many times with a lot of images and your vision transformer will learn
|
||||
|
||||
# save your improved vision transformer
|
||||
torch.save(v.state_dict(), './trained-vit.pt')
|
||||
```
|
||||
|
||||
## Masked Patch Prediction
|
||||
|
||||
Thanks to <a href="https://github.com/zankner">Zach</a>, you can train using the original masked patch prediction task presented in the paper, with the following code.
|
||||
@@ -507,6 +811,217 @@ for _ in range(100):
|
||||
torch.save(model.state_dict(), './pretrained-net.pt')
|
||||
```
|
||||
|
||||
## Adaptive Token Sampling
|
||||
|
||||
<img src="./images/ats.png" width="400px"></img>
|
||||
|
||||
This <a href="https://arxiv.org/abs/2111.15667">paper</a> proposes to use the CLS attention scores, re-weighed by the norms of the value heads, as means to discard unimportant tokens at different layers.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.ats_vit import ViT
|
||||
|
||||
v = ViT(
|
||||
image_size = 256,
|
||||
patch_size = 16,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
max_tokens_per_depth = (256, 128, 64, 32, 16, 8), # a tuple that denotes the maximum number of tokens that any given layer should have. if the layer has greater than this amount, it will undergo adaptive token sampling
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
img = torch.randn(4, 3, 256, 256)
|
||||
|
||||
preds = v(img) # (4, 1000)
|
||||
|
||||
# you can also get a list of the final sampled patch ids
|
||||
# a value of -1 denotes padding
|
||||
|
||||
preds, token_ids = v(img, return_sampled_token_ids = True) # (4, 1000), (4, <=8)
|
||||
```
|
||||
|
||||
## Patch Merger
|
||||
|
||||
|
||||
<img src="./images/patch_merger.png" width="400px"></img>
|
||||
|
||||
This <a href="https://arxiv.org/abs/2202.12015">paper</a> proposes a simple module (Patch Merger) for reducing the number of tokens at any layer of a vision transformer without sacrificing performance.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.vit_with_patch_merger import ViT
|
||||
|
||||
v = ViT(
|
||||
image_size = 256,
|
||||
patch_size = 16,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 12,
|
||||
heads = 8,
|
||||
patch_merge_layer = 6, # at which transformer layer to do patch merging
|
||||
patch_merge_num_tokens = 8, # the output number of tokens from the patch merge
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
img = torch.randn(4, 3, 256, 256)
|
||||
|
||||
preds = v(img) # (4, 1000)
|
||||
```
|
||||
|
||||
One can also use the `PatchMerger` module by itself
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.vit_with_patch_merger import PatchMerger
|
||||
|
||||
merger = PatchMerger(
|
||||
dim = 1024,
|
||||
num_tokens_out = 8 # output number of tokens
|
||||
)
|
||||
|
||||
features = torch.randn(4, 256, 1024) # (batch, num tokens, dimension)
|
||||
|
||||
out = merger(features) # (4, 8, 1024)
|
||||
```
|
||||
|
||||
## Vision Transformer for Small Datasets
|
||||
|
||||
<img src="./images/vit_for_small_datasets.png" width="400px"></img>
|
||||
|
||||
This <a href="https://arxiv.org/abs/2112.13492">paper</a> proposes a new image to patch function that incorporates shifts of the image, before normalizing and dividing the image into patches. I have found shifting to be extremely helpful in some other transformers work, so decided to include this for further explorations. It also includes the `LSA` with the learned temperature and masking out of a token's attention to itself.
|
||||
|
||||
You can use as follows:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.vit_for_small_dataset import ViT
|
||||
|
||||
v = ViT(
|
||||
image_size = 256,
|
||||
patch_size = 16,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
img = torch.randn(4, 3, 256, 256)
|
||||
|
||||
preds = v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
You can also use the `SPT` from this paper as a standalone module
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.vit_for_small_dataset import SPT
|
||||
|
||||
spt = SPT(
|
||||
dim = 1024,
|
||||
patch_size = 16,
|
||||
channels = 3
|
||||
)
|
||||
|
||||
img = torch.randn(4, 3, 256, 256)
|
||||
|
||||
tokens = spt(img) # (4, 256, 1024)
|
||||
```
|
||||
|
||||
## Parallel ViT
|
||||
|
||||
<img src="./images/parallel-vit.png" width="350px"></img>
|
||||
|
||||
This <a href="https://arxiv.org/abs/2203.09795">paper</a> propose parallelizing multiple attention and feedforward blocks per layer (2 blocks), claiming that it is easier to train without loss of performance.
|
||||
|
||||
You can try this variant as follows
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.parallel_vit import ViT
|
||||
|
||||
v = ViT(
|
||||
image_size = 256,
|
||||
patch_size = 16,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048,
|
||||
num_parallel_branches = 2, # in paper, they claimed 2 was optimal
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
img = torch.randn(4, 3, 256, 256)
|
||||
|
||||
preds = v(img) # (4, 1000)
|
||||
```
|
||||
|
||||
## Learnable Memory ViT
|
||||
|
||||
<img src="./images/learnable-memory-vit.png" width="350px"></img>
|
||||
|
||||
This <a href="https://arxiv.org/abs/2203.15243">paper</a> shows that adding learnable memory tokens at each layer of a vision transformer can greatly enhance fine-tuning results (in addition to learnable task specific CLS token and adapter head).
|
||||
|
||||
You can use this with a specially modified `ViT` as follows
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.learnable_memory_vit import ViT, Adapter
|
||||
|
||||
# normal base ViT
|
||||
|
||||
v = ViT(
|
||||
image_size = 256,
|
||||
patch_size = 16,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
img = torch.randn(4, 3, 256, 256)
|
||||
logits = v(img) # (4, 1000)
|
||||
|
||||
# do your usual training with ViT
|
||||
# ...
|
||||
|
||||
|
||||
# then, to finetune, just pass the ViT into the Adapter class
|
||||
# you can do this for multiple Adapters, as shown below
|
||||
|
||||
adapter1 = Adapter(
|
||||
vit = v,
|
||||
num_classes = 2, # number of output classes for this specific task
|
||||
num_memories_per_layer = 5 # number of learnable memories per layer, 10 was sufficient in paper
|
||||
)
|
||||
|
||||
logits1 = adapter1(img) # (4, 2) - predict 2 classes off frozen ViT backbone with learnable memories and task specific head
|
||||
|
||||
# yet another task to finetune on, this time with 4 classes
|
||||
|
||||
adapter2 = Adapter(
|
||||
vit = v,
|
||||
num_classes = 4,
|
||||
num_memories_per_layer = 10
|
||||
)
|
||||
|
||||
logits2 = adapter2(img) # (4, 4) - predict 4 classes off frozen ViT backbone with learnable memories and task specific head
|
||||
|
||||
```
|
||||
|
||||
## Dino
|
||||
|
||||
<img src="./images/dino.png" width="350px"></img>
|
||||
@@ -602,6 +1117,41 @@ to cleanup the class and the hooks once you have collected enough data
|
||||
v = v.eject() # wrapper is discarded and original ViT instance is returned
|
||||
```
|
||||
|
||||
## Accessing Embeddings
|
||||
|
||||
You can similarly access the embeddings with the `Extractor` wrapper
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.vit import ViT
|
||||
|
||||
v = ViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
# import Recorder and wrap the ViT
|
||||
|
||||
from vit_pytorch.extractor import Extractor
|
||||
v = Extractor(v)
|
||||
|
||||
# forward pass now returns predictions and the attention maps
|
||||
|
||||
img = torch.randn(1, 3, 256, 256)
|
||||
logits, embeddings = v(img)
|
||||
|
||||
# there is one extra token due to the CLS token
|
||||
|
||||
embeddings # (1, 65, 1024) - (batch x patches x model dim)
|
||||
```
|
||||
|
||||
## Research Ideas
|
||||
|
||||
### Efficient Attention
|
||||
@@ -739,13 +1289,13 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
## Citations
|
||||
```bibtex
|
||||
@article{hassani2021escaping,
|
||||
title = {Escaping the Big Data Paradigm with Compact Transformers},
|
||||
author = {Ali Hassani and Steven Walton and Nikhil Shah and Abulikemu Abuduweili and Jiachen Li and Humphrey Shi},
|
||||
year = 2021,
|
||||
url = {https://arxiv.org/abs/2104.05704},
|
||||
eprint = {2104.05704},
|
||||
archiveprefix = {arXiv},
|
||||
primaryclass = {cs.CV}
|
||||
title = {Escaping the Big Data Paradigm with Compact Transformers},
|
||||
author = {Ali Hassani and Steven Walton and Nikhil Shah and Abulikemu Abuduweili and Jiachen Li and Humphrey Shi},
|
||||
year = 2021,
|
||||
url = {https://arxiv.org/abs/2104.05704},
|
||||
eprint = {2104.05704},
|
||||
archiveprefix = {arXiv},
|
||||
primaryclass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
@@ -773,10 +1323,10 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
|
||||
```bibtex
|
||||
@misc{yuan2021tokenstotoken,
|
||||
title = {Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet},
|
||||
author = {Li Yuan and Yunpeng Chen and Tao Wang and Weihao Yu and Yujun Shi and Francis EH Tay and Jiashi Feng and Shuicheng Yan},
|
||||
year = {2021},
|
||||
eprint = {2101.11986},
|
||||
title = {Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet},
|
||||
author = {Li Yuan and Yunpeng Chen and Tao Wang and Weihao Yu and Yujun Shi and Francis EH Tay and Jiashi Feng and Shuicheng Yan},
|
||||
year = {2021},
|
||||
eprint = {2101.11986},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
@@ -892,6 +1442,28 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{chen2021regionvit,
|
||||
title = {RegionViT: Regional-to-Local Attention for Vision Transformers},
|
||||
author = {Chun-Fu Chen and Rameswar Panda and Quanfu Fan},
|
||||
year = {2021},
|
||||
eprint = {2106.02689},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{wang2021crossformer,
|
||||
title = {CrossFormer: A Versatile Vision Transformer Hinging on Cross-scale Attention},
|
||||
author = {Wenxiao Wang and Lu Yao and Long Chen and Binbin Lin and Deng Cai and Xiaofei He and Wei Liu},
|
||||
year = {2021},
|
||||
eprint = {2108.00154},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{caron2021emerging,
|
||||
title = {Emerging Properties in Self-Supervised Vision Transformers},
|
||||
@@ -903,6 +1475,115 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{he2021masked,
|
||||
title = {Masked Autoencoders Are Scalable Vision Learners},
|
||||
author = {Kaiming He and Xinlei Chen and Saining Xie and Yanghao Li and Piotr Dollár and Ross Girshick},
|
||||
year = {2021},
|
||||
eprint = {2111.06377},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{xie2021simmim,
|
||||
title = {SimMIM: A Simple Framework for Masked Image Modeling},
|
||||
author = {Zhenda Xie and Zheng Zhang and Yue Cao and Yutong Lin and Jianmin Bao and Zhuliang Yao and Qi Dai and Han Hu},
|
||||
year = {2021},
|
||||
eprint = {2111.09886},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{fayyaz2021ats,
|
||||
title = {ATS: Adaptive Token Sampling For Efficient Vision Transformers},
|
||||
author = {Mohsen Fayyaz and Soroush Abbasi Kouhpayegani and Farnoush Rezaei Jafari and Eric Sommerlade and Hamid Reza Vaezi Joze and Hamed Pirsiavash and Juergen Gall},
|
||||
year = {2021},
|
||||
eprint = {2111.15667},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{mehta2021mobilevit,
|
||||
title = {MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer},
|
||||
author = {Sachin Mehta and Mohammad Rastegari},
|
||||
year = {2021},
|
||||
eprint = {2110.02178},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{lee2021vision,
|
||||
title = {Vision Transformer for Small-Size Datasets},
|
||||
author = {Seung Hoon Lee and Seunghyun Lee and Byung Cheol Song},
|
||||
year = {2021},
|
||||
eprint = {2112.13492},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{renggli2022learning,
|
||||
title = {Learning to Merge Tokens in Vision Transformers},
|
||||
author = {Cedric Renggli and André Susano Pinto and Neil Houlsby and Basil Mustafa and Joan Puigcerver and Carlos Riquelme},
|
||||
year = {2022},
|
||||
eprint = {2202.12015},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{yang2022scalablevit,
|
||||
title = {ScalableViT: Rethinking the Context-oriented Generalization of Vision Transformer},
|
||||
author = {Rui Yang and Hailong Ma and Jie Wu and Yansong Tang and Xuefeng Xiao and Min Zheng and Xiu Li},
|
||||
year = {2022},
|
||||
eprint = {2203.10790},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Touvron2022ThreeTE,
|
||||
title = {Three things everyone should know about Vision Transformers},
|
||||
author = {Hugo Touvron and Matthieu Cord and Alaaeldin El-Nouby and Jakob Verbeek and Herv'e J'egou},
|
||||
year = {2022}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Sandler2022FinetuningIT,
|
||||
title = {Fine-tuning Image Transformers using Learnable Memory},
|
||||
author = {Mark Sandler and Andrey Zhmoginov and Max Vladymyrov and Andrew Jackson},
|
||||
year = {2022}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Li2022SepViTSV,
|
||||
title = {SepViT: Separable Vision Transformer},
|
||||
author = {Wei Li and Xing Wang and Xin Xia and Jie Wu and Xuefeng Xiao and Minghang Zheng and Shiping Wen},
|
||||
year = {2022}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Tu2022MaxViTMV,
|
||||
title = {MaxViT: Multi-Axis Vision Transformer},
|
||||
author = {Zhe-Wei Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
|
||||
year = {2022}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{vaswani2017attention,
|
||||
title = {Attention Is All You Need},
|
||||
|
||||
@@ -364,9 +364,8 @@
|
||||
"\n",
|
||||
"val_transforms = transforms.Compose(\n",
|
||||
" [\n",
|
||||
" transforms.Resize((224, 224)),\n",
|
||||
" transforms.RandomResizedCrop(224),\n",
|
||||
" transforms.RandomHorizontalFlip(),\n",
|
||||
" transforms.Resize(256),\n",
|
||||
" transforms.CenterCrop(224),\n",
|
||||
" transforms.ToTensor(),\n",
|
||||
" ]\n",
|
||||
")\n",
|
||||
@@ -374,9 +373,8 @@
|
||||
"\n",
|
||||
"test_transforms = transforms.Compose(\n",
|
||||
" [\n",
|
||||
" transforms.Resize((224, 224)),\n",
|
||||
" transforms.RandomResizedCrop(224),\n",
|
||||
" transforms.RandomHorizontalFlip(),\n",
|
||||
" transforms.Resize(256),\n",
|
||||
" transforms.CenterCrop(224),\n",
|
||||
" transforms.ToTensor(),\n",
|
||||
" ]\n",
|
||||
")\n"
|
||||
@@ -6250,4 +6248,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 1
|
||||
}
|
||||
}
|
||||
BIN
images/ats.png
Normal file
|
After Width: | Height: | Size: 198 KiB |
BIN
images/crossformer.png
Normal file
|
After Width: | Height: | Size: 169 KiB |
BIN
images/crossformer2.png
Normal file
|
After Width: | Height: | Size: 237 KiB |
BIN
images/learnable-memory-vit.png
Normal file
|
After Width: | Height: | Size: 108 KiB |
BIN
images/mae.png
Normal file
|
After Width: | Height: | Size: 198 KiB |
BIN
images/max-vit.png
Normal file
|
After Width: | Height: | Size: 133 KiB |
BIN
images/mbvit.png
Normal file
|
After Width: | Height: | Size: 206 KiB |
BIN
images/parallel-vit.png
Normal file
|
After Width: | Height: | Size: 14 KiB |
BIN
images/patch_merger.png
Normal file
|
After Width: | Height: | Size: 54 KiB |
BIN
images/regionvit.png
Normal file
|
After Width: | Height: | Size: 94 KiB |
BIN
images/regionvit2.png
Normal file
|
After Width: | Height: | Size: 55 KiB |
BIN
images/scalable-vit-1.png
Normal file
|
After Width: | Height: | Size: 79 KiB |
BIN
images/scalable-vit-2.png
Normal file
|
After Width: | Height: | Size: 62 KiB |
BIN
images/sep-vit.png
Normal file
|
After Width: | Height: | Size: 142 KiB |
BIN
images/simmim.png
Normal file
|
After Width: | Height: | Size: 365 KiB |
BIN
images/vit_for_small_datasets.png
Normal file
|
After Width: | Height: | Size: 80 KiB |
12
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.20.4',
|
||||
version = '0.33.2',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
@@ -15,10 +15,16 @@ setup(
|
||||
'image recognition'
|
||||
],
|
||||
install_requires=[
|
||||
'einops>=0.3',
|
||||
'torch>=1.6',
|
||||
'einops>=0.4.1',
|
||||
'torch>=1.10',
|
||||
'torchvision'
|
||||
],
|
||||
setup_requires=[
|
||||
'pytest-runner',
|
||||
],
|
||||
tests_require=[
|
||||
'pytest'
|
||||
],
|
||||
classifiers=[
|
||||
'Development Status :: 4 - Beta',
|
||||
'Intended Audience :: Developers',
|
||||
|
||||
20
tests/test.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import torch
|
||||
from vit_pytorch import ViT
|
||||
|
||||
def test():
|
||||
v = ViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 256, 256)
|
||||
|
||||
preds = v(img)
|
||||
assert preds.shape == (1, 1000), 'correct logits outputted'
|
||||
@@ -1,2 +1,3 @@
|
||||
from vit_pytorch.vit import ViT
|
||||
from vit_pytorch.mae import MAE
|
||||
from vit_pytorch.dino import Dino
|
||||
|
||||
265
vit_pytorch/ats_vit.py
Normal file
@@ -0,0 +1,265 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from torch import nn, einsum
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
# adaptive token sampling functions and classes
|
||||
|
||||
def log(t, eps = 1e-6):
|
||||
return torch.log(t + eps)
|
||||
|
||||
def sample_gumbel(shape, device, dtype, eps = 1e-6):
|
||||
u = torch.empty(shape, device = device, dtype = dtype).uniform_(0, 1)
|
||||
return -log(-log(u, eps), eps)
|
||||
|
||||
def batched_index_select(values, indices, dim = 1):
|
||||
value_dims = values.shape[(dim + 1):]
|
||||
values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
|
||||
indices = indices[(..., *((None,) * len(value_dims)))]
|
||||
indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
|
||||
value_expand_len = len(indices_shape) - (dim + 1)
|
||||
values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)]
|
||||
|
||||
value_expand_shape = [-1] * len(values.shape)
|
||||
expand_slice = slice(dim, (dim + value_expand_len))
|
||||
value_expand_shape[expand_slice] = indices.shape[expand_slice]
|
||||
values = values.expand(*value_expand_shape)
|
||||
|
||||
dim += value_expand_len
|
||||
return values.gather(dim, indices)
|
||||
|
||||
class AdaptiveTokenSampling(nn.Module):
|
||||
def __init__(self, output_num_tokens, eps = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.output_num_tokens = output_num_tokens
|
||||
|
||||
def forward(self, attn, value, mask):
|
||||
heads, output_num_tokens, eps, device, dtype = attn.shape[1], self.output_num_tokens, self.eps, attn.device, attn.dtype
|
||||
|
||||
# first get the attention values for CLS token to all other tokens
|
||||
|
||||
cls_attn = attn[..., 0, 1:]
|
||||
|
||||
# calculate the norms of the values, for weighting the scores, as described in the paper
|
||||
|
||||
value_norms = value[..., 1:, :].norm(dim = -1)
|
||||
|
||||
# weigh the attention scores by the norm of the values, sum across all heads
|
||||
|
||||
cls_attn = einsum('b h n, b h n -> b n', cls_attn, value_norms)
|
||||
|
||||
# normalize to 1
|
||||
|
||||
normed_cls_attn = cls_attn / (cls_attn.sum(dim = -1, keepdim = True) + eps)
|
||||
|
||||
# instead of using inverse transform sampling, going to invert the softmax and use gumbel-max sampling instead
|
||||
|
||||
pseudo_logits = log(normed_cls_attn)
|
||||
|
||||
# mask out pseudo logits for gumbel-max sampling
|
||||
|
||||
mask_without_cls = mask[:, 1:]
|
||||
mask_value = -torch.finfo(attn.dtype).max / 2
|
||||
pseudo_logits = pseudo_logits.masked_fill(~mask_without_cls, mask_value)
|
||||
|
||||
# expand k times, k being the adaptive sampling number
|
||||
|
||||
pseudo_logits = repeat(pseudo_logits, 'b n -> b k n', k = output_num_tokens)
|
||||
pseudo_logits = pseudo_logits + sample_gumbel(pseudo_logits.shape, device = device, dtype = dtype)
|
||||
|
||||
# gumble-max and add one to reserve 0 for padding / mask
|
||||
|
||||
sampled_token_ids = pseudo_logits.argmax(dim = -1) + 1
|
||||
|
||||
# calculate unique using torch.unique and then pad the sequence from the right
|
||||
|
||||
unique_sampled_token_ids_list = [torch.unique(t, sorted = True) for t in torch.unbind(sampled_token_ids)]
|
||||
unique_sampled_token_ids = pad_sequence(unique_sampled_token_ids_list, batch_first = True)
|
||||
|
||||
# calculate the new mask, based on the padding
|
||||
|
||||
new_mask = unique_sampled_token_ids != 0
|
||||
|
||||
# CLS token never gets masked out (gets a value of True)
|
||||
|
||||
new_mask = F.pad(new_mask, (1, 0), value = True)
|
||||
|
||||
# prepend a 0 token id to keep the CLS attention scores
|
||||
|
||||
unique_sampled_token_ids = F.pad(unique_sampled_token_ids, (1, 0), value = 0)
|
||||
expanded_unique_sampled_token_ids = repeat(unique_sampled_token_ids, 'b n -> b h n', h = heads)
|
||||
|
||||
# gather the new attention scores
|
||||
|
||||
new_attn = batched_index_select(attn, expanded_unique_sampled_token_ids, dim = 2)
|
||||
|
||||
# return the sampled attention scores, new mask (denoting padding), as well as the sampled token indices (for the residual)
|
||||
return new_attn, new_mask, unique_sampled_token_ids
|
||||
|
||||
# classes
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., output_num_tokens = None):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.output_num_tokens = output_num_tokens
|
||||
self.ats = AdaptiveTokenSampling(output_num_tokens) if exists(output_num_tokens) else None
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, *, mask):
|
||||
num_tokens = x.shape[1]
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
if exists(mask):
|
||||
dots_mask = rearrange(mask, 'b i -> b 1 i 1') * rearrange(mask, 'b j -> b 1 1 j')
|
||||
mask_value = -torch.finfo(dots.dtype).max
|
||||
dots = dots.masked_fill(~dots_mask, mask_value)
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
sampled_token_ids = None
|
||||
|
||||
# if adaptive token sampling is enabled
|
||||
# and number of tokens is greater than the number of output tokens
|
||||
if exists(self.output_num_tokens) and (num_tokens - 1) > self.output_num_tokens:
|
||||
attn, mask, sampled_token_ids = self.ats(attn, v, mask = mask)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
|
||||
return self.to_out(out), mask, sampled_token_ids
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, max_tokens_per_depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
assert len(max_tokens_per_depth) == depth, 'max_tokens_per_depth must be a tuple of length that is equal to the depth of the transformer'
|
||||
assert sorted(max_tokens_per_depth, reverse = True) == list(max_tokens_per_depth), 'max_tokens_per_depth must be in decreasing order'
|
||||
assert min(max_tokens_per_depth) > 0, 'max_tokens_per_depth must have at least 1 token at any layer'
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
for _, output_num_tokens in zip(range(depth), max_tokens_per_depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, output_num_tokens = output_num_tokens, heads = heads, dim_head = dim_head, dropout = dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
b, n, device = *x.shape[:2], x.device
|
||||
|
||||
# use mask to keep track of the paddings when sampling tokens
|
||||
# as the duplicates (when sampling) are just removed, as mentioned in the paper
|
||||
mask = torch.ones((b, n), device = device, dtype = torch.bool)
|
||||
|
||||
token_ids = torch.arange(n, device = device)
|
||||
token_ids = repeat(token_ids, 'n -> b n', b = b)
|
||||
|
||||
for attn, ff in self.layers:
|
||||
attn_out, mask, sampled_token_ids = attn(x, mask = mask)
|
||||
|
||||
# when token sampling, one needs to then gather the residual tokens with the sampled token ids
|
||||
if exists(sampled_token_ids):
|
||||
x = batched_index_select(x, sampled_token_ids, dim = 1)
|
||||
token_ids = batched_index_select(token_ids, sampled_token_ids, dim = 1)
|
||||
|
||||
x = x + attn_out
|
||||
|
||||
x = ff(x) + x
|
||||
|
||||
return x, token_ids
|
||||
|
||||
class ViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, max_tokens_per_depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
|
||||
nn.Linear(patch_dim, dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, max_tokens_per_depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img, return_sampled_token_ids = False):
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.dropout(x)
|
||||
|
||||
x, token_ids = self.transformer(x)
|
||||
|
||||
logits = self.mlp_head(x[:, 0])
|
||||
|
||||
if return_sampled_token_ids:
|
||||
# remove CLS token and decrement by 1 to make -1 the padding
|
||||
token_ids = token_ids[:, 1:] - 1
|
||||
return logits, token_ids
|
||||
|
||||
return logits
|
||||
@@ -76,6 +76,7 @@ class Attention(nn.Module):
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.mix_heads_pre_attn = nn.Parameter(torch.randn(heads, heads))
|
||||
self.mix_heads_post_attn = nn.Parameter(torch.randn(heads, heads))
|
||||
@@ -96,7 +97,10 @@ class Attention(nn.Module):
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
|
||||
dots = einsum('b h i j, h g -> b g i j', dots, self.mix_heads_pre_attn) # talking heads, pre-softmax
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
attn = einsum('b h i j, h g -> b g i j', attn, self.mix_heads_post_attn) # talking heads, post-softmax
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
|
||||
@@ -2,7 +2,13 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Pre-defined CCT Models
|
||||
# helpers
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
# CCT Models
|
||||
|
||||
__all__ = ['cct_2', 'cct_4', 'cct_6', 'cct_7', 'cct_8', 'cct_14', 'cct_16']
|
||||
|
||||
|
||||
@@ -55,8 +61,8 @@ def _cct(num_layers, num_heads, mlp_ratio, embedding_dim,
|
||||
padding=padding,
|
||||
*args, **kwargs)
|
||||
|
||||
# modules
|
||||
|
||||
# Modules
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1):
|
||||
super().__init__()
|
||||
@@ -308,6 +314,7 @@ class CCT(nn.Module):
|
||||
pooling_padding=1,
|
||||
*args, **kwargs):
|
||||
super(CCT, self).__init__()
|
||||
img_height, img_width = pair(img_size)
|
||||
|
||||
self.tokenizer = Tokenizer(n_input_channels=n_input_channels,
|
||||
n_output_channels=embedding_dim,
|
||||
@@ -324,8 +331,8 @@ class CCT(nn.Module):
|
||||
|
||||
self.classifier = TransformerClassifier(
|
||||
sequence_length=self.tokenizer.sequence_length(n_channels=n_input_channels,
|
||||
height=img_size,
|
||||
width=img_size),
|
||||
height=img_height,
|
||||
width=img_width),
|
||||
embedding_dim=embedding_dim,
|
||||
seq_pool=True,
|
||||
dropout_rate=0.,
|
||||
@@ -336,4 +343,3 @@ class CCT(nn.Module):
|
||||
def forward(self, x):
|
||||
x = self.tokenizer(x)
|
||||
return self.classifier(x)
|
||||
|
||||
|
||||
@@ -48,6 +48,8 @@ class Attention(nn.Module):
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
|
||||
|
||||
@@ -69,6 +71,7 @@ class Attention(nn.Module):
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
|
||||
267
vit_pytorch/crossformer.py
Normal file
@@ -0,0 +1,267 @@
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange, Reduce
|
||||
import torch.nn.functional as F
|
||||
|
||||
# helpers
|
||||
|
||||
def cast_tuple(val, length = 1):
|
||||
return val if isinstance(val, tuple) else ((val,) * length)
|
||||
|
||||
# cross embed layer
|
||||
|
||||
class CrossEmbedLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim_in,
|
||||
dim_out,
|
||||
kernel_sizes,
|
||||
stride = 2
|
||||
):
|
||||
super().__init__()
|
||||
kernel_sizes = sorted(kernel_sizes)
|
||||
num_scales = len(kernel_sizes)
|
||||
|
||||
# calculate the dimension at each scale
|
||||
dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]
|
||||
dim_scales = [*dim_scales, dim_out - sum(dim_scales)]
|
||||
|
||||
self.convs = nn.ModuleList([])
|
||||
for kernel, dim_scale in zip(kernel_sizes, dim_scales):
|
||||
self.convs.append(nn.Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))
|
||||
|
||||
def forward(self, x):
|
||||
fmaps = tuple(map(lambda conv: conv(x), self.convs))
|
||||
return torch.cat(fmaps, dim = 1)
|
||||
|
||||
# dynamic positional bias
|
||||
|
||||
def DynamicPositionBias(dim):
|
||||
return nn.Sequential(
|
||||
nn.Linear(2, dim),
|
||||
nn.LayerNorm(dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(dim, 1),
|
||||
Rearrange('... () -> ...')
|
||||
)
|
||||
|
||||
# transformer classes
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps = 1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
|
||||
|
||||
def FeedForward(dim, mult = 4, dropout = 0.):
|
||||
return nn.Sequential(
|
||||
LayerNorm(dim),
|
||||
nn.Conv2d(dim, dim * mult, 1),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Conv2d(dim * mult, dim, 1)
|
||||
)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
attn_type,
|
||||
window_size,
|
||||
dim_head = 32,
|
||||
dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
assert attn_type in {'short', 'long'}, 'attention type must be one of local or distant'
|
||||
heads = dim // dim_head
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.attn_type = attn_type
|
||||
self.window_size = window_size
|
||||
|
||||
self.norm = LayerNorm(dim)
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
|
||||
self.to_out = nn.Conv2d(inner_dim, dim, 1)
|
||||
|
||||
# positions
|
||||
|
||||
self.dpb = DynamicPositionBias(dim // 4)
|
||||
|
||||
# calculate and store indices for retrieving bias
|
||||
|
||||
pos = torch.arange(window_size)
|
||||
grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
|
||||
grid = rearrange(grid, 'c i j -> (i j) c')
|
||||
rel_pos = grid[:, None] - grid[None, :]
|
||||
rel_pos += window_size - 1
|
||||
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1)
|
||||
|
||||
self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False)
|
||||
|
||||
def forward(self, x):
|
||||
*_, height, width, heads, wsz, device = *x.shape, self.heads, self.window_size, x.device
|
||||
|
||||
# prenorm
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
# rearrange for short or long distance attention
|
||||
|
||||
if self.attn_type == 'short':
|
||||
x = rearrange(x, 'b d (h s1) (w s2) -> (b h w) d s1 s2', s1 = wsz, s2 = wsz)
|
||||
elif self.attn_type == 'long':
|
||||
x = rearrange(x, 'b d (l1 h) (l2 w) -> (b h w) d l1 l2', l1 = wsz, l2 = wsz)
|
||||
|
||||
# queries / keys / values
|
||||
|
||||
q, k, v = self.to_qkv(x).chunk(3, dim = 1)
|
||||
|
||||
# split heads
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> b h (x y) d', h = heads), (q, k, v))
|
||||
q = q * self.scale
|
||||
|
||||
sim = einsum('b h i d, b h j d -> b h i j', q, k)
|
||||
|
||||
# add dynamic positional bias
|
||||
|
||||
pos = torch.arange(-wsz, wsz + 1, device = device)
|
||||
rel_pos = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
|
||||
rel_pos = rearrange(rel_pos, 'c i j -> (i j) c')
|
||||
biases = self.dpb(rel_pos.float())
|
||||
rel_pos_bias = biases[self.rel_pos_indices]
|
||||
|
||||
sim = sim + rel_pos_bias
|
||||
|
||||
# attend
|
||||
|
||||
attn = sim.softmax(dim = -1)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
# merge heads
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = wsz, y = wsz)
|
||||
out = self.to_out(out)
|
||||
|
||||
# rearrange back for long or short distance attention
|
||||
|
||||
if self.attn_type == 'short':
|
||||
out = rearrange(out, '(b h w) d s1 s2 -> b d (h s1) (w s2)', h = height // wsz, w = width // wsz)
|
||||
elif self.attn_type == 'long':
|
||||
out = rearrange(out, '(b h w) d l1 l2 -> b d (l1 h) (l2 w)', h = height // wsz, w = width // wsz)
|
||||
|
||||
return out
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
local_window_size,
|
||||
global_window_size,
|
||||
depth = 4,
|
||||
dim_head = 32,
|
||||
attn_dropout = 0.,
|
||||
ff_dropout = 0.,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, attn_type = 'short', window_size = local_window_size, dim_head = dim_head, dropout = attn_dropout),
|
||||
FeedForward(dim, dropout = ff_dropout),
|
||||
Attention(dim, attn_type = 'long', window_size = global_window_size, dim_head = dim_head, dropout = attn_dropout),
|
||||
FeedForward(dim, dropout = ff_dropout)
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
for short_attn, short_ff, long_attn, long_ff in self.layers:
|
||||
x = short_attn(x) + x
|
||||
x = short_ff(x) + x
|
||||
x = long_attn(x) + x
|
||||
x = long_ff(x) + x
|
||||
|
||||
return x
|
||||
|
||||
# classes
|
||||
|
||||
class CrossFormer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim = (64, 128, 256, 512),
|
||||
depth = (2, 2, 8, 2),
|
||||
global_window_size = (8, 4, 2, 1),
|
||||
local_window_size = 7,
|
||||
cross_embed_kernel_sizes = ((4, 8, 16, 32), (2, 4), (2, 4), (2, 4)),
|
||||
cross_embed_strides = (4, 2, 2, 2),
|
||||
num_classes = 1000,
|
||||
attn_dropout = 0.,
|
||||
ff_dropout = 0.,
|
||||
channels = 3
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
dim = cast_tuple(dim, 4)
|
||||
depth = cast_tuple(depth, 4)
|
||||
global_window_size = cast_tuple(global_window_size, 4)
|
||||
local_window_size = cast_tuple(local_window_size, 4)
|
||||
cross_embed_kernel_sizes = cast_tuple(cross_embed_kernel_sizes, 4)
|
||||
cross_embed_strides = cast_tuple(cross_embed_strides, 4)
|
||||
|
||||
assert len(dim) == 4
|
||||
assert len(depth) == 4
|
||||
assert len(global_window_size) == 4
|
||||
assert len(local_window_size) == 4
|
||||
assert len(cross_embed_kernel_sizes) == 4
|
||||
assert len(cross_embed_strides) == 4
|
||||
|
||||
# dimensions
|
||||
|
||||
last_dim = dim[-1]
|
||||
dims = [channels, *dim]
|
||||
dim_in_and_out = tuple(zip(dims[:-1], dims[1:]))
|
||||
|
||||
# layers
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
for (dim_in, dim_out), layers, global_wsz, local_wsz, cel_kernel_sizes, cel_stride in zip(dim_in_and_out, depth, global_window_size, local_window_size, cross_embed_kernel_sizes, cross_embed_strides):
|
||||
self.layers.append(nn.ModuleList([
|
||||
CrossEmbedLayer(dim_in, dim_out, cel_kernel_sizes, stride = cel_stride),
|
||||
Transformer(dim_out, local_window_size = local_wsz, global_window_size = global_wsz, depth = layers, attn_dropout = attn_dropout, ff_dropout = ff_dropout)
|
||||
]))
|
||||
|
||||
# final logits
|
||||
|
||||
self.to_logits = nn.Sequential(
|
||||
Reduce('b c h w -> b c', 'mean'),
|
||||
nn.Linear(last_dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for cel, transformer in self.layers:
|
||||
x = cel(x)
|
||||
x = transformer(x)
|
||||
|
||||
return self.to_logits(x)
|
||||
@@ -30,9 +30,9 @@ class LayerNorm(nn.Module): # layernorm, but done in the channel dimension #1
|
||||
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
|
||||
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (std + self.eps) * self.g + self.b
|
||||
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
@@ -76,6 +76,7 @@ class Attention(nn.Module):
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_q = DepthWiseConv2d(dim, inner_dim, proj_kernel, padding = padding, stride = 1, bias = False)
|
||||
self.to_kv = DepthWiseConv2d(dim, inner_dim * 2, proj_kernel, padding = padding, stride = kv_proj_stride, bias = False)
|
||||
@@ -94,6 +95,7 @@ class Attention(nn.Module):
|
||||
dots = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = einsum('b i j, b j d -> b i d', attn, v)
|
||||
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, y = y)
|
||||
|
||||
@@ -42,6 +42,8 @@ class Attention(nn.Module):
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.reattn_weights = nn.Parameter(torch.randn(heads, heads))
|
||||
|
||||
self.reattn_norm = nn.Sequential(
|
||||
@@ -64,6 +66,7 @@ class Attention(nn.Module):
|
||||
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
attn = dots.softmax(dim=-1)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
# re-attention
|
||||
|
||||
|
||||
@@ -3,12 +3,16 @@ from torch import nn
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
class ViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, transformer, pool = 'cls', channels = 3):
|
||||
super().__init__()
|
||||
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
|
||||
image_size_h, image_size_w = pair(image_size)
|
||||
assert image_size_h % patch_size == 0 and image_size_w % patch_size == 0, 'image dimensions must be divisible by the patch size'
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
num_patches = (image_size_h // patch_size) * (image_size_w // patch_size)
|
||||
patch_dim = channels * patch_size ** 2
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
|
||||
70
vit_pytorch/extractor.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
class Extractor(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vit,
|
||||
device = None,
|
||||
layer_name = 'transformer',
|
||||
layer_save_input = False,
|
||||
return_embeddings_only = False
|
||||
):
|
||||
super().__init__()
|
||||
self.vit = vit
|
||||
|
||||
self.data = None
|
||||
self.latents = None
|
||||
self.hooks = []
|
||||
self.hook_registered = False
|
||||
self.ejected = False
|
||||
self.device = device
|
||||
|
||||
self.layer_name = layer_name
|
||||
self.layer_save_input = layer_save_input # whether to save input or output of layer
|
||||
self.return_embeddings_only = return_embeddings_only
|
||||
|
||||
def _hook(self, _, inputs, output):
|
||||
tensor_to_save = inputs if self.layer_save_input else output
|
||||
self.latents = tensor_to_save.clone().detach()
|
||||
|
||||
def _register_hook(self):
|
||||
assert hasattr(self.vit, self.layer_name), 'layer whose output to take as embedding not found in vision transformer'
|
||||
layer = getattr(self.vit, self.layer_name)
|
||||
handle = layer.register_forward_hook(self._hook)
|
||||
self.hooks.append(handle)
|
||||
self.hook_registered = True
|
||||
|
||||
def eject(self):
|
||||
self.ejected = True
|
||||
for hook in self.hooks:
|
||||
hook.remove()
|
||||
self.hooks.clear()
|
||||
return self.vit
|
||||
|
||||
def clear(self):
|
||||
del self.latents
|
||||
self.latents = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img,
|
||||
return_embeddings_only = False
|
||||
):
|
||||
assert not self.ejected, 'extractor has been ejected, cannot be used anymore'
|
||||
self.clear()
|
||||
if not self.hook_registered:
|
||||
self._register_hook()
|
||||
|
||||
pred = self.vit(img)
|
||||
|
||||
target_device = self.device if exists(self.device) else img.device
|
||||
latents = self.latents.to(target_device)
|
||||
|
||||
if return_embeddings_only or self.return_embeddings_only:
|
||||
return latents
|
||||
|
||||
return pred, latents
|
||||
216
vit_pytorch/learnable_memory_vit.py
Normal file
@@ -0,0 +1,216 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
# controlling freezing of layers
|
||||
|
||||
def set_module_requires_grad_(module, requires_grad):
|
||||
for param in module.parameters():
|
||||
param.requires_grad = requires_grad
|
||||
|
||||
def freeze_all_layers_(module):
|
||||
set_module_requires_grad_(module, False)
|
||||
|
||||
def unfreeze_all_layers_(module):
|
||||
set_module_requires_grad_(module, True)
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, attn_mask = None, memories = None):
|
||||
x = self.norm(x)
|
||||
|
||||
x_kv = x # input for key / values projection
|
||||
|
||||
if exists(memories):
|
||||
# add memories to key / values if it is passed in
|
||||
memories = repeat(memories, 'n d -> b n d', b = x.shape[0]) if memories.ndim == 2 else memories
|
||||
x_kv = torch.cat((x_kv, memories), dim = 1)
|
||||
|
||||
qkv = (self.to_q(x), *self.to_kv(x_kv).chunk(2, dim = -1))
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
if exists(attn_mask):
|
||||
dots = dots.masked_fill(~attn_mask, -torch.finfo(dots.dtype).max)
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
def forward(self, x, attn_mask = None, memories = None):
|
||||
for ind, (attn, ff) in enumerate(self.layers):
|
||||
layer_memories = memories[ind] if exists(memories) else None
|
||||
|
||||
x = attn(x, attn_mask = attn_mask, memories = layer_memories) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
class ViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
|
||||
nn.Linear(patch_dim, dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def img_to_tokens(self, img):
|
||||
x = self.to_patch_embedding(img)
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = x.shape[0])
|
||||
x = torch.cat((cls_tokens, x), dim = 1)
|
||||
|
||||
x += self.pos_embedding
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
def forward(self, img):
|
||||
x = self.img_to_tokens(img)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
cls_tokens = x[:, 0]
|
||||
return self.mlp_head(cls_tokens)
|
||||
|
||||
# adapter with learnable memories per layer, memory CLS token, and learnable adapter head
|
||||
|
||||
class Adapter(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vit,
|
||||
num_memories_per_layer = 10,
|
||||
num_classes = 2,
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(vit, ViT)
|
||||
|
||||
# extract some model variables needed
|
||||
|
||||
dim = vit.cls_token.shape[-1]
|
||||
layers = len(vit.transformer.layers)
|
||||
num_patches = vit.pos_embedding.shape[-2]
|
||||
|
||||
self.vit = vit
|
||||
|
||||
# freeze ViT backbone - only memories will be finetuned
|
||||
|
||||
freeze_all_layers_(vit)
|
||||
|
||||
# learnable parameters
|
||||
|
||||
self.memory_cls_token = nn.Parameter(torch.randn(dim))
|
||||
self.memories_per_layer = nn.Parameter(torch.randn(layers, num_memories_per_layer, dim))
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
# specialized attention mask to preserve the output of the original ViT
|
||||
# it allows the memory CLS token to attend to all other tokens (and the learnable memory layer tokens), but not vice versa
|
||||
|
||||
attn_mask = torch.ones((num_patches, num_patches), dtype = torch.bool)
|
||||
attn_mask = F.pad(attn_mask, (1, num_memories_per_layer), value = False) # main tokens cannot attend to learnable memories per layer
|
||||
attn_mask = F.pad(attn_mask, (0, 0, 1, 0), value = True) # memory CLS token can attend to everything
|
||||
self.register_buffer('attn_mask', attn_mask)
|
||||
|
||||
def forward(self, img):
|
||||
b = img.shape[0]
|
||||
|
||||
tokens = self.vit.img_to_tokens(img)
|
||||
|
||||
# add task specific memory tokens
|
||||
|
||||
memory_cls_tokens = repeat(self.memory_cls_token, 'd -> b 1 d', b = b)
|
||||
tokens = torch.cat((memory_cls_tokens, tokens), dim = 1)
|
||||
|
||||
# pass memories along with image tokens through transformer for attending
|
||||
|
||||
out = self.vit.transformer(tokens, memories = self.memories_per_layer, attn_mask = self.attn_mask)
|
||||
|
||||
# extract memory CLS tokens
|
||||
|
||||
memory_cls_tokens = out[:, 0]
|
||||
|
||||
# pass through task specific adapter head
|
||||
|
||||
return self.mlp_head(memory_cls_tokens)
|
||||
@@ -29,7 +29,7 @@ class FeedForward(nn.Module):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(dim, dim * mult, 1),
|
||||
nn.GELU(),
|
||||
nn.Hardswish(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Conv2d(dim * mult, dim, 1),
|
||||
nn.Dropout(dropout)
|
||||
@@ -52,6 +52,7 @@ class Attention(nn.Module):
|
||||
self.to_v = nn.Sequential(nn.Conv2d(dim, inner_dim_value, 1, bias = False), nn.BatchNorm2d(inner_dim_value))
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
out_batch_norm = nn.BatchNorm2d(dim_out)
|
||||
nn.init.zeros_(out_batch_norm.weight)
|
||||
@@ -70,8 +71,8 @@ class Attention(nn.Module):
|
||||
q_range = torch.arange(0, fmap_size, step = (2 if downsample else 1))
|
||||
k_range = torch.arange(fmap_size)
|
||||
|
||||
q_pos = torch.stack(torch.meshgrid(q_range, q_range), dim = -1)
|
||||
k_pos = torch.stack(torch.meshgrid(k_range, k_range), dim = -1)
|
||||
q_pos = torch.stack(torch.meshgrid(q_range, q_range, indexing = 'ij'), dim = -1)
|
||||
k_pos = torch.stack(torch.meshgrid(k_range, k_range, indexing = 'ij'), dim = -1)
|
||||
|
||||
q_pos, k_pos = map(lambda t: rearrange(t, 'i j c -> (i j) c'), (q_pos, k_pos))
|
||||
rel_pos = (q_pos[:, None, ...] - k_pos[None, :, ...]).abs()
|
||||
@@ -100,6 +101,7 @@ class Attention(nn.Module):
|
||||
dots = self.apply_pos_bias(dots)
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h (x y) d -> b (h d) x y', h = h, y = y)
|
||||
|
||||
@@ -78,6 +78,7 @@ class Attention(nn.Module):
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
@@ -93,6 +94,7 @@ class Attention(nn.Module):
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
|
||||
96
vit_pytorch/mae.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from einops import repeat
|
||||
|
||||
from vit_pytorch.vit import Transformer
|
||||
|
||||
class MAE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
encoder,
|
||||
decoder_dim,
|
||||
masking_ratio = 0.75,
|
||||
decoder_depth = 1,
|
||||
decoder_heads = 8,
|
||||
decoder_dim_head = 64
|
||||
):
|
||||
super().__init__()
|
||||
assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1'
|
||||
self.masking_ratio = masking_ratio
|
||||
|
||||
# extract some hyperparameters and functions from encoder (vision transformer to be trained)
|
||||
|
||||
self.encoder = encoder
|
||||
num_patches, encoder_dim = encoder.pos_embedding.shape[-2:]
|
||||
self.to_patch, self.patch_to_emb = encoder.to_patch_embedding[:2]
|
||||
pixel_values_per_patch = self.patch_to_emb.weight.shape[-1]
|
||||
|
||||
# decoder parameters
|
||||
|
||||
self.enc_to_dec = nn.Linear(encoder_dim, decoder_dim) if encoder_dim != decoder_dim else nn.Identity()
|
||||
self.mask_token = nn.Parameter(torch.randn(decoder_dim))
|
||||
self.decoder = Transformer(dim = decoder_dim, depth = decoder_depth, heads = decoder_heads, dim_head = decoder_dim_head, mlp_dim = decoder_dim * 4)
|
||||
self.decoder_pos_emb = nn.Embedding(num_patches, decoder_dim)
|
||||
self.to_pixels = nn.Linear(decoder_dim, pixel_values_per_patch)
|
||||
|
||||
def forward(self, img):
|
||||
device = img.device
|
||||
|
||||
# get patches
|
||||
|
||||
patches = self.to_patch(img)
|
||||
batch, num_patches, *_ = patches.shape
|
||||
|
||||
# patch to encoder tokens and add positions
|
||||
|
||||
tokens = self.patch_to_emb(patches)
|
||||
tokens = tokens + self.encoder.pos_embedding[:, 1:(num_patches + 1)]
|
||||
|
||||
# calculate of patches needed to be masked, and get random indices, dividing it up for mask vs unmasked
|
||||
|
||||
num_masked = int(self.masking_ratio * num_patches)
|
||||
rand_indices = torch.rand(batch, num_patches, device = device).argsort(dim = -1)
|
||||
masked_indices, unmasked_indices = rand_indices[:, :num_masked], rand_indices[:, num_masked:]
|
||||
|
||||
# get the unmasked tokens to be encoded
|
||||
|
||||
batch_range = torch.arange(batch, device = device)[:, None]
|
||||
tokens = tokens[batch_range, unmasked_indices]
|
||||
|
||||
# get the patches to be masked for the final reconstruction loss
|
||||
|
||||
masked_patches = patches[batch_range, masked_indices]
|
||||
|
||||
# attend with vision transformer
|
||||
|
||||
encoded_tokens = self.encoder.transformer(tokens)
|
||||
|
||||
# project encoder to decoder dimensions, if they are not equal - the paper says you can get away with a smaller dimension for decoder
|
||||
|
||||
decoder_tokens = self.enc_to_dec(encoded_tokens)
|
||||
|
||||
# reapply decoder position embedding to unmasked tokens
|
||||
|
||||
decoder_tokens = decoder_tokens + self.decoder_pos_emb(unmasked_indices)
|
||||
|
||||
# repeat mask tokens for number of masked, and add the positions using the masked indices derived above
|
||||
|
||||
mask_tokens = repeat(self.mask_token, 'd -> b n d', b = batch, n = num_masked)
|
||||
mask_tokens = mask_tokens + self.decoder_pos_emb(masked_indices)
|
||||
|
||||
# concat the masked tokens to the decoder tokens and attend with decoder
|
||||
|
||||
decoder_tokens = torch.cat((mask_tokens, decoder_tokens), dim = 1)
|
||||
decoded_tokens = self.decoder(decoder_tokens)
|
||||
|
||||
# splice out the mask tokens and project to pixel values
|
||||
|
||||
mask_tokens = decoded_tokens[:, :num_masked]
|
||||
pred_pixel_values = self.to_pixels(mask_tokens)
|
||||
|
||||
# calculate reconstruction loss
|
||||
|
||||
recon_loss = F.mse_loss(pred_pixel_values, masked_patches)
|
||||
return recon_loss
|
||||
286
vit_pytorch/max_vit.py
Normal file
@@ -0,0 +1,286 @@
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange, Reduce
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def cast_tuple(val, length = 1):
|
||||
return val if isinstance(val, tuple) else ((val,) * length)
|
||||
|
||||
# helper classes
|
||||
|
||||
class PreNormResidual(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x):
|
||||
return self.fn(self.norm(x)) + x
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, mult = 4, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, inner_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
# MBConv
|
||||
|
||||
class SqueezeExcitation(nn.Module):
|
||||
def __init__(self, dim, shrinkage_rate = 0.25):
|
||||
super().__init__()
|
||||
hidden_dim = int(dim * shrinkage_rate)
|
||||
|
||||
self.gate = nn.Sequential(
|
||||
Reduce('b c h w -> b c', 'mean'),
|
||||
nn.Linear(dim, hidden_dim, bias = False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_dim, dim, bias = False),
|
||||
nn.Sigmoid(),
|
||||
Rearrange('b c -> b c 1 1')
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.gate(x)
|
||||
|
||||
|
||||
class MBConvResidual(nn.Module):
|
||||
def __init__(self, fn, dropout = 0.):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.dropsample = Dropsample(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.fn(x)
|
||||
out = self.dropsample(out)
|
||||
return out + x
|
||||
|
||||
class Dropsample(nn.Module):
|
||||
def __init__(self, prob = 0):
|
||||
super().__init__()
|
||||
self.prob = prob
|
||||
|
||||
def forward(self, x):
|
||||
device = x.device
|
||||
|
||||
if self.prob == 0. or (not self.training):
|
||||
return x
|
||||
|
||||
keep_mask = torch.FloatTensor((x.shape[0], 1, 1, 1), device = device).uniform_() > self.prob
|
||||
return x * keep_mask / (1 - self.prob)
|
||||
|
||||
def MBConv(
|
||||
dim_in,
|
||||
dim_out,
|
||||
*,
|
||||
downsample,
|
||||
expansion_rate = 4,
|
||||
shrinkage_rate = 0.25,
|
||||
dropout = 0.
|
||||
):
|
||||
hidden_dim = int(expansion_rate * dim_out)
|
||||
stride = 2 if downsample else 1
|
||||
|
||||
net = nn.Sequential(
|
||||
nn.Conv2d(dim_in, dim_out, 1),
|
||||
nn.BatchNorm2d(dim_out),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(dim_out, dim_out, 3, stride = stride, padding = 1, groups = dim_out),
|
||||
SqueezeExcitation(dim_out, shrinkage_rate = shrinkage_rate),
|
||||
nn.Conv2d(dim_out, dim_out, 1),
|
||||
nn.BatchNorm2d(dim_out)
|
||||
)
|
||||
|
||||
if dim_in == dim_out and not downsample:
|
||||
net = MBConvResidual(net, dropout = dropout)
|
||||
|
||||
return net
|
||||
|
||||
# attention related classes
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_head = 32,
|
||||
dropout = 0.,
|
||||
window_size = 7
|
||||
):
|
||||
super().__init__()
|
||||
assert (dim % dim_head) == 0, 'dimension should be divisible by dimension per head'
|
||||
|
||||
self.heads = dim // dim_head
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
|
||||
|
||||
self.attend = nn.Sequential(
|
||||
nn.Softmax(dim = -1),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(dim, dim, bias = False),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
# relative positional bias
|
||||
|
||||
self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads)
|
||||
|
||||
pos = torch.arange(window_size)
|
||||
grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
|
||||
grid = rearrange(grid, 'c i j -> (i j) c')
|
||||
rel_pos = rearrange(grid, 'i ... -> i 1 ...') - rearrange(grid, 'j ... -> 1 j ...')
|
||||
rel_pos += window_size - 1
|
||||
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1)
|
||||
|
||||
self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False)
|
||||
|
||||
def forward(self, x):
|
||||
batch, height, width, window_height, window_width, _, device, h = *x.shape, x.device, self.heads
|
||||
|
||||
# flatten
|
||||
|
||||
x = rearrange(x, 'b x y w1 w2 d -> (b x y) (w1 w2) d')
|
||||
|
||||
# project for queries, keys, values
|
||||
|
||||
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
|
||||
|
||||
# split heads
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d ) -> b h n d', h = h), (q, k, v))
|
||||
|
||||
# scale
|
||||
|
||||
q = q * self.scale
|
||||
|
||||
# sim
|
||||
|
||||
sim = einsum('b h i d, b h j d -> b h i j', q, k)
|
||||
|
||||
# add positional bias
|
||||
|
||||
bias = self.rel_pos_bias(self.rel_pos_indices)
|
||||
sim = sim + rearrange(bias, 'i j h -> h i j')
|
||||
|
||||
# attention
|
||||
|
||||
attn = self.attend(sim)
|
||||
|
||||
# aggregate
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
|
||||
# merge heads
|
||||
|
||||
out = rearrange(out, 'b h (w1 w2) d -> b w1 w2 (h d)', w1 = window_height, w2 = window_width)
|
||||
|
||||
# combine heads out
|
||||
|
||||
out = self.to_out(out)
|
||||
return rearrange(out, '(b x y) ... -> b x y ...', x = height, y = width)
|
||||
|
||||
class MaxViT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
num_classes,
|
||||
dim,
|
||||
depth,
|
||||
dim_head = 32,
|
||||
dim_conv_stem = None,
|
||||
window_size = 7,
|
||||
mbconv_expansion_rate = 4,
|
||||
mbconv_shrinkage_rate = 0.25,
|
||||
dropout = 0.1,
|
||||
channels = 3
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'
|
||||
|
||||
# convolutional stem
|
||||
|
||||
dim_conv_stem = default(dim_conv_stem, dim)
|
||||
|
||||
self.conv_stem = nn.Sequential(
|
||||
nn.Conv2d(channels, dim_conv_stem, 3, stride = 2, padding = 1),
|
||||
nn.Conv2d(dim_conv_stem, dim_conv_stem, 3, padding = 1)
|
||||
)
|
||||
|
||||
# variables
|
||||
|
||||
num_stages = len(depth)
|
||||
|
||||
dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))
|
||||
dims = (dim_conv_stem, *dims)
|
||||
dim_pairs = tuple(zip(dims[:-1], dims[1:]))
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
# shorthand for window size for efficient block - grid like attention
|
||||
|
||||
w = window_size
|
||||
|
||||
# iterate through stages
|
||||
|
||||
for ind, ((layer_dim_in, layer_dim), layer_depth) in enumerate(zip(dim_pairs, depth)):
|
||||
for stage_ind in range(layer_depth):
|
||||
is_first = stage_ind == 0
|
||||
stage_dim_in = layer_dim_in if is_first else layer_dim
|
||||
|
||||
block = nn.Sequential(
|
||||
MBConv(
|
||||
stage_dim_in,
|
||||
layer_dim,
|
||||
downsample = is_first,
|
||||
expansion_rate = mbconv_expansion_rate,
|
||||
shrinkage_rate = mbconv_shrinkage_rate
|
||||
),
|
||||
Rearrange('b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w), # block-like attention
|
||||
PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
|
||||
PreNormResidual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
|
||||
Rearrange('b x y w1 w2 d -> b d (x w1) (y w2)'),
|
||||
|
||||
Rearrange('b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w), # grid-like attention
|
||||
PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
|
||||
PreNormResidual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
|
||||
Rearrange('b x y w1 w2 d -> b d (w1 x) (w2 y)'),
|
||||
)
|
||||
|
||||
self.layers.append(block)
|
||||
|
||||
# mlp head out
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
Reduce('b d h w -> b d', 'mean'),
|
||||
nn.LayerNorm(dims[-1]),
|
||||
nn.Linear(dims[-1], num_classes)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_stem(x)
|
||||
|
||||
for stage in self.layers:
|
||||
x = stage(x)
|
||||
|
||||
return self.mlp_head(x)
|
||||
252
vit_pytorch/mobile_vit.py
Normal file
@@ -0,0 +1,252 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Reduce
|
||||
|
||||
# helpers
|
||||
|
||||
def conv_1x1_bn(inp, oup):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
nn.SiLU()
|
||||
)
|
||||
|
||||
def conv_nxn_bn(inp, oup, kernal_size=3, stride=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
nn.SiLU()
|
||||
)
|
||||
|
||||
# classes
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout=0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim=-1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
||||
q, k, v = map(lambda t: rearrange(
|
||||
t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b p h n d -> b p n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
"""Transformer block described in ViT.
|
||||
Paper: https://arxiv.org/abs/2010.11929
|
||||
Based on: https://github.com/lucidrains/vit-pytorch
|
||||
"""
|
||||
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
class MV2Block(nn.Module):
|
||||
"""MV2 block described in MobileNetV2.
|
||||
Paper: https://arxiv.org/pdf/1801.04381
|
||||
Based on: https://github.com/tonylins/pytorch-mobilenet-v2
|
||||
"""
|
||||
|
||||
def __init__(self, inp, oup, stride=1, expansion=4):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
assert stride in [1, 2]
|
||||
|
||||
hidden_dim = int(inp * expansion)
|
||||
self.use_res_connect = self.stride == 1 and inp == oup
|
||||
|
||||
if expansion == 1:
|
||||
self.conv = nn.Sequential(
|
||||
# dw
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
|
||||
1, groups=hidden_dim, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.SiLU(),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
)
|
||||
else:
|
||||
self.conv = nn.Sequential(
|
||||
# pw
|
||||
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.SiLU(),
|
||||
# dw
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
|
||||
1, groups=hidden_dim, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.SiLU(),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv(x)
|
||||
if self.use_res_connect:
|
||||
out = out + x
|
||||
return out
|
||||
|
||||
class MobileViTBlock(nn.Module):
|
||||
def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):
|
||||
super().__init__()
|
||||
self.ph, self.pw = patch_size
|
||||
|
||||
self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
|
||||
self.conv2 = conv_1x1_bn(channel, dim)
|
||||
|
||||
self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout)
|
||||
|
||||
self.conv3 = conv_1x1_bn(dim, channel)
|
||||
self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)
|
||||
|
||||
def forward(self, x):
|
||||
y = x.clone()
|
||||
|
||||
# Local representations
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
|
||||
# Global representations
|
||||
_, _, h, w = x.shape
|
||||
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d',
|
||||
ph=self.ph, pw=self.pw)
|
||||
x = self.transformer(x)
|
||||
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)',
|
||||
h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)
|
||||
|
||||
# Fusion
|
||||
x = self.conv3(x)
|
||||
x = torch.cat((x, y), 1)
|
||||
x = self.conv4(x)
|
||||
return x
|
||||
|
||||
class MobileViT(nn.Module):
|
||||
"""MobileViT.
|
||||
Paper: https://arxiv.org/abs/2110.02178
|
||||
Based on: https://github.com/chinhsuanwu/mobilevit-pytorch
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size,
|
||||
dims,
|
||||
channels,
|
||||
num_classes,
|
||||
expansion=4,
|
||||
kernel_size=3,
|
||||
patch_size=(2, 2),
|
||||
depths=(2, 4, 3)
|
||||
):
|
||||
super().__init__()
|
||||
assert len(dims) == 3, 'dims must be a tuple of 3'
|
||||
assert len(depths) == 3, 'depths must be a tuple of 3'
|
||||
|
||||
ih, iw = image_size
|
||||
ph, pw = patch_size
|
||||
assert ih % ph == 0 and iw % pw == 0
|
||||
|
||||
init_dim, *_, last_dim = channels
|
||||
|
||||
self.conv1 = conv_nxn_bn(3, init_dim, stride=2)
|
||||
|
||||
self.stem = nn.ModuleList([])
|
||||
self.stem.append(MV2Block(channels[0], channels[1], 1, expansion))
|
||||
self.stem.append(MV2Block(channels[1], channels[2], 2, expansion))
|
||||
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
|
||||
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
|
||||
|
||||
self.trunk = nn.ModuleList([])
|
||||
self.trunk.append(nn.ModuleList([
|
||||
MV2Block(channels[3], channels[4], 2, expansion),
|
||||
MobileViTBlock(dims[0], depths[0], channels[5],
|
||||
kernel_size, patch_size, int(dims[0] * 2))
|
||||
]))
|
||||
|
||||
self.trunk.append(nn.ModuleList([
|
||||
MV2Block(channels[5], channels[6], 2, expansion),
|
||||
MobileViTBlock(dims[1], depths[1], channels[7],
|
||||
kernel_size, patch_size, int(dims[1] * 4))
|
||||
]))
|
||||
|
||||
self.trunk.append(nn.ModuleList([
|
||||
MV2Block(channels[7], channels[8], 2, expansion),
|
||||
MobileViTBlock(dims[2], depths[2], channels[9],
|
||||
kernel_size, patch_size, int(dims[2] * 4))
|
||||
]))
|
||||
|
||||
self.to_logits = nn.Sequential(
|
||||
conv_1x1_bn(channels[-2], last_dim),
|
||||
Reduce('b c h w -> b c', 'mean'),
|
||||
nn.Linear(channels[-1], num_classes, bias=False)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
|
||||
for conv in self.stem:
|
||||
x = conv(x)
|
||||
|
||||
for conv, attn in self.trunk:
|
||||
x = conv(x)
|
||||
x = attn(x)
|
||||
|
||||
return self.to_logits(x)
|
||||
@@ -20,9 +20,9 @@ class LayerNorm(nn.Module):
|
||||
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
|
||||
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (std + self.eps) * self.g + self.b
|
||||
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
@@ -55,6 +55,7 @@ class Attention(nn.Module):
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
@@ -71,6 +72,7 @@ class Attention(nn.Module):
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
|
||||
@@ -131,10 +133,11 @@ class NesT(nn.Module):
|
||||
|
||||
seq_len = (fmap_size // blocks) ** 2 # sequence length is held constant across heirarchy
|
||||
hierarchies = list(reversed(range(num_hierarchies)))
|
||||
mults = [2 ** i for i in hierarchies]
|
||||
mults = [2 ** i for i in reversed(hierarchies)]
|
||||
|
||||
layer_heads = list(map(lambda t: t * heads, mults))
|
||||
layer_dims = list(map(lambda t: t * dim, mults))
|
||||
last_dim = layer_dims[-1]
|
||||
|
||||
layer_dims = [*layer_dims, layer_dims[-1]]
|
||||
dim_pairs = zip(layer_dims[:-1], layer_dims[1:])
|
||||
@@ -157,10 +160,11 @@ class NesT(nn.Module):
|
||||
Aggregate(dim_in, dim_out) if not is_last else nn.Identity()
|
||||
]))
|
||||
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
LayerNorm(dim),
|
||||
LayerNorm(last_dim),
|
||||
Reduce('b c h w -> b c', 'mean'),
|
||||
nn.Linear(dim, num_classes)
|
||||
nn.Linear(last_dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
|
||||
140
vit_pytorch/parallel_vit.py
Normal file
@@ -0,0 +1,140 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
# classes
|
||||
|
||||
class Parallel(nn.Module):
|
||||
def __init__(self, *fns):
|
||||
super().__init__()
|
||||
self.fns = nn.ModuleList(fns)
|
||||
|
||||
def forward(self, x):
|
||||
return sum([fn(x) for fn in self.fns])
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, num_parallel_branches = 2, dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
attn_block = lambda: PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))
|
||||
ff_block = lambda: PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Parallel(*[attn_block() for _ in range(num_parallel_branches)]),
|
||||
Parallel(*[ff_block() for _ in range(num_parallel_branches)]),
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
for attns, ffs in self.layers:
|
||||
x = attns(x) + x
|
||||
x = ffs(x) + x
|
||||
return x
|
||||
|
||||
class ViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', num_parallel_branches = 2, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
|
||||
nn.Linear(patch_dim, dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, num_parallel_branches, dropout)
|
||||
|
||||
self.pool = pool
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.mlp_head(x)
|
||||
@@ -48,6 +48,7 @@ class Attention(nn.Module):
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
@@ -63,6 +64,7 @@ class Attention(nn.Module):
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
@@ -129,14 +131,15 @@ class PiT(nn.Module):
|
||||
mlp_dim,
|
||||
dim_head = 64,
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.
|
||||
emb_dropout = 0.,
|
||||
channels = 3
|
||||
):
|
||||
super().__init__()
|
||||
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
assert isinstance(depth, tuple), 'depth must be a tuple of integers, specifying the number of blocks before each downsizing'
|
||||
heads = cast_tuple(heads, len(depth))
|
||||
|
||||
patch_dim = 3 * patch_size ** 2
|
||||
patch_dim = channels * patch_size ** 2
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
nn.Unfold(kernel_size = patch_size, stride = patch_size // 2),
|
||||
@@ -175,7 +178,7 @@ class PiT(nn.Module):
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding
|
||||
x += self.pos_embedding[:, :n+1]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.layers(x)
|
||||
|
||||
@@ -55,5 +55,5 @@ class Recorder(nn.Module):
|
||||
target_device = self.device if self.device is not None else img.device
|
||||
recordings = tuple(map(lambda t: t.to(target_device), self.recordings))
|
||||
|
||||
attns = torch.stack(recordings, dim = 1)
|
||||
attns = torch.stack(recordings, dim = 1) if len(recordings) > 0 else None
|
||||
return pred, attns
|
||||
|
||||
269
vit_pytorch/regionvit.py
Normal file
@@ -0,0 +1,269 @@
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange, Reduce
|
||||
import torch.nn.functional as F
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def cast_tuple(val, length = 1):
|
||||
return val if isinstance(val, tuple) else ((val,) * length)
|
||||
|
||||
def divisible_by(val, d):
|
||||
return (val % d) == 0
|
||||
|
||||
# helper classes
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(dim_in, dim_out, 3, stride = 2, padding = 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
class PEG(nn.Module):
|
||||
def __init__(self, dim, kernel_size = 3):
|
||||
super().__init__()
|
||||
self.proj = nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.proj(x) + x
|
||||
|
||||
# transformer classes
|
||||
|
||||
def FeedForward(dim, mult = 4, dropout = 0.):
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, dim * mult, 1),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(dim * mult, dim, 1)
|
||||
)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
heads = 4,
|
||||
dim_head = 32,
|
||||
dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, rel_pos_bias = None):
|
||||
h = self.heads
|
||||
|
||||
# prenorm
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
|
||||
|
||||
# split heads
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
|
||||
q = q * self.scale
|
||||
|
||||
sim = einsum('b h i d, b h j d -> b h i j', q, k)
|
||||
|
||||
# add relative positional bias for local tokens
|
||||
|
||||
if exists(rel_pos_bias):
|
||||
sim = sim + rel_pos_bias
|
||||
|
||||
attn = sim.softmax(dim = -1)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
# merge heads
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class R2LTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
window_size,
|
||||
depth = 4,
|
||||
heads = 4,
|
||||
dim_head = 32,
|
||||
attn_dropout = 0.,
|
||||
ff_dropout = 0.,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
self.window_size = window_size
|
||||
rel_positions = 2 * window_size - 1
|
||||
self.local_rel_pos_bias = nn.Embedding(rel_positions ** 2, heads)
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout),
|
||||
FeedForward(dim, dropout = ff_dropout)
|
||||
]))
|
||||
|
||||
def forward(self, local_tokens, region_tokens):
|
||||
device = local_tokens.device
|
||||
lh, lw = local_tokens.shape[-2:]
|
||||
rh, rw = region_tokens.shape[-2:]
|
||||
window_size_h, window_size_w = lh // rh, lw // rw
|
||||
|
||||
local_tokens = rearrange(local_tokens, 'b c h w -> b (h w) c')
|
||||
region_tokens = rearrange(region_tokens, 'b c h w -> b (h w) c')
|
||||
|
||||
# calculate local relative positional bias
|
||||
|
||||
h_range = torch.arange(window_size_h, device = device)
|
||||
w_range = torch.arange(window_size_w, device = device)
|
||||
|
||||
grid_x, grid_y = torch.meshgrid(h_range, w_range, indexing = 'ij')
|
||||
grid = torch.stack((grid_x, grid_y))
|
||||
grid = rearrange(grid, 'c h w -> c (h w)')
|
||||
grid = (grid[:, :, None] - grid[:, None, :]) + (self.window_size - 1)
|
||||
bias_indices = (grid * torch.tensor([1, self.window_size * 2 - 1], device = device)[:, None, None]).sum(dim = 0)
|
||||
rel_pos_bias = self.local_rel_pos_bias(bias_indices)
|
||||
rel_pos_bias = rearrange(rel_pos_bias, 'i j h -> () h i j')
|
||||
rel_pos_bias = F.pad(rel_pos_bias, (1, 0, 1, 0), value = 0)
|
||||
|
||||
# go through r2l transformer layers
|
||||
|
||||
for attn, ff in self.layers:
|
||||
region_tokens = attn(region_tokens) + region_tokens
|
||||
|
||||
# concat region tokens to local tokens
|
||||
|
||||
local_tokens = rearrange(local_tokens, 'b (h w) d -> b h w d', h = lh)
|
||||
local_tokens = rearrange(local_tokens, 'b (h p1) (w p2) d -> (b h w) (p1 p2) d', p1 = window_size_h, p2 = window_size_w)
|
||||
region_tokens = rearrange(region_tokens, 'b n d -> (b n) () d')
|
||||
|
||||
# do self attention on local tokens, along with its regional token
|
||||
|
||||
region_and_local_tokens = torch.cat((region_tokens, local_tokens), dim = 1)
|
||||
region_and_local_tokens = attn(region_and_local_tokens, rel_pos_bias = rel_pos_bias) + region_and_local_tokens
|
||||
|
||||
# feedforward
|
||||
|
||||
region_and_local_tokens = ff(region_and_local_tokens) + region_and_local_tokens
|
||||
|
||||
# split back local and regional tokens
|
||||
|
||||
region_tokens, local_tokens = region_and_local_tokens[:, :1], region_and_local_tokens[:, 1:]
|
||||
local_tokens = rearrange(local_tokens, '(b h w) (p1 p2) d -> b (h p1 w p2) d', h = lh // window_size_h, w = lw // window_size_w, p1 = window_size_h)
|
||||
region_tokens = rearrange(region_tokens, '(b n) () d -> b n d', n = rh * rw)
|
||||
|
||||
local_tokens = rearrange(local_tokens, 'b (h w) c -> b c h w', h = lh, w = lw)
|
||||
region_tokens = rearrange(region_tokens, 'b (h w) c -> b c h w', h = rh, w = rw)
|
||||
return local_tokens, region_tokens
|
||||
|
||||
# classes
|
||||
|
||||
class RegionViT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim = (64, 128, 256, 512),
|
||||
depth = (2, 2, 8, 2),
|
||||
window_size = 7,
|
||||
num_classes = 1000,
|
||||
tokenize_local_3_conv = False,
|
||||
local_patch_size = 4,
|
||||
use_peg = False,
|
||||
attn_dropout = 0.,
|
||||
ff_dropout = 0.,
|
||||
channels = 3,
|
||||
):
|
||||
super().__init__()
|
||||
dim = cast_tuple(dim, 4)
|
||||
depth = cast_tuple(depth, 4)
|
||||
assert len(dim) == 4, 'dim needs to be a single value or a tuple of length 4'
|
||||
assert len(depth) == 4, 'depth needs to be a single value or a tuple of length 4'
|
||||
|
||||
self.local_patch_size = local_patch_size
|
||||
|
||||
region_patch_size = local_patch_size * window_size
|
||||
self.region_patch_size = local_patch_size * window_size
|
||||
|
||||
init_dim, *_, last_dim = dim
|
||||
|
||||
# local and region encoders
|
||||
|
||||
if tokenize_local_3_conv:
|
||||
self.local_encoder = nn.Sequential(
|
||||
nn.Conv2d(3, init_dim, 3, 2, 1),
|
||||
nn.LayerNorm(init_dim),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(init_dim, init_dim, 3, 2, 1),
|
||||
nn.LayerNorm(init_dim),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(init_dim, init_dim, 3, 1, 1)
|
||||
)
|
||||
else:
|
||||
self.local_encoder = nn.Conv2d(3, init_dim, 8, 4, 3)
|
||||
|
||||
self.region_encoder = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = region_patch_size, p2 = region_patch_size),
|
||||
nn.Conv2d((region_patch_size ** 2) * channels, init_dim, 1)
|
||||
)
|
||||
|
||||
# layers
|
||||
|
||||
current_dim = init_dim
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
for ind, dim, num_layers in zip(range(4), dim, depth):
|
||||
not_first = ind != 0
|
||||
need_downsample = not_first
|
||||
need_peg = not_first and use_peg
|
||||
|
||||
self.layers.append(nn.ModuleList([
|
||||
Downsample(current_dim, dim) if need_downsample else nn.Identity(),
|
||||
PEG(dim) if need_peg else nn.Identity(),
|
||||
R2LTransformer(dim, depth = num_layers, window_size = window_size, attn_dropout = attn_dropout, ff_dropout = ff_dropout)
|
||||
]))
|
||||
|
||||
current_dim = dim
|
||||
|
||||
# final logits
|
||||
|
||||
self.to_logits = nn.Sequential(
|
||||
Reduce('b c h w -> b c', 'mean'),
|
||||
nn.LayerNorm(last_dim),
|
||||
nn.Linear(last_dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
*_, h, w = x.shape
|
||||
assert divisible_by(h, self.region_patch_size) and divisible_by(w, self.region_patch_size), 'height and width must be divisible by region patch size'
|
||||
assert divisible_by(h, self.local_patch_size) and divisible_by(w, self.local_patch_size), 'height and width must be divisible by local patch size'
|
||||
|
||||
local_tokens = self.local_encoder(x)
|
||||
region_tokens = self.region_encoder(x)
|
||||
|
||||
for down, peg, transformer in self.layers:
|
||||
local_tokens, region_tokens = down(local_tokens), down(region_tokens)
|
||||
local_tokens = peg(local_tokens)
|
||||
local_tokens, region_tokens = transformer(local_tokens, region_tokens)
|
||||
|
||||
return self.to_logits(region_tokens)
|
||||
@@ -19,7 +19,7 @@ class AxialRotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_freq = 10):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
scales = torch.logspace(0., log(max_freq / 2) / log(2), self.dim // 4, base = 2)
|
||||
scales = torch.linspace(1., max_freq / 2, self.dim // 4)
|
||||
self.register_buffer('scales', scales)
|
||||
|
||||
def forward(self, x):
|
||||
@@ -104,6 +104,7 @@ class Attention(nn.Module):
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.use_ds_conv = use_ds_conv
|
||||
|
||||
@@ -148,16 +149,17 @@ class Attention(nn.Module):
|
||||
dots = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = einsum('b i j, b j d -> b i d', attn, v)
|
||||
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., use_rotary = True, use_ds_conv = True, use_glu = True):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, image_size, dropout = 0., use_rotary = True, use_ds_conv = True, use_glu = True):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
self.pos_emb = AxialRotaryEmbedding(dim_head)
|
||||
self.pos_emb = AxialRotaryEmbedding(dim_head, max_freq = image_size)
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_rotary = use_rotary, use_ds_conv = use_ds_conv)),
|
||||
@@ -187,7 +189,7 @@ class RvT(nn.Module):
|
||||
)
|
||||
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, use_rotary, use_ds_conv, use_glu)
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, image_size, dropout, use_rotary, use_ds_conv, use_glu)
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
|
||||
306
vit_pytorch/scalable_vit.py
Normal file
@@ -0,0 +1,306 @@
|
||||
from functools import partial
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange, Reduce
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def cast_tuple(val, length = 1):
|
||||
return val if isinstance(val, tuple) else ((val,) * length)
|
||||
|
||||
# helper classes
|
||||
|
||||
class ChanLayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps = 1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = ChanLayerNorm(dim)
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x):
|
||||
return self.fn(self.norm(x))
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(dim_in, dim_out, 3, stride = 2, padding = 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
class PEG(nn.Module):
|
||||
def __init__(self, dim, kernel_size = 3):
|
||||
super().__init__()
|
||||
self.proj = nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.proj(x) + x
|
||||
|
||||
# feedforward
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, expansion_factor = 4, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim * expansion_factor
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(dim, inner_dim, 1),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Conv2d(inner_dim, dim, 1),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
# attention
|
||||
|
||||
class ScalableSelfAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
heads = 8,
|
||||
dim_key = 32,
|
||||
dim_value = 32,
|
||||
dropout = 0.,
|
||||
reduction_factor = 1
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.scale = dim_key ** -0.5
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_q = nn.Conv2d(dim, dim_key * heads, 1, bias = False)
|
||||
self.to_k = nn.Conv2d(dim, dim_key * heads, reduction_factor, stride = reduction_factor, bias = False)
|
||||
self.to_v = nn.Conv2d(dim, dim_value * heads, reduction_factor, stride = reduction_factor, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Conv2d(dim_value * heads, dim, 1),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
height, width, heads = *x.shape[-2:], self.heads
|
||||
|
||||
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
|
||||
|
||||
# split out heads
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = heads), (q, k, v))
|
||||
|
||||
# similarity
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
# attention
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
# aggregate values
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
|
||||
# merge back heads
|
||||
|
||||
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = height, y = width)
|
||||
return self.to_out(out)
|
||||
|
||||
class InteractiveWindowedSelfAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
window_size,
|
||||
heads = 8,
|
||||
dim_key = 32,
|
||||
dim_value = 32,
|
||||
dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.scale = dim_key ** -0.5
|
||||
self.window_size = window_size
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.local_interactive_module = nn.Conv2d(dim_value * heads, dim_value * heads, 3, padding = 1)
|
||||
|
||||
self.to_q = nn.Conv2d(dim, dim_key * heads, 1, bias = False)
|
||||
self.to_k = nn.Conv2d(dim, dim_key * heads, 1, bias = False)
|
||||
self.to_v = nn.Conv2d(dim, dim_value * heads, 1, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Conv2d(dim_value * heads, dim, 1),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
height, width, heads, wsz = *x.shape[-2:], self.heads, self.window_size
|
||||
|
||||
wsz_h, wsz_w = default(wsz, height), default(wsz, width)
|
||||
assert (height % wsz_h) == 0 and (width % wsz_w) == 0, f'height ({height}) or width ({width}) of feature map is not divisible by the window size ({wsz_h}, {wsz_w})'
|
||||
|
||||
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
|
||||
|
||||
# get output of LIM
|
||||
|
||||
local_out = self.local_interactive_module(v)
|
||||
|
||||
# divide into window (and split out heads) for efficient self attention
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b (h d) (x w1) (y w2) -> (b x y) h (w1 w2) d', h = heads, w1 = wsz_h, w2 = wsz_w), (q, k, v))
|
||||
|
||||
# similarity
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
# attention
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
# aggregate values
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
|
||||
# reshape the windows back to full feature map (and merge heads)
|
||||
|
||||
out = rearrange(out, '(b x y) h (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz_h, y = width // wsz_w, w1 = wsz_h, w2 = wsz_w)
|
||||
|
||||
# add LIM output
|
||||
|
||||
out = out + local_out
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
depth,
|
||||
heads = 8,
|
||||
ff_expansion_factor = 4,
|
||||
dropout = 0.,
|
||||
ssa_dim_key = 32,
|
||||
ssa_dim_value = 32,
|
||||
ssa_reduction_factor = 1,
|
||||
iwsa_dim_key = 32,
|
||||
iwsa_dim_value = 32,
|
||||
iwsa_window_size = None,
|
||||
norm_output = True
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for ind in range(depth):
|
||||
is_first = ind == 0
|
||||
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, ScalableSelfAttention(dim, heads = heads, dim_key = ssa_dim_key, dim_value = ssa_dim_value, reduction_factor = ssa_reduction_factor, dropout = dropout)),
|
||||
PreNorm(dim, FeedForward(dim, expansion_factor = ff_expansion_factor, dropout = dropout)),
|
||||
PEG(dim) if is_first else None,
|
||||
PreNorm(dim, FeedForward(dim, expansion_factor = ff_expansion_factor, dropout = dropout)),
|
||||
PreNorm(dim, InteractiveWindowedSelfAttention(dim, heads = heads, dim_key = iwsa_dim_key, dim_value = iwsa_dim_value, window_size = iwsa_window_size, dropout = dropout))
|
||||
]))
|
||||
|
||||
self.norm = ChanLayerNorm(dim) if norm_output else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
for ssa, ff1, peg, iwsa, ff2 in self.layers:
|
||||
x = ssa(x) + x
|
||||
x = ff1(x) + x
|
||||
|
||||
if exists(peg):
|
||||
x = peg(x)
|
||||
|
||||
x = iwsa(x) + x
|
||||
x = ff2(x) + x
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class ScalableViT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
num_classes,
|
||||
dim,
|
||||
depth,
|
||||
heads,
|
||||
reduction_factor,
|
||||
window_size = None,
|
||||
iwsa_dim_key = 32,
|
||||
iwsa_dim_value = 32,
|
||||
ssa_dim_key = 32,
|
||||
ssa_dim_value = 32,
|
||||
ff_expansion_factor = 4,
|
||||
channels = 3,
|
||||
dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
self.to_patches = nn.Conv2d(channels, dim, 7, stride = 4, padding = 3)
|
||||
|
||||
assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'
|
||||
|
||||
num_stages = len(depth)
|
||||
dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))
|
||||
|
||||
hyperparams_per_stage = [
|
||||
heads,
|
||||
ssa_dim_key,
|
||||
ssa_dim_value,
|
||||
reduction_factor,
|
||||
iwsa_dim_key,
|
||||
iwsa_dim_value,
|
||||
window_size,
|
||||
]
|
||||
|
||||
hyperparams_per_stage = list(map(partial(cast_tuple, length = num_stages), hyperparams_per_stage))
|
||||
assert all(tuple(map(lambda arr: len(arr) == num_stages, hyperparams_per_stage)))
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
for ind, (layer_dim, layer_depth, layer_heads, layer_ssa_dim_key, layer_ssa_dim_value, layer_ssa_reduction_factor, layer_iwsa_dim_key, layer_iwsa_dim_value, layer_window_size) in enumerate(zip(dims, depth, *hyperparams_per_stage)):
|
||||
is_last = ind == (num_stages - 1)
|
||||
|
||||
self.layers.append(nn.ModuleList([
|
||||
Transformer(dim = layer_dim, depth = layer_depth, heads = layer_heads, ff_expansion_factor = ff_expansion_factor, dropout = dropout, ssa_dim_key = layer_ssa_dim_key, ssa_dim_value = layer_ssa_dim_value, ssa_reduction_factor = layer_ssa_reduction_factor, iwsa_dim_key = layer_iwsa_dim_key, iwsa_dim_value = layer_iwsa_dim_value, iwsa_window_size = layer_window_size, norm_output = not is_last),
|
||||
Downsample(layer_dim, layer_dim * 2) if not is_last else None
|
||||
]))
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
Reduce('b d h w -> b d', 'mean'),
|
||||
nn.LayerNorm(dims[-1]),
|
||||
nn.Linear(dims[-1], num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
x = self.to_patches(img)
|
||||
|
||||
for transformer, downsample in self.layers:
|
||||
x = transformer(x)
|
||||
|
||||
if exists(downsample):
|
||||
x = downsample(x)
|
||||
|
||||
return self.mlp_head(x)
|
||||
294
vit_pytorch/sep_vit.py
Normal file
@@ -0,0 +1,294 @@
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange, Reduce
|
||||
|
||||
# helpers
|
||||
|
||||
def cast_tuple(val, length = 1):
|
||||
return val if isinstance(val, tuple) else ((val,) * length)
|
||||
|
||||
# helper classes
|
||||
|
||||
class ChanLayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps = 1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = ChanLayerNorm(dim)
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x):
|
||||
return self.fn(self.norm(x))
|
||||
|
||||
class OverlappingPatchEmbed(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, stride = 2):
|
||||
super().__init__()
|
||||
kernel_size = stride * 2 - 1
|
||||
padding = kernel_size // 2
|
||||
self.conv = nn.Conv2d(dim_in, dim_out, kernel_size, stride = stride, padding = padding)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
class PEG(nn.Module):
|
||||
def __init__(self, dim, kernel_size = 3):
|
||||
super().__init__()
|
||||
self.proj = nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.proj(x) + x
|
||||
|
||||
# feedforward
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, mult = 4, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(dim, inner_dim, 1),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Conv2d(inner_dim, dim, 1),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
# attention
|
||||
|
||||
class DSSA(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
heads = 8,
|
||||
dim_head = 32,
|
||||
dropout = 0.,
|
||||
window_size = 7
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.window_size = window_size
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.attend = nn.Sequential(
|
||||
nn.Softmax(dim = -1),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
self.to_qkv = nn.Conv1d(dim, inner_dim * 3, 1, bias = False)
|
||||
|
||||
# window tokens
|
||||
|
||||
self.window_tokens = nn.Parameter(torch.randn(dim))
|
||||
|
||||
# prenorm and non-linearity for window tokens
|
||||
# then projection to queries and keys for window tokens
|
||||
|
||||
self.window_tokens_to_qk = nn.Sequential(
|
||||
nn.LayerNorm(dim_head),
|
||||
nn.GELU(),
|
||||
Rearrange('b h n c -> b (h c) n'),
|
||||
nn.Conv1d(inner_dim, inner_dim * 2, 1),
|
||||
Rearrange('b (h c) n -> b h n c', h = heads),
|
||||
)
|
||||
|
||||
# window attention
|
||||
|
||||
self.window_attend = nn.Sequential(
|
||||
nn.Softmax(dim = -1),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Conv2d(inner_dim, dim, 1),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
einstein notation
|
||||
|
||||
b - batch
|
||||
c - channels
|
||||
w1 - window size (height)
|
||||
w2 - also window size (width)
|
||||
i - sequence dimension (source)
|
||||
j - sequence dimension (target dimension to be reduced)
|
||||
h - heads
|
||||
x - height of feature map divided by window size
|
||||
y - width of feature map divided by window size
|
||||
"""
|
||||
|
||||
batch, height, width, heads, wsz = x.shape[0], *x.shape[-2:], self.heads, self.window_size
|
||||
assert (height % wsz) == 0 and (width % wsz) == 0, f'height {height} and width {width} must be divisible by window size {wsz}'
|
||||
num_windows = (height // wsz) * (width // wsz)
|
||||
|
||||
# fold in windows for "depthwise" attention - not sure why it is named depthwise when it is just "windowed" attention
|
||||
|
||||
x = rearrange(x, 'b c (h w1) (w w2) -> (b h w) c (w1 w2)', w1 = wsz, w2 = wsz)
|
||||
|
||||
# add windowing tokens
|
||||
|
||||
w = repeat(self.window_tokens, 'c -> b c 1', b = x.shape[0])
|
||||
x = torch.cat((w, x), dim = -1)
|
||||
|
||||
# project for queries, keys, value
|
||||
|
||||
q, k, v = self.to_qkv(x).chunk(3, dim = 1)
|
||||
|
||||
# split out heads
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = heads), (q, k, v))
|
||||
|
||||
# scale
|
||||
|
||||
q = q * self.scale
|
||||
|
||||
# similarity
|
||||
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k)
|
||||
|
||||
# attention
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
# aggregate values
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
|
||||
# split out windowed tokens
|
||||
|
||||
window_tokens, windowed_fmaps = out[:, :, 0], out[:, :, 1:]
|
||||
|
||||
# early return if there is only 1 window
|
||||
|
||||
if num_windows == 1:
|
||||
fmap = rearrange(windowed_fmaps, '(b x y) h (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz, y = width // wsz, w1 = wsz, w2 = wsz)
|
||||
return self.to_out(fmap)
|
||||
|
||||
# carry out the pointwise attention, the main novelty in the paper
|
||||
|
||||
window_tokens = rearrange(window_tokens, '(b x y) h d -> b h (x y) d', x = height // wsz, y = width // wsz)
|
||||
windowed_fmaps = rearrange(windowed_fmaps, '(b x y) h n d -> b h (x y) n d', x = height // wsz, y = width // wsz)
|
||||
|
||||
# windowed queries and keys (preceded by prenorm activation)
|
||||
|
||||
w_q, w_k = self.window_tokens_to_qk(window_tokens).chunk(2, dim = -1)
|
||||
|
||||
# scale
|
||||
|
||||
w_q = w_q * self.scale
|
||||
|
||||
# similarities
|
||||
|
||||
w_dots = einsum('b h i d, b h j d -> b h i j', w_q, w_k)
|
||||
|
||||
w_attn = self.window_attend(w_dots)
|
||||
|
||||
# aggregate the feature maps from the "depthwise" attention step (the most interesting part of the paper, one i haven't seen before)
|
||||
|
||||
aggregated_windowed_fmap = einsum('b h i j, b h j w d -> b h i w d', w_attn, windowed_fmaps)
|
||||
|
||||
# fold back the windows and then combine heads for aggregation
|
||||
|
||||
fmap = rearrange(aggregated_windowed_fmap, 'b h (x y) (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz, y = width // wsz, w1 = wsz, w2 = wsz)
|
||||
return self.to_out(fmap)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
depth,
|
||||
dim_head = 32,
|
||||
heads = 8,
|
||||
ff_mult = 4,
|
||||
dropout = 0.,
|
||||
norm_output = True
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
for ind in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, DSSA(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mult = ff_mult, dropout = dropout)),
|
||||
]))
|
||||
|
||||
self.norm = ChanLayerNorm(dim) if norm_output else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class SepViT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
num_classes,
|
||||
dim,
|
||||
depth,
|
||||
heads,
|
||||
window_size = 7,
|
||||
dim_head = 32,
|
||||
ff_mult = 4,
|
||||
channels = 3,
|
||||
dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'
|
||||
|
||||
num_stages = len(depth)
|
||||
|
||||
dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))
|
||||
dims = (channels, *dims)
|
||||
dim_pairs = tuple(zip(dims[:-1], dims[1:]))
|
||||
|
||||
strides = (4, *((2,) * (num_stages - 1)))
|
||||
|
||||
hyperparams_per_stage = [heads, window_size]
|
||||
hyperparams_per_stage = list(map(partial(cast_tuple, length = num_stages), hyperparams_per_stage))
|
||||
assert all(tuple(map(lambda arr: len(arr) == num_stages, hyperparams_per_stage)))
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
for ind, ((layer_dim_in, layer_dim), layer_depth, layer_stride, layer_heads, layer_window_size) in enumerate(zip(dim_pairs, depth, strides, *hyperparams_per_stage)):
|
||||
is_last = ind == (num_stages - 1)
|
||||
|
||||
self.layers.append(nn.ModuleList([
|
||||
OverlappingPatchEmbed(layer_dim_in, layer_dim, stride = layer_stride),
|
||||
PEG(layer_dim),
|
||||
Transformer(dim = layer_dim, depth = layer_depth, heads = layer_heads, ff_mult = ff_mult, dropout = dropout, norm_output = not is_last),
|
||||
]))
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
Reduce('b d h w -> b d', 'mean'),
|
||||
nn.LayerNorm(dims[-1]),
|
||||
nn.Linear(dims[-1], num_classes)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
for ope, peg, transformer in self.layers:
|
||||
x = ope(x)
|
||||
x = peg(x)
|
||||
x = transformer(x)
|
||||
|
||||
return self.mlp_head(x)
|
||||
84
vit_pytorch/simmim.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from einops import repeat
|
||||
|
||||
class SimMIM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
encoder,
|
||||
masking_ratio = 0.5
|
||||
):
|
||||
super().__init__()
|
||||
assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1'
|
||||
self.masking_ratio = masking_ratio
|
||||
|
||||
# extract some hyperparameters and functions from encoder (vision transformer to be trained)
|
||||
|
||||
self.encoder = encoder
|
||||
num_patches, encoder_dim = encoder.pos_embedding.shape[-2:]
|
||||
self.to_patch, self.patch_to_emb = encoder.to_patch_embedding[:2]
|
||||
pixel_values_per_patch = self.patch_to_emb.weight.shape[-1]
|
||||
|
||||
# simple linear head
|
||||
|
||||
self.mask_token = nn.Parameter(torch.randn(encoder_dim))
|
||||
self.to_pixels = nn.Linear(encoder_dim, pixel_values_per_patch)
|
||||
|
||||
def forward(self, img):
|
||||
device = img.device
|
||||
|
||||
# get patches
|
||||
|
||||
patches = self.to_patch(img)
|
||||
batch, num_patches, *_ = patches.shape
|
||||
|
||||
# for indexing purposes
|
||||
|
||||
batch_range = torch.arange(batch, device = device)[:, None]
|
||||
|
||||
# get positions
|
||||
|
||||
pos_emb = self.encoder.pos_embedding[:, 1:(num_patches + 1)]
|
||||
|
||||
# patch to encoder tokens and add positions
|
||||
|
||||
tokens = self.patch_to_emb(patches)
|
||||
tokens = tokens + pos_emb
|
||||
|
||||
# prepare mask tokens
|
||||
|
||||
mask_tokens = repeat(self.mask_token, 'd -> b n d', b = batch, n = num_patches)
|
||||
mask_tokens = mask_tokens + pos_emb
|
||||
|
||||
# calculate of patches needed to be masked, and get positions (indices) to be masked
|
||||
|
||||
num_masked = int(self.masking_ratio * num_patches)
|
||||
masked_indices = torch.rand(batch, num_patches, device = device).topk(k = num_masked, dim = -1).indices
|
||||
masked_bool_mask = torch.zeros((batch, num_patches), device = device).scatter_(-1, masked_indices, 1).bool()
|
||||
|
||||
# mask tokens
|
||||
|
||||
tokens = torch.where(masked_bool_mask[..., None], mask_tokens, tokens)
|
||||
|
||||
# attend with vision transformer
|
||||
|
||||
encoded = self.encoder.transformer(tokens)
|
||||
|
||||
# get the masked tokens
|
||||
|
||||
encoded_mask_tokens = encoded[batch_range, masked_indices]
|
||||
|
||||
# small linear projection for predicted pixel values
|
||||
|
||||
pred_pixel_values = self.to_pixels(encoded_mask_tokens)
|
||||
|
||||
# get the masked patches for the final reconstruction loss
|
||||
|
||||
masked_patches = patches[batch_range, masked_indices]
|
||||
|
||||
# calculate reconstruction loss
|
||||
|
||||
recon_loss = F.l1_loss(pred_pixel_values, masked_patches) / num_masked
|
||||
return recon_loss
|
||||
@@ -72,7 +72,7 @@ class T2TViT(nn.Module):
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding
|
||||
x += self.pos_embedding[:, :n+1]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
@@ -38,9 +38,9 @@ class LayerNorm(nn.Module):
|
||||
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
|
||||
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (std + self.eps) * self.g + self.b
|
||||
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
@@ -130,6 +130,8 @@ class GlobalAttention(nn.Module):
|
||||
self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False)
|
||||
self.to_kv = nn.Conv2d(dim, inner_dim * 2, k, stride = k, bias = False)
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Conv2d(inner_dim, dim, 1),
|
||||
nn.Dropout(dropout)
|
||||
@@ -145,6 +147,7 @@ class GlobalAttention(nn.Module):
|
||||
dots = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
|
||||
attn = dots.softmax(dim = -1)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = einsum('b i j, b j d -> b i d', attn, v)
|
||||
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, y = y)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
@@ -43,6 +42,8 @@ class Attention(nn.Module):
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
@@ -51,15 +52,15 @@ class Attention(nn.Module):
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
b, n, _, h = *x.shape, self.heads
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
@@ -113,7 +114,7 @@ class ViT(nn.Module):
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.dropout(x)
|
||||
|
||||
145
vit_pytorch/vit_for_small_dataset.py
Normal file
@@ -0,0 +1,145 @@
|
||||
from math import sqrt
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
# classes
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class LSA(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.temperature = nn.Parameter(torch.log(torch.tensor(dim_head ** -0.5)))
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.temperature.exp()
|
||||
|
||||
mask = torch.eye(dots.shape[-1], device = dots.device, dtype = torch.bool)
|
||||
mask_value = -torch.finfo(dots.dtype).max
|
||||
dots = dots.masked_fill(mask, mask_value)
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, LSA(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
class SPT(nn.Module):
|
||||
def __init__(self, *, dim, patch_size, channels = 3):
|
||||
super().__init__()
|
||||
patch_dim = patch_size * patch_size * 5 * channels
|
||||
|
||||
self.to_patch_tokens = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
shifts = ((1, -1, 0, 0), (-1, 1, 0, 0), (0, 0, 1, -1), (0, 0, -1, 1))
|
||||
shifted_x = list(map(lambda shift: F.pad(x, shift), shifts))
|
||||
x_with_shifts = torch.cat((x, *shifted_x), dim = 1)
|
||||
return self.to_patch_tokens(x_with_shifts)
|
||||
|
||||
class ViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
|
||||
self.to_patch_embedding = SPT(dim = dim, patch_size = patch_size, channels = channels)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
self.pool = pool
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.mlp_head(x)
|
||||
147
vit_pytorch/vit_with_patch_merger.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange, Reduce
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val ,d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
# patch merger class
|
||||
|
||||
class PatchMerger(nn.Module):
|
||||
def __init__(self, dim, num_tokens_out):
|
||||
super().__init__()
|
||||
self.scale = dim ** -0.5
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.queries = nn.Parameter(torch.randn(num_tokens_out, dim))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
sim = torch.matmul(self.queries, x.transpose(-1, -2)) * self.scale
|
||||
attn = sim.softmax(dim = -1)
|
||||
return torch.matmul(attn, x)
|
||||
|
||||
# classes
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., patch_merge_layer = None, patch_merge_num_tokens = 8):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
self.patch_merge_layer_index = default(patch_merge_layer, depth // 2) - 1 # default to mid-way through transformer, as shown in paper
|
||||
self.patch_merger = PatchMerger(dim = dim, num_tokens_out = patch_merge_num_tokens)
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
|
||||
]))
|
||||
def forward(self, x):
|
||||
for index, (attn, ff) in enumerate(self.layers):
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
|
||||
if index == self.patch_merge_layer_index:
|
||||
x = self.patch_merger(x)
|
||||
|
||||
return x
|
||||
|
||||
class ViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, patch_merge_layer = None, patch_merge_num_tokens = 8, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
|
||||
nn.Linear(patch_dim, dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, patch_merge_layer, patch_merge_num_tokens)
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
Reduce('b n d -> b d', 'mean'),
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
x += self.pos_embedding[:, :n]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
return self.mlp_head(x)
|
||||