mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
Compare commits
23 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
29fbf0aff4 | ||
|
|
4b8f5bc900 | ||
|
|
f86e052c05 | ||
|
|
2fa2b62def | ||
|
|
9f87d1c43b | ||
|
|
2c6dd7010a | ||
|
|
6460119f65 | ||
|
|
4e62e5f05e | ||
|
|
b3e90a2652 | ||
|
|
4ef72fc4dc | ||
|
|
c2aab05ebf | ||
|
|
81661e3966 | ||
|
|
13f8e123bb | ||
|
|
2d4089c88e | ||
|
|
c7bb5fc43f | ||
|
|
946b19be64 | ||
|
|
d93cd84ccd | ||
|
|
5d4c798949 | ||
|
|
d65a742efe | ||
|
|
8c54e01492 | ||
|
|
df656fe7c7 | ||
|
|
4e6a42a0ca | ||
|
|
6d7298d8ad |
3
.github/FUNDING.yml
vendored
Normal file
3
.github/FUNDING.yml
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
# These are supported funding model platforms
|
||||
|
||||
github: [lucidrains]
|
||||
346
README.md
346
README.md
@@ -6,6 +6,7 @@
|
||||
- [Install](#install)
|
||||
- [Usage](#usage)
|
||||
- [Parameters](#parameters)
|
||||
- [Simple ViT](#simple-vit)
|
||||
- [Distillation](#distillation)
|
||||
- [Deep ViT](#deep-vit)
|
||||
- [CaiT](#cait)
|
||||
@@ -19,6 +20,8 @@
|
||||
- [CrossFormer](#crossformer)
|
||||
- [RegionViT](#regionvit)
|
||||
- [ScalableViT](#scalablevit)
|
||||
- [SepViT](#sepvit)
|
||||
- [MaxViT](#maxvit)
|
||||
- [NesT](#nest)
|
||||
- [MobileViT](#mobilevit)
|
||||
- [Masked Autoencoder](#masked-autoencoder)
|
||||
@@ -27,8 +30,11 @@
|
||||
- [Adaptive Token Sampling](#adaptive-token-sampling)
|
||||
- [Patch Merger](#patch-merger)
|
||||
- [Vision Transformer for Small Datasets](#vision-transformer-for-small-datasets)
|
||||
- [3D Vit](#3d-vit)
|
||||
- [Parallel ViT](#parallel-vit)
|
||||
- [Learnable Memory ViT](#learnable-memory-vit)
|
||||
- [Dino](#dino)
|
||||
- [EsViT](#esvit)
|
||||
- [Accessing Attention](#accessing-attention)
|
||||
- [Research Ideas](#research-ideas)
|
||||
* [Efficient Attention](#efficient-attention)
|
||||
@@ -45,6 +51,10 @@ For a Pytorch implementation with pretrained models, please see Ross Wightman's
|
||||
|
||||
The official Jax repository is <a href="https://github.com/google-research/vision_transformer">here</a>.
|
||||
|
||||
A tensorflow2 translation also exists <a href="https://github.com/taki0112/vit-tensorflow">here</a>, created by research scientist <a href="https://github.com/taki0112">Junho Kim</a>! 🙏
|
||||
|
||||
<a href="https://github.com/conceptofmind/vit-flax">Flax translation</a> by <a href="https://github.com/conceptofmind">Enrico Shippole</a>!
|
||||
|
||||
## Install
|
||||
|
||||
```bash
|
||||
@@ -100,6 +110,33 @@ Embedding dropout rate.
|
||||
- `pool`: string, either `cls` token pooling or `mean` pooling
|
||||
|
||||
|
||||
## Simple ViT
|
||||
|
||||
<a href="https://arxiv.org/abs/2205.01580">An update</a> from some of the same authors of the original paper proposes simplifications to `ViT` that allows it to train faster and better.
|
||||
|
||||
Among these simplifications include 2d sinusoidal positional embedding, global average pooling (no CLS token), no dropout, batch sizes of 1024 rather than 4096, and use of RandAugment and MixUp augmentations. They also show that a simple linear at the end is not significantly worse than the original MLP head
|
||||
|
||||
You can use it by importing the `SimpleViT` as shown below
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch import SimpleViT
|
||||
|
||||
v = SimpleViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 16,
|
||||
mlp_dim = 2048
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 256, 256)
|
||||
|
||||
preds = v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## Distillation
|
||||
|
||||
<img src="./images/distill.png" width="300px"></img>
|
||||
@@ -556,13 +593,73 @@ model = ScalableViT(
|
||||
reduction_factor = (8, 4, 2, 1), # downsampling of the key / values in SSA. in the paper, this was represented as (reduction_factor ** -2)
|
||||
window_size = (64, 32, None, None), # window size of the IWSA at each stage. None means no windowing needed
|
||||
dropout = 0.1, # attention and feedforward dropout
|
||||
).cuda()
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 256, 256).cuda()
|
||||
img = torch.randn(1, 3, 256, 256)
|
||||
|
||||
preds = model(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## SepViT
|
||||
|
||||
<img src="./images/sep-vit.png" width="400px"></img>
|
||||
|
||||
Another <a href="https://arxiv.org/abs/2203.15380">Bytedance AI paper</a>, it proposes a depthwise-pointwise self-attention layer that seems largely inspired by mobilenet's depthwise-separable convolution. The most interesting aspect is the reuse of the feature map from the depthwise self-attention stage as the values for the pointwise self-attention, as shown in the diagram above.
|
||||
|
||||
I have decided to include only the version of `SepViT` with this specific self-attention layer, as the grouped attention layers are not remarkable nor novel, and the authors were not clear on how they treated the window tokens for the group self-attention layer. Besides, it seems like with `DSSA` layer alone, they were able to beat Swin.
|
||||
|
||||
ex. SepViT-Lite
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.sep_vit import SepViT
|
||||
|
||||
v = SepViT(
|
||||
num_classes = 1000,
|
||||
dim = 32, # dimensions of first stage, which doubles every stage (32, 64, 128, 256) for SepViT-Lite
|
||||
dim_head = 32, # attention head dimension
|
||||
heads = (1, 2, 4, 8), # number of heads per stage
|
||||
depth = (1, 2, 6, 2), # number of transformer blocks per stage
|
||||
window_size = 7, # window size of DSS Attention block
|
||||
dropout = 0.1 # dropout
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 224, 224)
|
||||
|
||||
preds = v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## MaxViT
|
||||
|
||||
<img src="./images/max-vit.png" width="400px"></img>
|
||||
|
||||
<a href="https://arxiv.org/abs/2204.01697">This paper</a> proposes a hybrid convolutional / attention network, using MBConv from the convolution side, and then block / grid axial sparse attention.
|
||||
|
||||
They also claim this specific vision transformer is good for generative models (GANs).
|
||||
|
||||
ex. MaxViT-S
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.max_vit import MaxViT
|
||||
|
||||
v = MaxViT(
|
||||
num_classes = 1000,
|
||||
dim_conv_stem = 64, # dimension of the convolutional stem, would default to dimension of first layer if not specified
|
||||
dim = 96, # dimension of first layer, doubles every layer
|
||||
dim_head = 32, # dimension of attention heads, kept at 32 in paper
|
||||
depth = (2, 2, 5, 2), # number of MaxViT blocks per stage, which consists of MBConv, block-like attention, grid-like attention
|
||||
window_size = 7, # window size for block and grids
|
||||
mbconv_expansion_rate = 4, # expansion rate of MBConv
|
||||
mbconv_shrinkage_rate = 0.25, # shrinkage rate of squeeze-excitation in MBConv
|
||||
dropout = 0.1 # dropout
|
||||
)
|
||||
|
||||
img = torch.randn(2, 3, 224, 224)
|
||||
|
||||
preds = v(img) # (2, 1000)
|
||||
```
|
||||
|
||||
## NesT
|
||||
|
||||
<img src="./images/nest.png" width="400px"></img>
|
||||
@@ -871,6 +968,37 @@ img = torch.randn(4, 3, 256, 256)
|
||||
tokens = spt(img) # (4, 256, 1024)
|
||||
```
|
||||
|
||||
## 3D ViT
|
||||
|
||||
By popular request, I will start extending a few of the architectures in this repository to 3D ViTs, for use with video, medical imaging, etc.
|
||||
|
||||
You will need to pass in two additional hyperparameters: (1) the number of frames `frames` and (2) patch size along the frame dimension `frame_patch_size`
|
||||
|
||||
For starters, with the most basic ViT
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.vit_3d import ViT
|
||||
|
||||
v = ViT(
|
||||
image_size = 128, # image size
|
||||
frames = 16, # number of frames
|
||||
image_patch_size = 16, # image patch size
|
||||
frame_patch_size = 2, # frame patch size
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width)
|
||||
|
||||
preds = v(video) # (4, 1000)
|
||||
```
|
||||
|
||||
## Parallel ViT
|
||||
|
||||
<img src="./images/parallel-vit.png" width="350px"></img>
|
||||
@@ -901,6 +1029,61 @@ img = torch.randn(4, 3, 256, 256)
|
||||
preds = v(img) # (4, 1000)
|
||||
```
|
||||
|
||||
## Learnable Memory ViT
|
||||
|
||||
<img src="./images/learnable-memory-vit.png" width="350px"></img>
|
||||
|
||||
This <a href="https://arxiv.org/abs/2203.15243">paper</a> shows that adding learnable memory tokens at each layer of a vision transformer can greatly enhance fine-tuning results (in addition to learnable task specific CLS token and adapter head).
|
||||
|
||||
You can use this with a specially modified `ViT` as follows
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.learnable_memory_vit import ViT, Adapter
|
||||
|
||||
# normal base ViT
|
||||
|
||||
v = ViT(
|
||||
image_size = 256,
|
||||
patch_size = 16,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
img = torch.randn(4, 3, 256, 256)
|
||||
logits = v(img) # (4, 1000)
|
||||
|
||||
# do your usual training with ViT
|
||||
# ...
|
||||
|
||||
|
||||
# then, to finetune, just pass the ViT into the Adapter class
|
||||
# you can do this for multiple Adapters, as shown below
|
||||
|
||||
adapter1 = Adapter(
|
||||
vit = v,
|
||||
num_classes = 2, # number of output classes for this specific task
|
||||
num_memories_per_layer = 5 # number of learnable memories per layer, 10 was sufficient in paper
|
||||
)
|
||||
|
||||
logits1 = adapter1(img) # (4, 2) - predict 2 classes off frozen ViT backbone with learnable memories and task specific head
|
||||
|
||||
# yet another task to finetune on, this time with 4 classes
|
||||
|
||||
adapter2 = Adapter(
|
||||
vit = v,
|
||||
num_classes = 4,
|
||||
num_memories_per_layer = 10
|
||||
)
|
||||
|
||||
logits2 = adapter2(img) # (4, 4) - predict 4 classes off frozen ViT backbone with learnable memories and task specific head
|
||||
|
||||
```
|
||||
|
||||
## Dino
|
||||
|
||||
@@ -956,6 +1139,80 @@ for _ in range(100):
|
||||
torch.save(model.state_dict(), './pretrained-net.pt')
|
||||
```
|
||||
|
||||
## EsViT
|
||||
|
||||
<img src="./images/esvit.png" width="350px"></img>
|
||||
|
||||
<a href="https://arxiv.org/abs/2106.09785">`EsViT`</a> is a variant of Dino (from above) re-engineered to support efficient `ViT`s with patch merging / downsampling by taking into an account an extra regional loss between the augmented views. To quote the abstract, it `outperforms its supervised counterpart on 17 out of 18 datasets` at 3 times higher throughput.
|
||||
|
||||
Even though it is named as though it were a new `ViT` variant, it actually is just a strategy for training any multistage `ViT` (in the paper, they focused on Swin). The example below will show how to use it with `CvT`. You'll need to set the `hidden_layer` to the name of the layer within your efficient ViT that outputs the non-average pooled visual representations, just before the global pooling and projection to logits.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.cvt import CvT
|
||||
from vit_pytorch.es_vit import EsViTTrainer
|
||||
|
||||
cvt = CvT(
|
||||
num_classes = 1000,
|
||||
s1_emb_dim = 64,
|
||||
s1_emb_kernel = 7,
|
||||
s1_emb_stride = 4,
|
||||
s1_proj_kernel = 3,
|
||||
s1_kv_proj_stride = 2,
|
||||
s1_heads = 1,
|
||||
s1_depth = 1,
|
||||
s1_mlp_mult = 4,
|
||||
s2_emb_dim = 192,
|
||||
s2_emb_kernel = 3,
|
||||
s2_emb_stride = 2,
|
||||
s2_proj_kernel = 3,
|
||||
s2_kv_proj_stride = 2,
|
||||
s2_heads = 3,
|
||||
s2_depth = 2,
|
||||
s2_mlp_mult = 4,
|
||||
s3_emb_dim = 384,
|
||||
s3_emb_kernel = 3,
|
||||
s3_emb_stride = 2,
|
||||
s3_proj_kernel = 3,
|
||||
s3_kv_proj_stride = 2,
|
||||
s3_heads = 4,
|
||||
s3_depth = 10,
|
||||
s3_mlp_mult = 4,
|
||||
dropout = 0.
|
||||
)
|
||||
|
||||
learner = EsViTTrainer(
|
||||
cvt,
|
||||
image_size = 256,
|
||||
hidden_layer = 'layers', # hidden layer name or index, from which to extract the embedding
|
||||
projection_hidden_size = 256, # projector network hidden dimension
|
||||
projection_layers = 4, # number of layers in projection network
|
||||
num_classes_K = 65336, # output logits dimensions (referenced as K in paper)
|
||||
student_temp = 0.9, # student temperature
|
||||
teacher_temp = 0.04, # teacher temperature, needs to be annealed from 0.04 to 0.07 over 30 epochs
|
||||
local_upper_crop_scale = 0.4, # upper bound for local crop - 0.4 was recommended in the paper
|
||||
global_lower_crop_scale = 0.5, # lower bound for global crop - 0.5 was recommended in the paper
|
||||
moving_average_decay = 0.9, # moving average of encoder - paper showed anywhere from 0.9 to 0.999 was ok
|
||||
center_moving_average_decay = 0.9, # moving average of teacher centers - paper showed anywhere from 0.9 to 0.999 was ok
|
||||
)
|
||||
|
||||
opt = torch.optim.AdamW(learner.parameters(), lr = 3e-4)
|
||||
|
||||
def sample_unlabelled_images():
|
||||
return torch.randn(8, 3, 256, 256)
|
||||
|
||||
for _ in range(1000):
|
||||
images = sample_unlabelled_images()
|
||||
loss = learner(images)
|
||||
opt.zero_grad()
|
||||
loss.backward()
|
||||
opt.step()
|
||||
learner.update_moving_average() # update moving average of teacher encoder and teacher centers
|
||||
|
||||
# save your improved network
|
||||
torch.save(cvt.state_dict(), './pretrained-net.pt')
|
||||
```
|
||||
|
||||
## Accessing Attention
|
||||
|
||||
If you would like to visualize the attention weights (post-softmax) for your research, just follow the procedure below
|
||||
@@ -1032,6 +1289,47 @@ logits, embeddings = v(img)
|
||||
embeddings # (1, 65, 1024) - (batch x patches x model dim)
|
||||
```
|
||||
|
||||
Or say for `CrossViT`, which has a multi-scale encoder that outputs two sets of embeddings for 'large' and 'small' scales
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.cross_vit import CrossViT
|
||||
|
||||
v = CrossViT(
|
||||
image_size = 256,
|
||||
num_classes = 1000,
|
||||
depth = 4,
|
||||
sm_dim = 192,
|
||||
sm_patch_size = 16,
|
||||
sm_enc_depth = 2,
|
||||
sm_enc_heads = 8,
|
||||
sm_enc_mlp_dim = 2048,
|
||||
lg_dim = 384,
|
||||
lg_patch_size = 64,
|
||||
lg_enc_depth = 3,
|
||||
lg_enc_heads = 8,
|
||||
lg_enc_mlp_dim = 2048,
|
||||
cross_attn_depth = 2,
|
||||
cross_attn_heads = 8,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
# wrap the CrossViT
|
||||
|
||||
from vit_pytorch.extractor import Extractor
|
||||
v = Extractor(v, layer_name = 'multi_scale_encoder') # take embedding coming from the output of multi-scale-encoder
|
||||
|
||||
# forward pass now returns predictions and the attention maps
|
||||
|
||||
img = torch.randn(1, 3, 256, 256)
|
||||
logits, embeddings = v(img)
|
||||
|
||||
# there is one extra token due to the CLS token
|
||||
|
||||
embeddings # ((1, 257, 192), (1, 17, 384)) - (batch x patches x dimension) <- large and small scales respectively
|
||||
```
|
||||
|
||||
## Research Ideas
|
||||
|
||||
### Efficient Attention
|
||||
@@ -1440,6 +1738,50 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Sandler2022FinetuningIT,
|
||||
title = {Fine-tuning Image Transformers using Learnable Memory},
|
||||
author = {Mark Sandler and Andrey Zhmoginov and Max Vladymyrov and Andrew Jackson},
|
||||
year = {2022}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Li2022SepViTSV,
|
||||
title = {SepViT: Separable Vision Transformer},
|
||||
author = {Wei Li and Xing Wang and Xin Xia and Jie Wu and Xuefeng Xiao and Minghang Zheng and Shiping Wen},
|
||||
year = {2022}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Tu2022MaxViTMV,
|
||||
title = {MaxViT: Multi-Axis Vision Transformer},
|
||||
author = {Zhengzhong Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
|
||||
year = {2022}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Li2021EfficientSV,
|
||||
title = {Efficient Self-supervised Vision Transformers for Representation Learning},
|
||||
author = {Chunyuan Li and Jianwei Yang and Pengchuan Zhang and Mei Gao and Bin Xiao and Xiyang Dai and Lu Yuan and Jianfeng Gao},
|
||||
journal = {ArXiv},
|
||||
year = {2021},
|
||||
volume = {abs/2106.09785}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{Beyer2022BetterPlainViT
|
||||
title = {Better plain ViT baselines for ImageNet-1k},
|
||||
author = {Beyer, Lucas and Zhai, Xiaohua and Kolesnikov, Alexander},
|
||||
publisher = {arXiv},
|
||||
year = {2022}
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{vaswani2017attention,
|
||||
title = {Attention Is All You Need},
|
||||
|
||||
BIN
images/esvit.png
Normal file
BIN
images/esvit.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 190 KiB |
BIN
images/learnable-memory-vit.png
Normal file
BIN
images/learnable-memory-vit.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 108 KiB |
BIN
images/max-vit.png
Normal file
BIN
images/max-vit.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 133 KiB |
BIN
images/sep-vit.png
Normal file
BIN
images/sep-vit.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 142 KiB |
5
setup.py
5
setup.py
@@ -3,9 +3,10 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.29.1',
|
||||
version = '0.36.0',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
long_description_content_type = 'text/markdown',
|
||||
author = 'Phil Wang',
|
||||
author_email = 'lucidrains@gmail.com',
|
||||
url = 'https://github.com/lucidrains/vit-pytorch',
|
||||
@@ -16,7 +17,7 @@ setup(
|
||||
],
|
||||
install_requires=[
|
||||
'einops>=0.4.1',
|
||||
'torch>=1.6',
|
||||
'torch>=1.10',
|
||||
'torchvision'
|
||||
],
|
||||
setup_requires=[
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from vit_pytorch.vit import ViT
|
||||
from vit_pytorch.simple_vit import SimpleViT
|
||||
|
||||
from vit_pytorch.mae import MAE
|
||||
from vit_pytorch.dino import Dino
|
||||
|
||||
@@ -139,6 +139,8 @@ class Attention(nn.Module):
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.output_num_tokens = output_num_tokens
|
||||
@@ -163,6 +165,7 @@ class Attention(nn.Module):
|
||||
dots = dots.masked_fill(~dots_mask, mask_value)
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
sampled_token_ids = None
|
||||
|
||||
|
||||
@@ -76,6 +76,7 @@ class Attention(nn.Module):
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.mix_heads_pre_attn = nn.Parameter(torch.randn(heads, heads))
|
||||
self.mix_heads_post_attn = nn.Parameter(torch.randn(heads, heads))
|
||||
@@ -96,7 +97,10 @@ class Attention(nn.Module):
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
|
||||
dots = einsum('b h i j, h g -> b g i j', dots, self.mix_heads_pre_attn) # talking heads, pre-softmax
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
attn = einsum('b h i j, h g -> b g i j', attn, self.mix_heads_post_attn) # talking heads, post-softmax
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
|
||||
@@ -48,6 +48,8 @@ class Attention(nn.Module):
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
|
||||
|
||||
@@ -69,6 +71,7 @@ class Attention(nn.Module):
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
|
||||
@@ -95,6 +95,9 @@ class Attention(nn.Module):
|
||||
self.window_size = window_size
|
||||
|
||||
self.norm = LayerNorm(dim)
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
|
||||
self.to_out = nn.Conv2d(inner_dim, dim, 1)
|
||||
|
||||
@@ -105,7 +108,7 @@ class Attention(nn.Module):
|
||||
# calculate and store indices for retrieving bias
|
||||
|
||||
pos = torch.arange(window_size)
|
||||
grid = torch.stack(torch.meshgrid(pos, pos))
|
||||
grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
|
||||
grid = rearrange(grid, 'c i j -> (i j) c')
|
||||
rel_pos = grid[:, None] - grid[None, :]
|
||||
rel_pos += window_size - 1
|
||||
@@ -141,7 +144,7 @@ class Attention(nn.Module):
|
||||
# add dynamic positional bias
|
||||
|
||||
pos = torch.arange(-wsz, wsz + 1, device = device)
|
||||
rel_pos = torch.stack(torch.meshgrid(pos, pos))
|
||||
rel_pos = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
|
||||
rel_pos = rearrange(rel_pos, 'c i j -> (i j) c')
|
||||
biases = self.dpb(rel_pos.float())
|
||||
rel_pos_bias = biases[self.rel_pos_indices]
|
||||
@@ -151,6 +154,7 @@ class Attention(nn.Module):
|
||||
# attend
|
||||
|
||||
attn = sim.softmax(dim = -1)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
# merge heads
|
||||
|
||||
|
||||
@@ -76,6 +76,7 @@ class Attention(nn.Module):
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_q = DepthWiseConv2d(dim, inner_dim, proj_kernel, padding = padding, stride = 1, bias = False)
|
||||
self.to_kv = DepthWiseConv2d(dim, inner_dim * 2, proj_kernel, padding = padding, stride = kv_proj_stride, bias = False)
|
||||
@@ -94,6 +95,7 @@ class Attention(nn.Module):
|
||||
dots = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = einsum('b i j, b j d -> b i d', attn, v)
|
||||
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, y = y)
|
||||
@@ -162,12 +164,14 @@ class CvT(nn.Module):
|
||||
|
||||
dim = config['emb_dim']
|
||||
|
||||
self.layers = nn.Sequential(
|
||||
*layers,
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
self.to_logits = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
Rearrange('... () () -> ...'),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
latents = self.layers(x)
|
||||
return self.to_logits(latents)
|
||||
|
||||
@@ -42,6 +42,8 @@ class Attention(nn.Module):
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.reattn_weights = nn.Parameter(torch.randn(heads, heads))
|
||||
|
||||
self.reattn_norm = nn.Sequential(
|
||||
@@ -64,6 +66,7 @@ class Attention(nn.Module):
|
||||
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
attn = dots.softmax(dim=-1)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
# re-attention
|
||||
|
||||
|
||||
367
vit_pytorch/es_vit.py
Normal file
367
vit_pytorch/es_vit.py
Normal file
@@ -0,0 +1,367 @@
|
||||
import copy
|
||||
import random
|
||||
from functools import wraps, partial
|
||||
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
from torchvision import transforms as T
|
||||
|
||||
from einops import rearrange, reduce, repeat
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, default):
|
||||
return val if exists(val) else default
|
||||
|
||||
def singleton(cache_key):
|
||||
def inner_fn(fn):
|
||||
@wraps(fn)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
instance = getattr(self, cache_key)
|
||||
if instance is not None:
|
||||
return instance
|
||||
|
||||
instance = fn(self, *args, **kwargs)
|
||||
setattr(self, cache_key, instance)
|
||||
return instance
|
||||
return wrapper
|
||||
return inner_fn
|
||||
|
||||
def get_module_device(module):
|
||||
return next(module.parameters()).device
|
||||
|
||||
def set_requires_grad(model, val):
|
||||
for p in model.parameters():
|
||||
p.requires_grad = val
|
||||
|
||||
# tensor related helpers
|
||||
|
||||
def log(t, eps = 1e-20):
|
||||
return torch.log(t + eps)
|
||||
|
||||
# loss function # (algorithm 1 in the paper)
|
||||
|
||||
def view_loss_fn(
|
||||
teacher_logits,
|
||||
student_logits,
|
||||
teacher_temp,
|
||||
student_temp,
|
||||
centers,
|
||||
eps = 1e-20
|
||||
):
|
||||
teacher_logits = teacher_logits.detach()
|
||||
student_probs = (student_logits / student_temp).softmax(dim = -1)
|
||||
teacher_probs = ((teacher_logits - centers) / teacher_temp).softmax(dim = -1)
|
||||
return - (teacher_probs * log(student_probs, eps)).sum(dim = -1).mean()
|
||||
|
||||
def region_loss_fn(
|
||||
teacher_logits,
|
||||
student_logits,
|
||||
teacher_latent,
|
||||
student_latent,
|
||||
teacher_temp,
|
||||
student_temp,
|
||||
centers,
|
||||
eps = 1e-20
|
||||
):
|
||||
teacher_logits = teacher_logits.detach()
|
||||
student_probs = (student_logits / student_temp).softmax(dim = -1)
|
||||
teacher_probs = ((teacher_logits - centers) / teacher_temp).softmax(dim = -1)
|
||||
|
||||
sim_matrix = einsum('b i d, b j d -> b i j', student_latent, teacher_latent)
|
||||
sim_indices = sim_matrix.max(dim = -1).indices
|
||||
sim_indices = repeat(sim_indices, 'b n -> b n k', k = teacher_probs.shape[-1])
|
||||
max_sim_teacher_probs = teacher_probs.gather(1, sim_indices)
|
||||
|
||||
return - (max_sim_teacher_probs * log(student_probs, eps)).sum(dim = -1).mean()
|
||||
|
||||
# augmentation utils
|
||||
|
||||
class RandomApply(nn.Module):
|
||||
def __init__(self, fn, p):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.p = p
|
||||
|
||||
def forward(self, x):
|
||||
if random.random() > self.p:
|
||||
return x
|
||||
return self.fn(x)
|
||||
|
||||
# exponential moving average
|
||||
|
||||
class EMA():
|
||||
def __init__(self, beta):
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
|
||||
def update_average(self, old, new):
|
||||
if old is None:
|
||||
return new
|
||||
return old * self.beta + (1 - self.beta) * new
|
||||
|
||||
def update_moving_average(ema_updater, ma_model, current_model):
|
||||
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
|
||||
old_weight, up_weight = ma_params.data, current_params.data
|
||||
ma_params.data = ema_updater.update_average(old_weight, up_weight)
|
||||
|
||||
# MLP class for projector and predictor
|
||||
|
||||
class L2Norm(nn.Module):
|
||||
def forward(self, x, eps = 1e-6):
|
||||
return F.normalize(x, dim = 1, eps = eps)
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, dim_out, num_layers, hidden_size = 256):
|
||||
super().__init__()
|
||||
|
||||
layers = []
|
||||
dims = (dim, *((hidden_size,) * (num_layers - 1)))
|
||||
|
||||
for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
is_last = ind == (len(dims) - 1)
|
||||
|
||||
layers.extend([
|
||||
nn.Linear(layer_dim_in, layer_dim_out),
|
||||
nn.GELU() if not is_last else nn.Identity()
|
||||
])
|
||||
|
||||
self.net = nn.Sequential(
|
||||
*layers,
|
||||
L2Norm(),
|
||||
nn.Linear(hidden_size, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
# a wrapper class for the base neural network
|
||||
# will manage the interception of the hidden layer output
|
||||
# and pipe it into the projecter and predictor nets
|
||||
|
||||
class NetWrapper(nn.Module):
|
||||
def __init__(self, net, output_dim, projection_hidden_size, projection_num_layers, layer = -2):
|
||||
super().__init__()
|
||||
self.net = net
|
||||
self.layer = layer
|
||||
|
||||
self.view_projector = None
|
||||
self.region_projector = None
|
||||
self.projection_hidden_size = projection_hidden_size
|
||||
self.projection_num_layers = projection_num_layers
|
||||
self.output_dim = output_dim
|
||||
|
||||
self.hidden = {}
|
||||
self.hook_registered = False
|
||||
|
||||
def _find_layer(self):
|
||||
if type(self.layer) == str:
|
||||
modules = dict([*self.net.named_modules()])
|
||||
return modules.get(self.layer, None)
|
||||
elif type(self.layer) == int:
|
||||
children = [*self.net.children()]
|
||||
return children[self.layer]
|
||||
return None
|
||||
|
||||
def _hook(self, _, input, output):
|
||||
device = input[0].device
|
||||
self.hidden[device] = output
|
||||
|
||||
def _register_hook(self):
|
||||
layer = self._find_layer()
|
||||
assert layer is not None, f'hidden layer ({self.layer}) not found'
|
||||
handle = layer.register_forward_hook(self._hook)
|
||||
self.hook_registered = True
|
||||
|
||||
@singleton('view_projector')
|
||||
def _get_view_projector(self, hidden):
|
||||
dim = hidden.shape[1]
|
||||
projector = MLP(dim, self.output_dim, self.projection_num_layers, self.projection_hidden_size)
|
||||
return projector.to(hidden)
|
||||
|
||||
@singleton('region_projector')
|
||||
def _get_region_projector(self, hidden):
|
||||
dim = hidden.shape[1]
|
||||
projector = MLP(dim, self.output_dim, self.projection_num_layers, self.projection_hidden_size)
|
||||
return projector.to(hidden)
|
||||
|
||||
def get_embedding(self, x):
|
||||
if self.layer == -1:
|
||||
return self.net(x)
|
||||
|
||||
if not self.hook_registered:
|
||||
self._register_hook()
|
||||
|
||||
self.hidden.clear()
|
||||
_ = self.net(x)
|
||||
hidden = self.hidden[x.device]
|
||||
self.hidden.clear()
|
||||
|
||||
assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
|
||||
return hidden
|
||||
|
||||
def forward(self, x, return_projection = True):
|
||||
region_latents = self.get_embedding(x)
|
||||
global_latent = reduce(region_latents, 'b c h w -> b c', 'mean')
|
||||
|
||||
if not return_projection:
|
||||
return global_latent, region_latents
|
||||
|
||||
view_projector = self._get_view_projector(global_latent)
|
||||
region_projector = self._get_region_projector(region_latents)
|
||||
|
||||
region_latents = rearrange(region_latents, 'b c h w -> b (h w) c')
|
||||
|
||||
return view_projector(global_latent), region_projector(region_latents), region_latents
|
||||
|
||||
# main class
|
||||
|
||||
class EsViTTrainer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
net,
|
||||
image_size,
|
||||
hidden_layer = -2,
|
||||
projection_hidden_size = 256,
|
||||
num_classes_K = 65336,
|
||||
projection_layers = 4,
|
||||
student_temp = 0.9,
|
||||
teacher_temp = 0.04,
|
||||
local_upper_crop_scale = 0.4,
|
||||
global_lower_crop_scale = 0.5,
|
||||
moving_average_decay = 0.9,
|
||||
center_moving_average_decay = 0.9,
|
||||
augment_fn = None,
|
||||
augment_fn2 = None
|
||||
):
|
||||
super().__init__()
|
||||
self.net = net
|
||||
|
||||
# default BYOL augmentation
|
||||
|
||||
DEFAULT_AUG = torch.nn.Sequential(
|
||||
RandomApply(
|
||||
T.ColorJitter(0.8, 0.8, 0.8, 0.2),
|
||||
p = 0.3
|
||||
),
|
||||
T.RandomGrayscale(p=0.2),
|
||||
T.RandomHorizontalFlip(),
|
||||
RandomApply(
|
||||
T.GaussianBlur((3, 3), (1.0, 2.0)),
|
||||
p = 0.2
|
||||
),
|
||||
T.Normalize(
|
||||
mean=torch.tensor([0.485, 0.456, 0.406]),
|
||||
std=torch.tensor([0.229, 0.224, 0.225])),
|
||||
)
|
||||
|
||||
self.augment1 = default(augment_fn, DEFAULT_AUG)
|
||||
self.augment2 = default(augment_fn2, DEFAULT_AUG)
|
||||
|
||||
# local and global crops
|
||||
|
||||
self.local_crop = T.RandomResizedCrop((image_size, image_size), scale = (0.05, local_upper_crop_scale))
|
||||
self.global_crop = T.RandomResizedCrop((image_size, image_size), scale = (global_lower_crop_scale, 1.))
|
||||
|
||||
self.student_encoder = NetWrapper(net, num_classes_K, projection_hidden_size, projection_layers, layer = hidden_layer)
|
||||
|
||||
self.teacher_encoder = None
|
||||
self.teacher_ema_updater = EMA(moving_average_decay)
|
||||
|
||||
self.register_buffer('teacher_view_centers', torch.zeros(1, num_classes_K))
|
||||
self.register_buffer('last_teacher_view_centers', torch.zeros(1, num_classes_K))
|
||||
|
||||
self.register_buffer('teacher_region_centers', torch.zeros(1, num_classes_K))
|
||||
self.register_buffer('last_teacher_region_centers', torch.zeros(1, num_classes_K))
|
||||
|
||||
self.teacher_centering_ema_updater = EMA(center_moving_average_decay)
|
||||
|
||||
self.student_temp = student_temp
|
||||
self.teacher_temp = teacher_temp
|
||||
|
||||
# get device of network and make wrapper same device
|
||||
device = get_module_device(net)
|
||||
self.to(device)
|
||||
|
||||
# send a mock image tensor to instantiate singleton parameters
|
||||
self.forward(torch.randn(2, 3, image_size, image_size, device=device))
|
||||
|
||||
@singleton('teacher_encoder')
|
||||
def _get_teacher_encoder(self):
|
||||
teacher_encoder = copy.deepcopy(self.student_encoder)
|
||||
set_requires_grad(teacher_encoder, False)
|
||||
return teacher_encoder
|
||||
|
||||
def reset_moving_average(self):
|
||||
del self.teacher_encoder
|
||||
self.teacher_encoder = None
|
||||
|
||||
def update_moving_average(self):
|
||||
assert self.teacher_encoder is not None, 'target encoder has not been created yet'
|
||||
update_moving_average(self.teacher_ema_updater, self.teacher_encoder, self.student_encoder)
|
||||
|
||||
new_teacher_view_centers = self.teacher_centering_ema_updater.update_average(self.teacher_view_centers, self.last_teacher_view_centers)
|
||||
self.teacher_view_centers.copy_(new_teacher_view_centers)
|
||||
|
||||
new_teacher_region_centers = self.teacher_centering_ema_updater.update_average(self.teacher_region_centers, self.last_teacher_region_centers)
|
||||
self.teacher_region_centers.copy_(new_teacher_region_centers)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
return_embedding = False,
|
||||
return_projection = True,
|
||||
student_temp = None,
|
||||
teacher_temp = None
|
||||
):
|
||||
if return_embedding:
|
||||
return self.student_encoder(x, return_projection = return_projection)
|
||||
|
||||
image_one, image_two = self.augment1(x), self.augment2(x)
|
||||
|
||||
local_image_one, local_image_two = self.local_crop(image_one), self.local_crop(image_two)
|
||||
global_image_one, global_image_two = self.global_crop(image_one), self.global_crop(image_two)
|
||||
|
||||
student_view_proj_one, student_region_proj_one, student_latent_one = self.student_encoder(local_image_one)
|
||||
student_view_proj_two, student_region_proj_two, student_latent_two = self.student_encoder(local_image_two)
|
||||
|
||||
with torch.no_grad():
|
||||
teacher_encoder = self._get_teacher_encoder()
|
||||
teacher_view_proj_one, teacher_region_proj_one, teacher_latent_one = teacher_encoder(global_image_one)
|
||||
teacher_view_proj_two, teacher_region_proj_two, teacher_latent_two = teacher_encoder(global_image_two)
|
||||
|
||||
view_loss_fn_ = partial(
|
||||
view_loss_fn,
|
||||
student_temp = default(student_temp, self.student_temp),
|
||||
teacher_temp = default(teacher_temp, self.teacher_temp),
|
||||
centers = self.teacher_view_centers
|
||||
)
|
||||
|
||||
region_loss_fn_ = partial(
|
||||
region_loss_fn,
|
||||
student_temp = default(student_temp, self.student_temp),
|
||||
teacher_temp = default(teacher_temp, self.teacher_temp),
|
||||
centers = self.teacher_region_centers
|
||||
)
|
||||
|
||||
# calculate view-level loss
|
||||
|
||||
teacher_view_logits_avg = torch.cat((teacher_view_proj_one, teacher_view_proj_two)).mean(dim = 0)
|
||||
self.last_teacher_view_centers.copy_(teacher_view_logits_avg)
|
||||
|
||||
teacher_region_logits_avg = torch.cat((teacher_region_proj_one, teacher_region_proj_two)).mean(dim = (0, 1))
|
||||
self.last_teacher_region_centers.copy_(teacher_region_logits_avg)
|
||||
|
||||
view_loss = (view_loss_fn_(teacher_view_proj_one, student_view_proj_two) \
|
||||
+ view_loss_fn_(teacher_view_proj_two, student_view_proj_one)) / 2
|
||||
|
||||
# calculate region-level loss
|
||||
|
||||
region_loss = (region_loss_fn_(teacher_region_proj_one, student_region_proj_two, teacher_latent_one, student_latent_two) \
|
||||
+ region_loss_fn_(teacher_region_proj_two, student_region_proj_one, teacher_latent_two, student_latent_one)) / 2
|
||||
|
||||
return (view_loss + region_loss) / 2
|
||||
@@ -4,14 +4,27 @@ from torch import nn
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def identity(t):
|
||||
return t
|
||||
|
||||
def clone_and_detach(t):
|
||||
return t.clone().detach()
|
||||
|
||||
def apply_tuple_or_single(fn, val):
|
||||
if isinstance(val, tuple):
|
||||
return tuple(map(fn, val))
|
||||
return fn(val)
|
||||
|
||||
class Extractor(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vit,
|
||||
device = None,
|
||||
layer = None,
|
||||
layer_name = 'transformer',
|
||||
layer_save_input = False,
|
||||
return_embeddings_only = False
|
||||
return_embeddings_only = False,
|
||||
detach = True
|
||||
):
|
||||
super().__init__()
|
||||
self.vit = vit
|
||||
@@ -23,17 +36,24 @@ class Extractor(nn.Module):
|
||||
self.ejected = False
|
||||
self.device = device
|
||||
|
||||
self.layer = layer
|
||||
self.layer_name = layer_name
|
||||
self.layer_save_input = layer_save_input # whether to save input or output of layer
|
||||
self.return_embeddings_only = return_embeddings_only
|
||||
|
||||
self.detach_fn = clone_and_detach if detach else identity
|
||||
|
||||
def _hook(self, _, inputs, output):
|
||||
tensor_to_save = inputs if self.layer_save_input else output
|
||||
self.latents = tensor_to_save.clone().detach()
|
||||
layer_output = inputs if self.layer_save_input else output
|
||||
self.latents = apply_tuple_or_single(self.detach_fn, layer_output)
|
||||
|
||||
def _register_hook(self):
|
||||
assert hasattr(self.vit, self.layer_name), 'layer whose output to take as embedding not found in vision transformer'
|
||||
layer = getattr(self.vit, self.layer_name)
|
||||
if not exists(self.layer):
|
||||
assert hasattr(self.vit, self.layer_name), 'layer whose output to take as embedding not found in vision transformer'
|
||||
layer = getattr(self.vit, self.layer_name)
|
||||
else:
|
||||
layer = self.layer
|
||||
|
||||
handle = layer.register_forward_hook(self._hook)
|
||||
self.hooks.append(handle)
|
||||
self.hook_registered = True
|
||||
@@ -62,7 +82,7 @@ class Extractor(nn.Module):
|
||||
pred = self.vit(img)
|
||||
|
||||
target_device = self.device if exists(self.device) else img.device
|
||||
latents = self.latents.to(target_device)
|
||||
latents = apply_tuple_or_single(lambda t: t.to(target_device), self.latents)
|
||||
|
||||
if return_embeddings_only or self.return_embeddings_only:
|
||||
return latents
|
||||
|
||||
216
vit_pytorch/learnable_memory_vit.py
Normal file
216
vit_pytorch/learnable_memory_vit.py
Normal file
@@ -0,0 +1,216 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
# controlling freezing of layers
|
||||
|
||||
def set_module_requires_grad_(module, requires_grad):
|
||||
for param in module.parameters():
|
||||
param.requires_grad = requires_grad
|
||||
|
||||
def freeze_all_layers_(module):
|
||||
set_module_requires_grad_(module, False)
|
||||
|
||||
def unfreeze_all_layers_(module):
|
||||
set_module_requires_grad_(module, True)
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, attn_mask = None, memories = None):
|
||||
x = self.norm(x)
|
||||
|
||||
x_kv = x # input for key / values projection
|
||||
|
||||
if exists(memories):
|
||||
# add memories to key / values if it is passed in
|
||||
memories = repeat(memories, 'n d -> b n d', b = x.shape[0]) if memories.ndim == 2 else memories
|
||||
x_kv = torch.cat((x_kv, memories), dim = 1)
|
||||
|
||||
qkv = (self.to_q(x), *self.to_kv(x_kv).chunk(2, dim = -1))
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
if exists(attn_mask):
|
||||
dots = dots.masked_fill(~attn_mask, -torch.finfo(dots.dtype).max)
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
def forward(self, x, attn_mask = None, memories = None):
|
||||
for ind, (attn, ff) in enumerate(self.layers):
|
||||
layer_memories = memories[ind] if exists(memories) else None
|
||||
|
||||
x = attn(x, attn_mask = attn_mask, memories = layer_memories) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
class ViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
|
||||
nn.Linear(patch_dim, dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def img_to_tokens(self, img):
|
||||
x = self.to_patch_embedding(img)
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = x.shape[0])
|
||||
x = torch.cat((cls_tokens, x), dim = 1)
|
||||
|
||||
x += self.pos_embedding
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
def forward(self, img):
|
||||
x = self.img_to_tokens(img)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
cls_tokens = x[:, 0]
|
||||
return self.mlp_head(cls_tokens)
|
||||
|
||||
# adapter with learnable memories per layer, memory CLS token, and learnable adapter head
|
||||
|
||||
class Adapter(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vit,
|
||||
num_memories_per_layer = 10,
|
||||
num_classes = 2,
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(vit, ViT)
|
||||
|
||||
# extract some model variables needed
|
||||
|
||||
dim = vit.cls_token.shape[-1]
|
||||
layers = len(vit.transformer.layers)
|
||||
num_patches = vit.pos_embedding.shape[-2]
|
||||
|
||||
self.vit = vit
|
||||
|
||||
# freeze ViT backbone - only memories will be finetuned
|
||||
|
||||
freeze_all_layers_(vit)
|
||||
|
||||
# learnable parameters
|
||||
|
||||
self.memory_cls_token = nn.Parameter(torch.randn(dim))
|
||||
self.memories_per_layer = nn.Parameter(torch.randn(layers, num_memories_per_layer, dim))
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
# specialized attention mask to preserve the output of the original ViT
|
||||
# it allows the memory CLS token to attend to all other tokens (and the learnable memory layer tokens), but not vice versa
|
||||
|
||||
attn_mask = torch.ones((num_patches, num_patches), dtype = torch.bool)
|
||||
attn_mask = F.pad(attn_mask, (1, num_memories_per_layer), value = False) # main tokens cannot attend to learnable memories per layer
|
||||
attn_mask = F.pad(attn_mask, (0, 0, 1, 0), value = True) # memory CLS token can attend to everything
|
||||
self.register_buffer('attn_mask', attn_mask)
|
||||
|
||||
def forward(self, img):
|
||||
b = img.shape[0]
|
||||
|
||||
tokens = self.vit.img_to_tokens(img)
|
||||
|
||||
# add task specific memory tokens
|
||||
|
||||
memory_cls_tokens = repeat(self.memory_cls_token, 'd -> b 1 d', b = b)
|
||||
tokens = torch.cat((memory_cls_tokens, tokens), dim = 1)
|
||||
|
||||
# pass memories along with image tokens through transformer for attending
|
||||
|
||||
out = self.vit.transformer(tokens, memories = self.memories_per_layer, attn_mask = self.attn_mask)
|
||||
|
||||
# extract memory CLS tokens
|
||||
|
||||
memory_cls_tokens = out[:, 0]
|
||||
|
||||
# pass through task specific adapter head
|
||||
|
||||
return self.mlp_head(memory_cls_tokens)
|
||||
@@ -52,6 +52,7 @@ class Attention(nn.Module):
|
||||
self.to_v = nn.Sequential(nn.Conv2d(dim, inner_dim_value, 1, bias = False), nn.BatchNorm2d(inner_dim_value))
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
out_batch_norm = nn.BatchNorm2d(dim_out)
|
||||
nn.init.zeros_(out_batch_norm.weight)
|
||||
@@ -70,8 +71,8 @@ class Attention(nn.Module):
|
||||
q_range = torch.arange(0, fmap_size, step = (2 if downsample else 1))
|
||||
k_range = torch.arange(fmap_size)
|
||||
|
||||
q_pos = torch.stack(torch.meshgrid(q_range, q_range), dim = -1)
|
||||
k_pos = torch.stack(torch.meshgrid(k_range, k_range), dim = -1)
|
||||
q_pos = torch.stack(torch.meshgrid(q_range, q_range, indexing = 'ij'), dim = -1)
|
||||
k_pos = torch.stack(torch.meshgrid(k_range, k_range, indexing = 'ij'), dim = -1)
|
||||
|
||||
q_pos, k_pos = map(lambda t: rearrange(t, 'i j c -> (i j) c'), (q_pos, k_pos))
|
||||
rel_pos = (q_pos[:, None, ...] - k_pos[None, :, ...]).abs()
|
||||
@@ -100,6 +101,7 @@ class Attention(nn.Module):
|
||||
dots = self.apply_pos_bias(dots)
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h (x y) d -> b (h d) x y', h = h, y = y)
|
||||
|
||||
@@ -78,6 +78,7 @@ class Attention(nn.Module):
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
@@ -93,6 +94,7 @@ class Attention(nn.Module):
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
|
||||
288
vit_pytorch/max_vit.py
Normal file
288
vit_pytorch/max_vit.py
Normal file
@@ -0,0 +1,288 @@
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange, Reduce
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def cast_tuple(val, length = 1):
|
||||
return val if isinstance(val, tuple) else ((val,) * length)
|
||||
|
||||
# helper classes
|
||||
|
||||
class PreNormResidual(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x):
|
||||
return self.fn(self.norm(x)) + x
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, mult = 4, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, inner_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
# MBConv
|
||||
|
||||
class SqueezeExcitation(nn.Module):
|
||||
def __init__(self, dim, shrinkage_rate = 0.25):
|
||||
super().__init__()
|
||||
hidden_dim = int(dim * shrinkage_rate)
|
||||
|
||||
self.gate = nn.Sequential(
|
||||
Reduce('b c h w -> b c', 'mean'),
|
||||
nn.Linear(dim, hidden_dim, bias = False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_dim, dim, bias = False),
|
||||
nn.Sigmoid(),
|
||||
Rearrange('b c -> b c 1 1')
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.gate(x)
|
||||
|
||||
|
||||
class MBConvResidual(nn.Module):
|
||||
def __init__(self, fn, dropout = 0.):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.dropsample = Dropsample(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.fn(x)
|
||||
out = self.dropsample(out)
|
||||
return out + x
|
||||
|
||||
class Dropsample(nn.Module):
|
||||
def __init__(self, prob = 0):
|
||||
super().__init__()
|
||||
self.prob = prob
|
||||
|
||||
def forward(self, x):
|
||||
device = x.device
|
||||
|
||||
if self.prob == 0. or (not self.training):
|
||||
return x
|
||||
|
||||
keep_mask = torch.FloatTensor((x.shape[0], 1, 1, 1), device = device).uniform_() > self.prob
|
||||
return x * keep_mask / (1 - self.prob)
|
||||
|
||||
def MBConv(
|
||||
dim_in,
|
||||
dim_out,
|
||||
*,
|
||||
downsample,
|
||||
expansion_rate = 4,
|
||||
shrinkage_rate = 0.25,
|
||||
dropout = 0.
|
||||
):
|
||||
hidden_dim = int(expansion_rate * dim_out)
|
||||
stride = 2 if downsample else 1
|
||||
|
||||
net = nn.Sequential(
|
||||
nn.Conv2d(dim_in, hidden_dim, 1),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride = stride, padding = 1, groups = hidden_dim),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.GELU(),
|
||||
SqueezeExcitation(hidden_dim, shrinkage_rate = shrinkage_rate),
|
||||
nn.Conv2d(hidden_dim, dim_out, 1),
|
||||
nn.BatchNorm2d(dim_out)
|
||||
)
|
||||
|
||||
if dim_in == dim_out and not downsample:
|
||||
net = MBConvResidual(net, dropout = dropout)
|
||||
|
||||
return net
|
||||
|
||||
# attention related classes
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_head = 32,
|
||||
dropout = 0.,
|
||||
window_size = 7
|
||||
):
|
||||
super().__init__()
|
||||
assert (dim % dim_head) == 0, 'dimension should be divisible by dimension per head'
|
||||
|
||||
self.heads = dim // dim_head
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
|
||||
|
||||
self.attend = nn.Sequential(
|
||||
nn.Softmax(dim = -1),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(dim, dim, bias = False),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
# relative positional bias
|
||||
|
||||
self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads)
|
||||
|
||||
pos = torch.arange(window_size)
|
||||
grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
|
||||
grid = rearrange(grid, 'c i j -> (i j) c')
|
||||
rel_pos = rearrange(grid, 'i ... -> i 1 ...') - rearrange(grid, 'j ... -> 1 j ...')
|
||||
rel_pos += window_size - 1
|
||||
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1)
|
||||
|
||||
self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False)
|
||||
|
||||
def forward(self, x):
|
||||
batch, height, width, window_height, window_width, _, device, h = *x.shape, x.device, self.heads
|
||||
|
||||
# flatten
|
||||
|
||||
x = rearrange(x, 'b x y w1 w2 d -> (b x y) (w1 w2) d')
|
||||
|
||||
# project for queries, keys, values
|
||||
|
||||
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
|
||||
|
||||
# split heads
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d ) -> b h n d', h = h), (q, k, v))
|
||||
|
||||
# scale
|
||||
|
||||
q = q * self.scale
|
||||
|
||||
# sim
|
||||
|
||||
sim = einsum('b h i d, b h j d -> b h i j', q, k)
|
||||
|
||||
# add positional bias
|
||||
|
||||
bias = self.rel_pos_bias(self.rel_pos_indices)
|
||||
sim = sim + rearrange(bias, 'i j h -> h i j')
|
||||
|
||||
# attention
|
||||
|
||||
attn = self.attend(sim)
|
||||
|
||||
# aggregate
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
|
||||
# merge heads
|
||||
|
||||
out = rearrange(out, 'b h (w1 w2) d -> b w1 w2 (h d)', w1 = window_height, w2 = window_width)
|
||||
|
||||
# combine heads out
|
||||
|
||||
out = self.to_out(out)
|
||||
return rearrange(out, '(b x y) ... -> b x y ...', x = height, y = width)
|
||||
|
||||
class MaxViT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
num_classes,
|
||||
dim,
|
||||
depth,
|
||||
dim_head = 32,
|
||||
dim_conv_stem = None,
|
||||
window_size = 7,
|
||||
mbconv_expansion_rate = 4,
|
||||
mbconv_shrinkage_rate = 0.25,
|
||||
dropout = 0.1,
|
||||
channels = 3
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'
|
||||
|
||||
# convolutional stem
|
||||
|
||||
dim_conv_stem = default(dim_conv_stem, dim)
|
||||
|
||||
self.conv_stem = nn.Sequential(
|
||||
nn.Conv2d(channels, dim_conv_stem, 3, stride = 2, padding = 1),
|
||||
nn.Conv2d(dim_conv_stem, dim_conv_stem, 3, padding = 1)
|
||||
)
|
||||
|
||||
# variables
|
||||
|
||||
num_stages = len(depth)
|
||||
|
||||
dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))
|
||||
dims = (dim_conv_stem, *dims)
|
||||
dim_pairs = tuple(zip(dims[:-1], dims[1:]))
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
# shorthand for window size for efficient block - grid like attention
|
||||
|
||||
w = window_size
|
||||
|
||||
# iterate through stages
|
||||
|
||||
for ind, ((layer_dim_in, layer_dim), layer_depth) in enumerate(zip(dim_pairs, depth)):
|
||||
for stage_ind in range(layer_depth):
|
||||
is_first = stage_ind == 0
|
||||
stage_dim_in = layer_dim_in if is_first else layer_dim
|
||||
|
||||
block = nn.Sequential(
|
||||
MBConv(
|
||||
stage_dim_in,
|
||||
layer_dim,
|
||||
downsample = is_first,
|
||||
expansion_rate = mbconv_expansion_rate,
|
||||
shrinkage_rate = mbconv_shrinkage_rate
|
||||
),
|
||||
Rearrange('b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w), # block-like attention
|
||||
PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
|
||||
PreNormResidual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
|
||||
Rearrange('b x y w1 w2 d -> b d (x w1) (y w2)'),
|
||||
|
||||
Rearrange('b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w), # grid-like attention
|
||||
PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
|
||||
PreNormResidual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
|
||||
Rearrange('b x y w1 w2 d -> b d (w1 x) (w2 y)'),
|
||||
)
|
||||
|
||||
self.layers.append(block)
|
||||
|
||||
# mlp head out
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
Reduce('b d h w -> b d', 'mean'),
|
||||
nn.LayerNorm(dims[-1]),
|
||||
nn.Linear(dims[-1], num_classes)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_stem(x)
|
||||
|
||||
for stage in self.layers:
|
||||
x = stage(x)
|
||||
|
||||
return self.mlp_head(x)
|
||||
@@ -54,6 +54,8 @@ class Attention(nn.Module):
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim=-1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
@@ -67,7 +69,10 @@ class Attention(nn.Module):
|
||||
t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b p h n d -> b p n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
@@ -55,6 +55,7 @@ class Attention(nn.Module):
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
@@ -71,6 +72,7 @@ class Attention(nn.Module):
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
|
||||
|
||||
@@ -50,6 +50,8 @@ class Attention(nn.Module):
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
@@ -64,6 +66,7 @@ class Attention(nn.Module):
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
|
||||
@@ -48,6 +48,7 @@ class Attention(nn.Module):
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
@@ -63,6 +64,7 @@ class Attention(nn.Module):
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
|
||||
@@ -61,8 +61,13 @@ class Attention(nn.Module):
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, rel_pos_bias = None):
|
||||
h = self.heads
|
||||
@@ -86,6 +91,7 @@ class Attention(nn.Module):
|
||||
sim = sim + rel_pos_bias
|
||||
|
||||
attn = sim.softmax(dim = -1)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
# merge heads
|
||||
|
||||
@@ -132,7 +138,7 @@ class R2LTransformer(nn.Module):
|
||||
h_range = torch.arange(window_size_h, device = device)
|
||||
w_range = torch.arange(window_size_w, device = device)
|
||||
|
||||
grid_x, grid_y = torch.meshgrid(h_range, w_range)
|
||||
grid_x, grid_y = torch.meshgrid(h_range, w_range, indexing = 'ij')
|
||||
grid = torch.stack((grid_x, grid_y))
|
||||
grid = rearrange(grid, 'c h w -> c (h w)')
|
||||
grid = (grid[:, :, None] - grid[:, None, :]) + (self.window_size - 1)
|
||||
|
||||
@@ -104,6 +104,7 @@ class Attention(nn.Module):
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.use_ds_conv = use_ds_conv
|
||||
|
||||
@@ -148,6 +149,7 @@ class Attention(nn.Module):
|
||||
dots = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = einsum('b i j, b j d -> b i d', attn, v)
|
||||
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
|
||||
|
||||
@@ -90,6 +90,7 @@ class ScalableSelfAttention(nn.Module):
|
||||
self.heads = heads
|
||||
self.scale = dim_key ** -0.5
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_q = nn.Conv2d(dim, dim_key * heads, 1, bias = False)
|
||||
self.to_k = nn.Conv2d(dim, dim_key * heads, reduction_factor, stride = reduction_factor, bias = False)
|
||||
@@ -116,6 +117,7 @@ class ScalableSelfAttention(nn.Module):
|
||||
# attention
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
# aggregate values
|
||||
|
||||
@@ -141,6 +143,7 @@ class InteractiveWindowedSelfAttention(nn.Module):
|
||||
self.scale = dim_key ** -0.5
|
||||
self.window_size = window_size
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.local_interactive_module = nn.Conv2d(dim_value * heads, dim_value * heads, 3, padding = 1)
|
||||
|
||||
@@ -176,6 +179,7 @@ class InteractiveWindowedSelfAttention(nn.Module):
|
||||
# attention
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
# aggregate values
|
||||
|
||||
@@ -280,7 +284,7 @@ class ScalableViT(nn.Module):
|
||||
is_last = ind == (num_stages - 1)
|
||||
|
||||
self.layers.append(nn.ModuleList([
|
||||
Transformer(dim = layer_dim, depth = layer_depth, heads = layer_heads, ff_expansion_factor = ff_expansion_factor, dropout = dropout, ssa_dim_key = layer_ssa_dim_key, ssa_dim_value = layer_ssa_dim_value, ssa_reduction_factor = layer_ssa_reduction_factor, iwsa_dim_key = layer_iwsa_dim_key, iwsa_dim_value = layer_iwsa_dim_value, iwsa_window_size = layer_window_size),
|
||||
Transformer(dim = layer_dim, depth = layer_depth, heads = layer_heads, ff_expansion_factor = ff_expansion_factor, dropout = dropout, ssa_dim_key = layer_ssa_dim_key, ssa_dim_value = layer_ssa_dim_value, ssa_reduction_factor = layer_ssa_reduction_factor, iwsa_dim_key = layer_iwsa_dim_key, iwsa_dim_value = layer_iwsa_dim_value, iwsa_window_size = layer_window_size, norm_output = not is_last),
|
||||
Downsample(layer_dim, layer_dim * 2) if not is_last else None
|
||||
]))
|
||||
|
||||
|
||||
294
vit_pytorch/sep_vit.py
Normal file
294
vit_pytorch/sep_vit.py
Normal file
@@ -0,0 +1,294 @@
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange, Reduce
|
||||
|
||||
# helpers
|
||||
|
||||
def cast_tuple(val, length = 1):
|
||||
return val if isinstance(val, tuple) else ((val,) * length)
|
||||
|
||||
# helper classes
|
||||
|
||||
class ChanLayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps = 1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = ChanLayerNorm(dim)
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x):
|
||||
return self.fn(self.norm(x))
|
||||
|
||||
class OverlappingPatchEmbed(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, stride = 2):
|
||||
super().__init__()
|
||||
kernel_size = stride * 2 - 1
|
||||
padding = kernel_size // 2
|
||||
self.conv = nn.Conv2d(dim_in, dim_out, kernel_size, stride = stride, padding = padding)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
class PEG(nn.Module):
|
||||
def __init__(self, dim, kernel_size = 3):
|
||||
super().__init__()
|
||||
self.proj = nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.proj(x) + x
|
||||
|
||||
# feedforward
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, mult = 4, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(dim, inner_dim, 1),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Conv2d(inner_dim, dim, 1),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
# attention
|
||||
|
||||
class DSSA(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
heads = 8,
|
||||
dim_head = 32,
|
||||
dropout = 0.,
|
||||
window_size = 7
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.window_size = window_size
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.attend = nn.Sequential(
|
||||
nn.Softmax(dim = -1),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
self.to_qkv = nn.Conv1d(dim, inner_dim * 3, 1, bias = False)
|
||||
|
||||
# window tokens
|
||||
|
||||
self.window_tokens = nn.Parameter(torch.randn(dim))
|
||||
|
||||
# prenorm and non-linearity for window tokens
|
||||
# then projection to queries and keys for window tokens
|
||||
|
||||
self.window_tokens_to_qk = nn.Sequential(
|
||||
nn.LayerNorm(dim_head),
|
||||
nn.GELU(),
|
||||
Rearrange('b h n c -> b (h c) n'),
|
||||
nn.Conv1d(inner_dim, inner_dim * 2, 1),
|
||||
Rearrange('b (h c) n -> b h n c', h = heads),
|
||||
)
|
||||
|
||||
# window attention
|
||||
|
||||
self.window_attend = nn.Sequential(
|
||||
nn.Softmax(dim = -1),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Conv2d(inner_dim, dim, 1),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
einstein notation
|
||||
|
||||
b - batch
|
||||
c - channels
|
||||
w1 - window size (height)
|
||||
w2 - also window size (width)
|
||||
i - sequence dimension (source)
|
||||
j - sequence dimension (target dimension to be reduced)
|
||||
h - heads
|
||||
x - height of feature map divided by window size
|
||||
y - width of feature map divided by window size
|
||||
"""
|
||||
|
||||
batch, height, width, heads, wsz = x.shape[0], *x.shape[-2:], self.heads, self.window_size
|
||||
assert (height % wsz) == 0 and (width % wsz) == 0, f'height {height} and width {width} must be divisible by window size {wsz}'
|
||||
num_windows = (height // wsz) * (width // wsz)
|
||||
|
||||
# fold in windows for "depthwise" attention - not sure why it is named depthwise when it is just "windowed" attention
|
||||
|
||||
x = rearrange(x, 'b c (h w1) (w w2) -> (b h w) c (w1 w2)', w1 = wsz, w2 = wsz)
|
||||
|
||||
# add windowing tokens
|
||||
|
||||
w = repeat(self.window_tokens, 'c -> b c 1', b = x.shape[0])
|
||||
x = torch.cat((w, x), dim = -1)
|
||||
|
||||
# project for queries, keys, value
|
||||
|
||||
q, k, v = self.to_qkv(x).chunk(3, dim = 1)
|
||||
|
||||
# split out heads
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = heads), (q, k, v))
|
||||
|
||||
# scale
|
||||
|
||||
q = q * self.scale
|
||||
|
||||
# similarity
|
||||
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k)
|
||||
|
||||
# attention
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
# aggregate values
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
|
||||
# split out windowed tokens
|
||||
|
||||
window_tokens, windowed_fmaps = out[:, :, 0], out[:, :, 1:]
|
||||
|
||||
# early return if there is only 1 window
|
||||
|
||||
if num_windows == 1:
|
||||
fmap = rearrange(windowed_fmaps, '(b x y) h (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz, y = width // wsz, w1 = wsz, w2 = wsz)
|
||||
return self.to_out(fmap)
|
||||
|
||||
# carry out the pointwise attention, the main novelty in the paper
|
||||
|
||||
window_tokens = rearrange(window_tokens, '(b x y) h d -> b h (x y) d', x = height // wsz, y = width // wsz)
|
||||
windowed_fmaps = rearrange(windowed_fmaps, '(b x y) h n d -> b h (x y) n d', x = height // wsz, y = width // wsz)
|
||||
|
||||
# windowed queries and keys (preceded by prenorm activation)
|
||||
|
||||
w_q, w_k = self.window_tokens_to_qk(window_tokens).chunk(2, dim = -1)
|
||||
|
||||
# scale
|
||||
|
||||
w_q = w_q * self.scale
|
||||
|
||||
# similarities
|
||||
|
||||
w_dots = einsum('b h i d, b h j d -> b h i j', w_q, w_k)
|
||||
|
||||
w_attn = self.window_attend(w_dots)
|
||||
|
||||
# aggregate the feature maps from the "depthwise" attention step (the most interesting part of the paper, one i haven't seen before)
|
||||
|
||||
aggregated_windowed_fmap = einsum('b h i j, b h j w d -> b h i w d', w_attn, windowed_fmaps)
|
||||
|
||||
# fold back the windows and then combine heads for aggregation
|
||||
|
||||
fmap = rearrange(aggregated_windowed_fmap, 'b h (x y) (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz, y = width // wsz, w1 = wsz, w2 = wsz)
|
||||
return self.to_out(fmap)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
depth,
|
||||
dim_head = 32,
|
||||
heads = 8,
|
||||
ff_mult = 4,
|
||||
dropout = 0.,
|
||||
norm_output = True
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
for ind in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, DSSA(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mult = ff_mult, dropout = dropout)),
|
||||
]))
|
||||
|
||||
self.norm = ChanLayerNorm(dim) if norm_output else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class SepViT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
num_classes,
|
||||
dim,
|
||||
depth,
|
||||
heads,
|
||||
window_size = 7,
|
||||
dim_head = 32,
|
||||
ff_mult = 4,
|
||||
channels = 3,
|
||||
dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'
|
||||
|
||||
num_stages = len(depth)
|
||||
|
||||
dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))
|
||||
dims = (channels, *dims)
|
||||
dim_pairs = tuple(zip(dims[:-1], dims[1:]))
|
||||
|
||||
strides = (4, *((2,) * (num_stages - 1)))
|
||||
|
||||
hyperparams_per_stage = [heads, window_size]
|
||||
hyperparams_per_stage = list(map(partial(cast_tuple, length = num_stages), hyperparams_per_stage))
|
||||
assert all(tuple(map(lambda arr: len(arr) == num_stages, hyperparams_per_stage)))
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
for ind, ((layer_dim_in, layer_dim), layer_depth, layer_stride, layer_heads, layer_window_size) in enumerate(zip(dim_pairs, depth, strides, *hyperparams_per_stage)):
|
||||
is_last = ind == (num_stages - 1)
|
||||
|
||||
self.layers.append(nn.ModuleList([
|
||||
OverlappingPatchEmbed(layer_dim_in, layer_dim, stride = layer_stride),
|
||||
PEG(layer_dim),
|
||||
Transformer(dim = layer_dim, depth = layer_depth, heads = layer_heads, ff_mult = ff_mult, dropout = dropout, norm_output = not is_last),
|
||||
]))
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
Reduce('b d h w -> b d', 'mean'),
|
||||
nn.LayerNorm(dims[-1]),
|
||||
nn.Linear(dims[-1], num_classes)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
for ope, peg, transformer in self.layers:
|
||||
x = ope(x)
|
||||
x = peg(x)
|
||||
x = transformer(x)
|
||||
|
||||
return self.mlp_head(x)
|
||||
116
vit_pytorch/simple_vit.py
Normal file
116
vit_pytorch/simple_vit.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
|
||||
_, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype
|
||||
|
||||
y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
|
||||
assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
|
||||
omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1)
|
||||
omega = 1. / (temperature ** omega)
|
||||
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
|
||||
return pe.type(dtype)
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head),
|
||||
FeedForward(dim, mlp_dim)
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
class SimpleViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_height, p2 = patch_width),
|
||||
nn.Linear(patch_dim, dim),
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
|
||||
|
||||
self.to_latent = nn.Identity()
|
||||
self.linear_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
*_, h, w, dtype = *img.shape, img.dtype
|
||||
|
||||
x = self.to_patch_embedding(img)
|
||||
pe = posemb_sincos_2d(x)
|
||||
x = rearrange(x, 'b ... d -> b (...) d') + pe
|
||||
|
||||
x = self.transformer(x)
|
||||
x = x.mean(dim = 1)
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.linear_head(x)
|
||||
@@ -130,6 +130,8 @@ class GlobalAttention(nn.Module):
|
||||
self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False)
|
||||
self.to_kv = nn.Conv2d(dim, inner_dim * 2, k, stride = k, bias = False)
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Conv2d(inner_dim, dim, 1),
|
||||
nn.Dropout(dropout)
|
||||
@@ -145,6 +147,7 @@ class GlobalAttention(nn.Module):
|
||||
dots = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
|
||||
attn = dots.softmax(dim = -1)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = einsum('b i j, b j d -> b i d', attn, v)
|
||||
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, y = y)
|
||||
|
||||
@@ -42,6 +42,8 @@ class Attention(nn.Module):
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
@@ -56,6 +58,7 @@ class Attention(nn.Module):
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
@@ -111,7 +114,7 @@ class ViT(nn.Module):
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.dropout(x)
|
||||
|
||||
129
vit_pytorch/vit_3d.py
Normal file
129
vit_pytorch/vit_3d.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
# classes
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
class ViT(nn.Module):
|
||||
def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(image_patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size'
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width) * (frames // frame_patch_size)
|
||||
patch_dim = channels * patch_height * patch_width * frame_patch_size
|
||||
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b (f h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
nn.Linear(patch_dim, dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
self.pool = pool
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.mlp_head(x)
|
||||
@@ -42,6 +42,8 @@ class LSA(nn.Module):
|
||||
self.temperature = nn.Parameter(torch.log(torch.tensor(dim_head ** -0.5)))
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
@@ -60,6 +62,7 @@ class LSA(nn.Module):
|
||||
dots = dots.masked_fill(mask, mask_value)
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
|
||||
@@ -63,6 +63,8 @@ class Attention(nn.Module):
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
@@ -77,6 +79,7 @@ class Attention(nn.Module):
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
|
||||
Reference in New Issue
Block a user