Compare commits

..

18 Commits

Author SHA1 Message Date
Phil Wang
d93cd84ccd let windowed tokens exchange information across heads a la talking heads prior to pointwise attention in sep-vit 2022-03-31 15:22:24 -07:00
Phil Wang
5d4c798949 cleanup sepvit 2022-03-31 14:35:11 -07:00
Phil Wang
d65a742efe intent to build (#210)
complete SepViT, from bytedance AI labs
2022-03-31 14:30:23 -07:00
Phil Wang
8c54e01492 do not layernorm on last transformer block for scalable vit, as there is already one in mlp head 2022-03-31 13:25:21 -07:00
Phil Wang
df656fe7c7 complete learnable memory ViT, for efficient fine-tuning and potentially plays into continual learning 2022-03-31 09:51:12 -07:00
Phil Wang
4e6a42a0ca correct need for post-attention dropout 2022-03-30 10:50:57 -07:00
Phil Wang
6d7298d8ad link to tensorflow2 translation by @taki0112 2022-03-28 09:05:34 -07:00
Phil Wang
9cd56ff29b CCT allow for rectangular images 2022-03-26 14:02:49 -07:00
Phil Wang
2aae406ce8 add proposed parallel vit from facebook ai for exploration purposes 2022-03-23 10:42:35 -07:00
Phil Wang
c2b2db2a54 fix window size of none for scalable vit for rectangular images 2022-03-22 17:37:59 -07:00
Phil Wang
719048d1bd some better defaults for scalable vit 2022-03-22 17:19:58 -07:00
Phil Wang
d27721a85a add scalable vit, from bytedance AI 2022-03-22 17:02:47 -07:00
Phil Wang
cb22cbbd19 update to einops 0.4, which is torchscript jit friendly 2022-03-22 13:58:00 -07:00
Phil Wang
6db20debb4 add patch merger 2022-03-01 16:50:17 -08:00
Phil Wang
1bae5d3cc5 allow for rectangular images for efficient adapter 2022-01-31 08:55:31 -08:00
Phil Wang
25b384297d return None from extractor if no attention layers 2022-01-28 17:49:58 -08:00
Phil Wang
64a07f50e6 epsilon should be inside square root 2022-01-24 17:24:41 -08:00
Phil Wang
126d204ff2 fix block repeats in readme example for Nest 2022-01-22 21:32:53 -08:00
32 changed files with 1464 additions and 52 deletions

314
README.md
View File

@@ -18,13 +18,18 @@
- [Twins SVT](#twins-svt)
- [CrossFormer](#crossformer)
- [RegionViT](#regionvit)
- [ScalableViT](#scalablevit)
- [SepViT](#sepvit)
- [NesT](#nest)
- [MobileViT](#mobilevit)
- [Masked Autoencoder](#masked-autoencoder)
- [Simple Masked Image Modeling](#simple-masked-image-modeling)
- [Masked Patch Prediction](#masked-patch-prediction)
- [Adaptive Token Sampling](#adaptive-token-sampling)
- [Patch Merger](#patch-merger)
- [Vision Transformer for Small Datasets](#vision-transformer-for-small-datasets)
- [Parallel ViT](#parallel-vit)
- [Learnable Memory ViT](#learnable-memory-vit)
- [Dino](#dino)
- [Accessing Attention](#accessing-attention)
- [Research Ideas](#research-ideas)
@@ -42,6 +47,8 @@ For a Pytorch implementation with pretrained models, please see Ross Wightman's
The official Jax repository is <a href="https://github.com/google-research/vision_transformer">here</a>.
A tensorflow2 translation also exists <a href="https://github.com/taki0112/vit-tensorflow">here</a>, created by research scientist <a href="https://github.com/taki0112">Junho Kim</a>! 🙏
## Install
```bash
@@ -238,6 +245,7 @@ preds = v(img) # (1, 1000)
```
## CCT
<img src="https://raw.githubusercontent.com/SHI-Labs/Compact-Transformers/main/images/model_sym.png" width="400px"></img>
<a href="https://arxiv.org/abs/2104.05704">CCT</a> proposes compact transformers
@@ -249,22 +257,25 @@ You can use this with two methods
import torch
from vit_pytorch.cct import CCT
model = CCT(
img_size=224,
embedding_dim=384,
n_conv_layers=2,
kernel_size=7,
stride=2,
padding=3,
pooling_kernel_size=3,
pooling_stride=2,
pooling_padding=1,
num_layers=14,
num_heads=6,
mlp_radio=3.,
num_classes=1000,
positional_embedding='learnable', # ['sine', 'learnable', 'none']
)
cct = CCT(
img_size = (224, 448),
embedding_dim = 384,
n_conv_layers = 2,
kernel_size = 7,
stride = 2,
padding = 3,
pooling_kernel_size = 3,
pooling_stride = 2,
pooling_padding = 1,
num_layers = 14,
num_heads = 6,
mlp_radio = 3.,
num_classes = 1000,
positional_embedding = 'learnable', # ['sine', 'learnable', 'none']
)
img = torch.randn(1, 3, 224, 448)
pred = cct(img) # (1, 1000)
```
Alternatively you can use one of several pre-defined models `[2,4,6,7,8,14,16]`
@@ -275,23 +286,23 @@ and the embedding dimension.
import torch
from vit_pytorch.cct import cct_14
model = cct_14(
img_size=224,
n_conv_layers=1,
kernel_size=7,
stride=2,
padding=3,
pooling_kernel_size=3,
pooling_stride=2,
pooling_padding=1,
num_classes=1000,
positional_embedding='learnable', # ['sine', 'learnable', 'none']
)
cct = cct_14(
img_size = 224,
n_conv_layers = 1,
kernel_size = 7,
stride = 2,
padding = 3,
pooling_kernel_size = 3,
pooling_stride = 2,
pooling_padding = 1,
num_classes = 1000,
positional_embedding = 'learnable', # ['sine', 'learnable', 'none']
)
```
<a href="https://github.com/SHI-Labs/Compact-Transformers">Official
Repository</a> includes links to pretrained model checkpoints.
## Cross ViT
<img src="./images/cross_vit.png" width="400px"></img>
@@ -524,6 +535,67 @@ img = torch.randn(1, 3, 224, 224)
pred = model(img) # (1, 1000)
```
## ScalableViT
<img src="./images/scalable-vit-1.png" width="400px"></img>
<img src="./images/scalable-vit-2.png" width="400px"></img>
This Bytedance AI <a href="https://arxiv.org/abs/2203.10790">paper</a> proposes the Scalable Self Attention (SSA) and the Interactive Windowed Self Attention (IWSA) modules. The SSA alleviates the computation needed at earlier stages by reducing the key / value feature map by some factor (`reduction_factor`), while modulating the dimension of the queries and keys (`ssa_dim_key`). The IWSA performs self attention within local windows, similar to other vision transformer papers. However, they add a residual of the values, passed through a convolution of kernel size 3, which they named Local Interactive Module (LIM).
They make the claim in this paper that this scheme outperforms Swin Transformer, and also demonstrate competitive performance against Crossformer.
You can use it as follows (ex. ScalableViT-S)
```python
import torch
from vit_pytorch.scalable_vit import ScalableViT
model = ScalableViT(
num_classes = 1000,
dim = 64, # starting model dimension. at every stage, dimension is doubled
heads = (2, 4, 8, 16), # number of attention heads at each stage
depth = (2, 2, 20, 2), # number of transformer blocks at each stage
ssa_dim_key = (40, 40, 40, 32), # the dimension of the attention keys (and queries) for SSA. in the paper, they represented this as a scale factor on the base dimension per key (ssa_dim_key / dim_key)
reduction_factor = (8, 4, 2, 1), # downsampling of the key / values in SSA. in the paper, this was represented as (reduction_factor ** -2)
window_size = (64, 32, None, None), # window size of the IWSA at each stage. None means no windowing needed
dropout = 0.1, # attention and feedforward dropout
)
img = torch.randn(1, 3, 256, 256)
preds = model(img) # (1, 1000)
```
## SepViT
<img src="./images/sep-vit.png" width="400px"></img>
Another <a href="https://arxiv.org/abs/2203.15380">Bytedance AI paper</a>, it proposes a depthwise-pointwise self-attention layer that seems largely inspired by mobilenet's depthwise-separable convolution. The most interesting aspect is the reuse of the feature map from the depthwise self-attention stage as the values for the pointwise self-attention, as shown in the diagram above.
I have decided to include only the version of `SepViT` with this specific self-attention layer, as the grouped attention layers are not remarkable nor novel, and the authors were not clear on how they treated the window tokens for the group self-attention layer. Besides, it seems like with `DSSA` layer alone, they were able to beat Swin.
ex. SepViT-Lite
```python
import torch
from vit_pytorch.sep_vit import SepViT
v = SepViT(
num_classes = 1000,
dim = 32, # dimensions of first stage, which doubles every stage (32, 64, 128, 256) for SepViT-Lite
dim_head = 32, # attention head dimension
heads = (1, 2, 4, 8), # number of heads per stage
depth = (1, 2, 6, 2), # number of transformer blocks per stage
window_size = 7, # window size of DSS Attention block
dropout = 0.1 # dropout
)
img = torch.randn(1, 3, 224, 224)
preds = v(img) # (1, 1000)
```
## NesT
<img src="./images/nest.png" width="400px"></img>
@@ -542,7 +614,7 @@ nest = NesT(
dim = 96,
heads = 3,
num_hierarchies = 3, # number of hierarchies
block_repeats = (8, 4, 1), # the number of transformer blocks at each heirarchy, starting from the bottom
block_repeats = (2, 2, 8), # the number of transformer blocks at each heirarchy, starting from the bottom
num_classes = 1000
)
@@ -732,12 +804,58 @@ v = ViT(
img = torch.randn(4, 3, 256, 256)
preds = v(img) # (1, 1000)
preds = v(img) # (4, 1000)
# you can also get a list of the final sampled patch ids
# a value of -1 denotes padding
preds, token_ids = v(img, return_sampled_token_ids = True) # (1, 1000), (1, <=8)
preds, token_ids = v(img, return_sampled_token_ids = True) # (4, 1000), (4, <=8)
```
## Patch Merger
<img src="./images/patch_merger.png" width="400px"></img>
This <a href="https://arxiv.org/abs/2202.12015">paper</a> proposes a simple module (Patch Merger) for reducing the number of tokens at any layer of a vision transformer without sacrificing performance.
```python
import torch
from vit_pytorch.vit_with_patch_merger import ViT
v = ViT(
image_size = 256,
patch_size = 16,
num_classes = 1000,
dim = 1024,
depth = 12,
heads = 8,
patch_merge_layer = 6, # at which transformer layer to do patch merging
patch_merge_num_tokens = 8, # the output number of tokens from the patch merge
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
img = torch.randn(4, 3, 256, 256)
preds = v(img) # (4, 1000)
```
One can also use the `PatchMerger` module by itself
```python
import torch
from vit_pytorch.vit_with_patch_merger import PatchMerger
merger = PatchMerger(
dim = 1024,
num_tokens_out = 8 # output number of tokens
)
features = torch.randn(4, 256, 1024) # (batch, num tokens, dimension)
out = merger(features) # (4, 8, 1024)
```
## Vision Transformer for Small Datasets
@@ -786,6 +904,92 @@ img = torch.randn(4, 3, 256, 256)
tokens = spt(img) # (4, 256, 1024)
```
## Parallel ViT
<img src="./images/parallel-vit.png" width="350px"></img>
This <a href="https://arxiv.org/abs/2203.09795">paper</a> propose parallelizing multiple attention and feedforward blocks per layer (2 blocks), claiming that it is easier to train without loss of performance.
You can try this variant as follows
```python
import torch
from vit_pytorch.parallel_vit import ViT
v = ViT(
image_size = 256,
patch_size = 16,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048,
num_parallel_branches = 2, # in paper, they claimed 2 was optimal
dropout = 0.1,
emb_dropout = 0.1
)
img = torch.randn(4, 3, 256, 256)
preds = v(img) # (4, 1000)
```
## Learnable Memory ViT
<img src="./images/learnable-memory-vit.png" width="350px"></img>
This <a href="https://arxiv.org/abs/2203.15243">paper</a> shows that adding learnable memory tokens at each layer of a vision transformer can greatly enhance fine-tuning results (in addition to learnable task specific CLS token and adapter head).
You can use this with a specially modified `ViT` as follows
```python
import torch
from vit_pytorch.learnable_memory_vit import ViT, Adapter
# normal base ViT
v = ViT(
image_size = 256,
patch_size = 16,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
img = torch.randn(4, 3, 256, 256)
logits = v(img) # (4, 1000)
# do your usual training with ViT
# ...
# then, to finetune, just pass the ViT into the Adapter class
# you can do this for multiple Adapters, as shown below
adapter1 = Adapter(
vit = v,
num_classes = 2, # number of output classes for this specific task
num_memories_per_layer = 5 # number of learnable memories per layer, 10 was sufficient in paper
)
logits1 = adapter1(img) # (4, 2) - predict 2 classes off frozen ViT backbone with learnable memories and task specific head
# yet another task to finetune on, this time with 4 classes
adapter2 = Adapter(
vit = v,
num_classes = 4,
num_memories_per_layer = 10
)
logits2 = adapter2(img) # (4, 4) - predict 4 classes off frozen ViT backbone with learnable memories and task specific head
```
## Dino
<img src="./images/dino.png" width="350px"></img>
@@ -1294,6 +1498,52 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```
```bibtex
@misc{renggli2022learning,
title = {Learning to Merge Tokens in Vision Transformers},
author = {Cedric Renggli and André Susano Pinto and Neil Houlsby and Basil Mustafa and Joan Puigcerver and Carlos Riquelme},
year = {2022},
eprint = {2202.12015},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@misc{yang2022scalablevit,
title = {ScalableViT: Rethinking the Context-oriented Generalization of Vision Transformer},
author = {Rui Yang and Hailong Ma and Jie Wu and Yansong Tang and Xuefeng Xiao and Min Zheng and Xiu Li},
year = {2022},
eprint = {2203.10790},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@inproceedings{Touvron2022ThreeTE,
title = {Three things everyone should know about Vision Transformers},
author = {Hugo Touvron and Matthieu Cord and Alaaeldin El-Nouby and Jakob Verbeek and Herv'e J'egou},
year = {2022}
}
```
```bibtex
@inproceedings{Sandler2022FinetuningIT,
title = {Fine-tuning Image Transformers using Learnable Memory},
author = {Mark Sandler and Andrey Zhmoginov and Max Vladymyrov and Andrew Jackson},
year = {2022}
}
```
```bibtex
@inproceedings{Li2022SepViTSV,
title = {SepViT: Separable Vision Transformer},
author = {Wei Li and Xing Wang and Xin Xia and Jie Wu and Xuefeng Xiao and Minghang Zheng and Shiping Wen},
year = {2022}
}
```
```bibtex
@misc{vaswani2017attention,
title = {Attention Is All You Need},

Binary file not shown.

After

Width:  |  Height:  |  Size: 108 KiB

BIN
images/parallel-vit.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 14 KiB

BIN
images/patch_merger.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 54 KiB

BIN
images/scalable-vit-1.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 79 KiB

BIN
images/scalable-vit-2.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 62 KiB

BIN
images/sep-vit.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 142 KiB

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.26.4',
version = '0.32.2',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',
@@ -15,7 +15,7 @@ setup(
'image recognition'
],
install_requires=[
'einops>=0.3',
'einops>=0.4.1',
'torch>=1.6',
'torchvision'
],

View File

@@ -139,6 +139,8 @@ class Attention(nn.Module):
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.output_num_tokens = output_num_tokens
@@ -163,6 +165,7 @@ class Attention(nn.Module):
dots = dots.masked_fill(~dots_mask, mask_value)
attn = self.attend(dots)
attn = self.dropout(attn)
sampled_token_ids = None

View File

@@ -76,6 +76,7 @@ class Attention(nn.Module):
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.mix_heads_pre_attn = nn.Parameter(torch.randn(heads, heads))
self.mix_heads_post_attn = nn.Parameter(torch.randn(heads, heads))
@@ -96,7 +97,10 @@ class Attention(nn.Module):
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
dots = einsum('b h i j, h g -> b g i j', dots, self.mix_heads_pre_attn) # talking heads, pre-softmax
attn = self.attend(dots)
attn = self.dropout(attn)
attn = einsum('b h i j, h g -> b g i j', attn, self.mix_heads_post_attn) # talking heads, post-softmax
out = einsum('b h i j, b h j d -> b h i d', attn, v)

View File

@@ -2,7 +2,13 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
# Pre-defined CCT Models
# helpers
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# CCT Models
__all__ = ['cct_2', 'cct_4', 'cct_6', 'cct_7', 'cct_8', 'cct_14', 'cct_16']
@@ -55,8 +61,8 @@ def _cct(num_layers, num_heads, mlp_ratio, embedding_dim,
padding=padding,
*args, **kwargs)
# modules
# Modules
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1):
super().__init__()
@@ -308,6 +314,7 @@ class CCT(nn.Module):
pooling_padding=1,
*args, **kwargs):
super(CCT, self).__init__()
img_height, img_width = pair(img_size)
self.tokenizer = Tokenizer(n_input_channels=n_input_channels,
n_output_channels=embedding_dim,
@@ -324,8 +331,8 @@ class CCT(nn.Module):
self.classifier = TransformerClassifier(
sequence_length=self.tokenizer.sequence_length(n_channels=n_input_channels,
height=img_size,
width=img_size),
height=img_height,
width=img_width),
embedding_dim=embedding_dim,
seq_pool=True,
dropout_rate=0.,
@@ -336,4 +343,3 @@ class CCT(nn.Module):
def forward(self, x):
x = self.tokenizer(x)
return self.classifier(x)

View File

@@ -48,6 +48,8 @@ class Attention(nn.Module):
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
@@ -69,6 +71,7 @@ class Attention(nn.Module):
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')

View File

@@ -62,9 +62,9 @@ class LayerNorm(nn.Module):
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (std + self.eps) * self.g + self.b
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
def FeedForward(dim, mult = 4, dropout = 0.):
return nn.Sequential(
@@ -95,6 +95,9 @@ class Attention(nn.Module):
self.window_size = window_size
self.norm = LayerNorm(dim)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(inner_dim, dim, 1)
@@ -151,6 +154,7 @@ class Attention(nn.Module):
# attend
attn = sim.softmax(dim = -1)
attn = self.dropout(attn)
# merge heads

View File

@@ -30,9 +30,9 @@ class LayerNorm(nn.Module): # layernorm, but done in the channel dimension #1
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (std + self.eps) * self.g + self.b
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
@@ -76,6 +76,7 @@ class Attention(nn.Module):
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_q = DepthWiseConv2d(dim, inner_dim, proj_kernel, padding = padding, stride = 1, bias = False)
self.to_kv = DepthWiseConv2d(dim, inner_dim * 2, proj_kernel, padding = padding, stride = kv_proj_stride, bias = False)
@@ -94,6 +95,7 @@ class Attention(nn.Module):
dots = einsum('b i d, b j d -> b i j', q, k) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, y = y)

View File

@@ -42,6 +42,8 @@ class Attention(nn.Module):
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.dropout = nn.Dropout(dropout)
self.reattn_weights = nn.Parameter(torch.randn(heads, heads))
self.reattn_norm = nn.Sequential(
@@ -64,6 +66,7 @@ class Attention(nn.Module):
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = dots.softmax(dim=-1)
attn = self.dropout(attn)
# re-attention

View File

@@ -3,12 +3,16 @@ from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
def pair(t):
return t if isinstance(t, tuple) else (t, t)
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, transformer, pool = 'cls', channels = 3):
super().__init__()
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
image_size_h, image_size_w = pair(image_size)
assert image_size_h % patch_size == 0 and image_size_w % patch_size == 0, 'image dimensions must be divisible by the patch size'
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
num_patches = (image_size // patch_size) ** 2
num_patches = (image_size_h // patch_size) * (image_size_w // patch_size)
patch_dim = channels * patch_size ** 2
self.to_patch_embedding = nn.Sequential(

View File

@@ -0,0 +1,216 @@
import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
def exists(val):
return val is not None
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# controlling freezing of layers
def set_module_requires_grad_(module, requires_grad):
for param in module.parameters():
param.requires_grad = requires_grad
def freeze_all_layers_(module):
set_module_requires_grad_(module, False)
def unfreeze_all_layers_(module):
set_module_requires_grad_(module, True)
# classes
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, attn_mask = None, memories = None):
x = self.norm(x)
x_kv = x # input for key / values projection
if exists(memories):
# add memories to key / values if it is passed in
memories = repeat(memories, 'n d -> b n d', b = x.shape[0]) if memories.ndim == 2 else memories
x_kv = torch.cat((x_kv, memories), dim = 1)
qkv = (self.to_q(x), *self.to_kv(x_kv).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
if exists(attn_mask):
dots = dots.masked_fill(~attn_mask, -torch.finfo(dots.dtype).max)
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x, attn_mask = None, memories = None):
for ind, (attn, ff) in enumerate(self.layers):
layer_memories = memories[ind] if exists(memories) else None
x = attn(x, attn_mask = attn_mask, memories = layer_memories) + x
x = ff(x) + x
return x
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.Linear(patch_dim, dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def img_to_tokens(self, img):
x = self.to_patch_embedding(img)
cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = x.shape[0])
x = torch.cat((cls_tokens, x), dim = 1)
x += self.pos_embedding
x = self.dropout(x)
return x
def forward(self, img):
x = self.img_to_tokens(img)
x = self.transformer(x)
cls_tokens = x[:, 0]
return self.mlp_head(cls_tokens)
# adapter with learnable memories per layer, memory CLS token, and learnable adapter head
class Adapter(nn.Module):
def __init__(
self,
*,
vit,
num_memories_per_layer = 10,
num_classes = 2,
):
super().__init__()
assert isinstance(vit, ViT)
# extract some model variables needed
dim = vit.cls_token.shape[-1]
layers = len(vit.transformer.layers)
num_patches = vit.pos_embedding.shape[-2]
self.vit = vit
# freeze ViT backbone - only memories will be finetuned
freeze_all_layers_(vit)
# learnable parameters
self.memory_cls_token = nn.Parameter(torch.randn(dim))
self.memories_per_layer = nn.Parameter(torch.randn(layers, num_memories_per_layer, dim))
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
# specialized attention mask to preserve the output of the original ViT
# it allows the memory CLS token to attend to all other tokens (and the learnable memory layer tokens), but not vice versa
attn_mask = torch.ones((num_patches, num_patches), dtype = torch.bool)
attn_mask = F.pad(attn_mask, (1, num_memories_per_layer), value = False) # main tokens cannot attend to learnable memories per layer
attn_mask = F.pad(attn_mask, (0, 0, 1, 0), value = True) # memory CLS token can attend to everything
self.register_buffer('attn_mask', attn_mask)
def forward(self, img):
b = img.shape[0]
tokens = self.vit.img_to_tokens(img)
# add task specific memory tokens
memory_cls_tokens = repeat(self.memory_cls_token, 'd -> b 1 d', b = b)
tokens = torch.cat((memory_cls_tokens, tokens), dim = 1)
# pass memories along with image tokens through transformer for attending
out = self.vit.transformer(tokens, memories = self.memories_per_layer, attn_mask = self.attn_mask)
# extract memory CLS tokens
memory_cls_tokens = out[:, 0]
# pass through task specific adapter head
return self.mlp_head(memory_cls_tokens)

View File

@@ -52,6 +52,7 @@ class Attention(nn.Module):
self.to_v = nn.Sequential(nn.Conv2d(dim, inner_dim_value, 1, bias = False), nn.BatchNorm2d(inner_dim_value))
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
out_batch_norm = nn.BatchNorm2d(dim_out)
nn.init.zeros_(out_batch_norm.weight)
@@ -100,6 +101,7 @@ class Attention(nn.Module):
dots = self.apply_pos_bias(dots)
attn = self.attend(dots)
attn = self.dropout(attn)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h (x y) d -> b (h d) x y', h = h, y = y)

View File

@@ -78,6 +78,7 @@ class Attention(nn.Module):
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
@@ -93,6 +94,7 @@ class Attention(nn.Module):
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')

View File

@@ -54,6 +54,8 @@ class Attention(nn.Module):
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Sequential(
@@ -67,7 +69,10 @@ class Attention(nn.Module):
t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b p h n d -> b p n (h d)')
return self.to_out(out)

View File

@@ -20,9 +20,9 @@ class LayerNorm(nn.Module):
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (std + self.eps) * self.g + self.b
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
@@ -55,6 +55,7 @@ class Attention(nn.Module):
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
self.to_out = nn.Sequential(
@@ -71,6 +72,7 @@ class Attention(nn.Module):
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)

140
vit_pytorch/parallel_vit.py Normal file
View File

@@ -0,0 +1,140 @@
import torch
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# classes
class Parallel(nn.Module):
def __init__(self, *fns):
super().__init__()
self.fns = nn.ModuleList(fns)
def forward(self, x):
return sum([fn(x) for fn in self.fns])
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, num_parallel_branches = 2, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
attn_block = lambda: PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))
ff_block = lambda: PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
for _ in range(depth):
self.layers.append(nn.ModuleList([
Parallel(*[attn_block() for _ in range(num_parallel_branches)]),
Parallel(*[ff_block() for _ in range(num_parallel_branches)]),
]))
def forward(self, x):
for attns, ffs in self.layers:
x = attns(x) + x
x = ffs(x) + x
return x
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', num_parallel_branches = 2, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.Linear(patch_dim, dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, num_parallel_branches, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)

View File

@@ -48,6 +48,7 @@ class Attention(nn.Module):
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
@@ -63,6 +64,7 @@ class Attention(nn.Module):
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')

View File

@@ -55,5 +55,5 @@ class Recorder(nn.Module):
target_device = self.device if self.device is not None else img.device
recordings = tuple(map(lambda t: t.to(target_device), self.recordings))
attns = torch.stack(recordings, dim = 1)
attns = torch.stack(recordings, dim = 1) if len(recordings) > 0 else None
return pred, attns

View File

@@ -61,8 +61,13 @@ class Attention(nn.Module):
inner_dim = dim_head * heads
self.norm = nn.LayerNorm(dim)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, rel_pos_bias = None):
h = self.heads
@@ -86,6 +91,7 @@ class Attention(nn.Module):
sim = sim + rel_pos_bias
attn = sim.softmax(dim = -1)
attn = self.dropout(attn)
# merge heads

View File

@@ -104,6 +104,7 @@ class Attention(nn.Module):
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.use_ds_conv = use_ds_conv
@@ -148,6 +149,7 @@ class Attention(nn.Module):
dots = einsum('b i d, b j d -> b i j', q, k) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)

306
vit_pytorch/scalable_vit.py Normal file
View File

@@ -0,0 +1,306 @@
from functools import partial
import torch
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange, Reduce
# helpers
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
# helper classes
class ChanLayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = ChanLayerNorm(dim)
self.fn = fn
def forward(self, x):
return self.fn(self.norm(x))
class Downsample(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.conv = nn.Conv2d(dim_in, dim_out, 3, stride = 2, padding = 1)
def forward(self, x):
return self.conv(x)
class PEG(nn.Module):
def __init__(self, dim, kernel_size = 3):
super().__init__()
self.proj = nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1)
def forward(self, x):
return self.proj(x) + x
# feedforward
class FeedForward(nn.Module):
def __init__(self, dim, expansion_factor = 4, dropout = 0.):
super().__init__()
inner_dim = dim * expansion_factor
self.net = nn.Sequential(
nn.Conv2d(dim, inner_dim, 1),
nn.GELU(),
nn.Dropout(dropout),
nn.Conv2d(inner_dim, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
# attention
class ScalableSelfAttention(nn.Module):
def __init__(
self,
dim,
heads = 8,
dim_key = 32,
dim_value = 32,
dropout = 0.,
reduction_factor = 1
):
super().__init__()
self.heads = heads
self.scale = dim_key ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_q = nn.Conv2d(dim, dim_key * heads, 1, bias = False)
self.to_k = nn.Conv2d(dim, dim_key * heads, reduction_factor, stride = reduction_factor, bias = False)
self.to_v = nn.Conv2d(dim, dim_value * heads, reduction_factor, stride = reduction_factor, bias = False)
self.to_out = nn.Sequential(
nn.Conv2d(dim_value * heads, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
height, width, heads = *x.shape[-2:], self.heads
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
# split out heads
q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = heads), (q, k, v))
# similarity
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
# attention
attn = self.attend(dots)
attn = self.dropout(attn)
# aggregate values
out = torch.matmul(attn, v)
# merge back heads
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = height, y = width)
return self.to_out(out)
class InteractiveWindowedSelfAttention(nn.Module):
def __init__(
self,
dim,
window_size,
heads = 8,
dim_key = 32,
dim_value = 32,
dropout = 0.
):
super().__init__()
self.heads = heads
self.scale = dim_key ** -0.5
self.window_size = window_size
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.local_interactive_module = nn.Conv2d(dim_value * heads, dim_value * heads, 3, padding = 1)
self.to_q = nn.Conv2d(dim, dim_key * heads, 1, bias = False)
self.to_k = nn.Conv2d(dim, dim_key * heads, 1, bias = False)
self.to_v = nn.Conv2d(dim, dim_value * heads, 1, bias = False)
self.to_out = nn.Sequential(
nn.Conv2d(dim_value * heads, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
height, width, heads, wsz = *x.shape[-2:], self.heads, self.window_size
wsz_h, wsz_w = default(wsz, height), default(wsz, width)
assert (height % wsz_h) == 0 and (width % wsz_w) == 0, f'height ({height}) or width ({width}) of feature map is not divisible by the window size ({wsz_h}, {wsz_w})'
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
# get output of LIM
local_out = self.local_interactive_module(v)
# divide into window (and split out heads) for efficient self attention
q, k, v = map(lambda t: rearrange(t, 'b (h d) (x w1) (y w2) -> (b x y) h (w1 w2) d', h = heads, w1 = wsz_h, w2 = wsz_w), (q, k, v))
# similarity
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
# attention
attn = self.attend(dots)
attn = self.dropout(attn)
# aggregate values
out = torch.matmul(attn, v)
# reshape the windows back to full feature map (and merge heads)
out = rearrange(out, '(b x y) h (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz_h, y = width // wsz_w, w1 = wsz_h, w2 = wsz_w)
# add LIM output
out = out + local_out
return self.to_out(out)
class Transformer(nn.Module):
def __init__(
self,
dim,
depth,
heads = 8,
ff_expansion_factor = 4,
dropout = 0.,
ssa_dim_key = 32,
ssa_dim_value = 32,
ssa_reduction_factor = 1,
iwsa_dim_key = 32,
iwsa_dim_value = 32,
iwsa_window_size = None,
norm_output = True
):
super().__init__()
self.layers = nn.ModuleList([])
for ind in range(depth):
is_first = ind == 0
self.layers.append(nn.ModuleList([
PreNorm(dim, ScalableSelfAttention(dim, heads = heads, dim_key = ssa_dim_key, dim_value = ssa_dim_value, reduction_factor = ssa_reduction_factor, dropout = dropout)),
PreNorm(dim, FeedForward(dim, expansion_factor = ff_expansion_factor, dropout = dropout)),
PEG(dim) if is_first else None,
PreNorm(dim, FeedForward(dim, expansion_factor = ff_expansion_factor, dropout = dropout)),
PreNorm(dim, InteractiveWindowedSelfAttention(dim, heads = heads, dim_key = iwsa_dim_key, dim_value = iwsa_dim_value, window_size = iwsa_window_size, dropout = dropout))
]))
self.norm = ChanLayerNorm(dim) if norm_output else nn.Identity()
def forward(self, x):
for ssa, ff1, peg, iwsa, ff2 in self.layers:
x = ssa(x) + x
x = ff1(x) + x
if exists(peg):
x = peg(x)
x = iwsa(x) + x
x = ff2(x) + x
return self.norm(x)
class ScalableViT(nn.Module):
def __init__(
self,
*,
num_classes,
dim,
depth,
heads,
reduction_factor,
window_size = None,
iwsa_dim_key = 32,
iwsa_dim_value = 32,
ssa_dim_key = 32,
ssa_dim_value = 32,
ff_expansion_factor = 4,
channels = 3,
dropout = 0.
):
super().__init__()
self.to_patches = nn.Conv2d(channels, dim, 7, stride = 4, padding = 3)
assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'
num_stages = len(depth)
dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))
hyperparams_per_stage = [
heads,
ssa_dim_key,
ssa_dim_value,
reduction_factor,
iwsa_dim_key,
iwsa_dim_value,
window_size,
]
hyperparams_per_stage = list(map(partial(cast_tuple, length = num_stages), hyperparams_per_stage))
assert all(tuple(map(lambda arr: len(arr) == num_stages, hyperparams_per_stage)))
self.layers = nn.ModuleList([])
for ind, (layer_dim, layer_depth, layer_heads, layer_ssa_dim_key, layer_ssa_dim_value, layer_ssa_reduction_factor, layer_iwsa_dim_key, layer_iwsa_dim_value, layer_window_size) in enumerate(zip(dims, depth, *hyperparams_per_stage)):
is_last = ind == (num_stages - 1)
self.layers.append(nn.ModuleList([
Transformer(dim = layer_dim, depth = layer_depth, heads = layer_heads, ff_expansion_factor = ff_expansion_factor, dropout = dropout, ssa_dim_key = layer_ssa_dim_key, ssa_dim_value = layer_ssa_dim_value, ssa_reduction_factor = layer_ssa_reduction_factor, iwsa_dim_key = layer_iwsa_dim_key, iwsa_dim_value = layer_iwsa_dim_value, iwsa_window_size = layer_window_size, norm_output = not is_last),
Downsample(layer_dim, layer_dim * 2) if not is_last else None
]))
self.mlp_head = nn.Sequential(
Reduce('b d h w -> b d', 'mean'),
nn.LayerNorm(dims[-1]),
nn.Linear(dims[-1], num_classes)
)
def forward(self, img):
x = self.to_patches(img)
for transformer, downsample in self.layers:
x = transformer(x)
if exists(downsample):
x = downsample(x)
return self.mlp_head(x)

294
vit_pytorch/sep_vit.py Normal file
View File

@@ -0,0 +1,294 @@
from functools import partial
import torch
from torch import nn, einsum
from einops import rearrange, repeat
from einops.layers.torch import Rearrange, Reduce
# helpers
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
# helper classes
class ChanLayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = ChanLayerNorm(dim)
self.fn = fn
def forward(self, x):
return self.fn(self.norm(x))
class OverlappingPatchEmbed(nn.Module):
def __init__(self, dim_in, dim_out, stride = 2):
super().__init__()
kernel_size = stride * 2 - 1
padding = kernel_size // 2
self.conv = nn.Conv2d(dim_in, dim_out, kernel_size, stride = stride, padding = padding)
def forward(self, x):
return self.conv(x)
class PEG(nn.Module):
def __init__(self, dim, kernel_size = 3):
super().__init__()
self.proj = nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1)
def forward(self, x):
return self.proj(x) + x
# feedforward
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
inner_dim = int(dim * mult)
self.net = nn.Sequential(
nn.Conv2d(dim, inner_dim, 1),
nn.GELU(),
nn.Dropout(dropout),
nn.Conv2d(inner_dim, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
# attention
class DSSA(nn.Module):
def __init__(
self,
dim,
heads = 8,
dim_head = 32,
dropout = 0.,
window_size = 7
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
self.window_size = window_size
inner_dim = dim_head * heads
self.attend = nn.Sequential(
nn.Softmax(dim = -1),
nn.Dropout(dropout)
)
self.to_qkv = nn.Conv1d(dim, inner_dim * 3, 1, bias = False)
# window tokens
self.window_tokens = nn.Parameter(torch.randn(dim))
# prenorm and non-linearity for window tokens
# then projection to queries and keys for window tokens
self.window_tokens_to_qk = nn.Sequential(
nn.LayerNorm(dim_head),
nn.GELU(),
Rearrange('b h n c -> b (h c) n'),
nn.Conv1d(inner_dim, inner_dim * 2, 1),
Rearrange('b (h c) n -> b h n c', h = heads),
)
# window attention
self.window_attend = nn.Sequential(
nn.Softmax(dim = -1),
nn.Dropout(dropout)
)
self.to_out = nn.Sequential(
nn.Conv2d(inner_dim, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
"""
einstein notation
b - batch
c - channels
w1 - window size (height)
w2 - also window size (width)
i - sequence dimension (source)
j - sequence dimension (target dimension to be reduced)
h - heads
x - height of feature map divided by window size
y - width of feature map divided by window size
"""
batch, height, width, heads, wsz = x.shape[0], *x.shape[-2:], self.heads, self.window_size
assert (height % wsz) == 0 and (width % wsz) == 0, f'height {height} and width {width} must be divisible by window size {wsz}'
num_windows = (height // wsz) * (width // wsz)
# fold in windows for "depthwise" attention - not sure why it is named depthwise when it is just "windowed" attention
x = rearrange(x, 'b c (h w1) (w w2) -> (b h w) c (w1 w2)', w1 = wsz, w2 = wsz)
# add windowing tokens
w = repeat(self.window_tokens, 'c -> b c 1', b = x.shape[0])
x = torch.cat((w, x), dim = -1)
# project for queries, keys, value
q, k, v = self.to_qkv(x).chunk(3, dim = 1)
# split out heads
q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = heads), (q, k, v))
# scale
q = q * self.scale
# similarity
dots = einsum('b h i d, b h j d -> b h i j', q, k)
# attention
attn = self.attend(dots)
# aggregate values
out = torch.matmul(attn, v)
# split out windowed tokens
window_tokens, windowed_fmaps = out[:, :, 0], out[:, :, 1:]
# early return if there is only 1 window
if num_windows == 1:
fmap = rearrange(windowed_fmaps, '(b x y) h (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz, y = width // wsz, w1 = wsz, w2 = wsz)
return self.to_out(fmap)
# carry out the pointwise attention, the main novelty in the paper
window_tokens = rearrange(window_tokens, '(b x y) h d -> b h (x y) d', x = height // wsz, y = width // wsz)
windowed_fmaps = rearrange(windowed_fmaps, '(b x y) h n d -> b h (x y) n d', x = height // wsz, y = width // wsz)
# windowed queries and keys (preceded by prenorm activation)
w_q, w_k = self.window_tokens_to_qk(window_tokens).chunk(2, dim = -1)
# scale
w_q = w_q * self.scale
# similarities
w_dots = einsum('b h i d, b h j d -> b h i j', w_q, w_k)
w_attn = self.window_attend(w_dots)
# aggregate the feature maps from the "depthwise" attention step (the most interesting part of the paper, one i haven't seen before)
aggregated_windowed_fmap = einsum('b h i j, b h j w d -> b h i w d', w_attn, windowed_fmaps)
# fold back the windows and then combine heads for aggregation
fmap = rearrange(aggregated_windowed_fmap, 'b h (x y) (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz, y = width // wsz, w1 = wsz, w2 = wsz)
return self.to_out(fmap)
class Transformer(nn.Module):
def __init__(
self,
dim,
depth,
dim_head = 32,
heads = 8,
ff_mult = 4,
dropout = 0.,
norm_output = True
):
super().__init__()
self.layers = nn.ModuleList([])
for ind in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, DSSA(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mult = ff_mult, dropout = dropout)),
]))
self.norm = ChanLayerNorm(dim) if norm_output else nn.Identity()
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)
class SepViT(nn.Module):
def __init__(
self,
*,
num_classes,
dim,
depth,
heads,
window_size = 7,
dim_head = 32,
ff_mult = 4,
channels = 3,
dropout = 0.
):
super().__init__()
assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'
num_stages = len(depth)
dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))
dims = (channels, *dims)
dim_pairs = tuple(zip(dims[:-1], dims[1:]))
strides = (4, *((2,) * (num_stages - 1)))
hyperparams_per_stage = [heads, window_size]
hyperparams_per_stage = list(map(partial(cast_tuple, length = num_stages), hyperparams_per_stage))
assert all(tuple(map(lambda arr: len(arr) == num_stages, hyperparams_per_stage)))
self.layers = nn.ModuleList([])
for ind, ((layer_dim_in, layer_dim), layer_depth, layer_stride, layer_heads, layer_window_size) in enumerate(zip(dim_pairs, depth, strides, *hyperparams_per_stage)):
is_last = ind == (num_stages - 1)
self.layers.append(nn.ModuleList([
OverlappingPatchEmbed(layer_dim_in, layer_dim, stride = layer_stride),
PEG(layer_dim),
Transformer(dim = layer_dim, depth = layer_depth, heads = layer_heads, ff_mult = ff_mult, dropout = dropout, norm_output = not is_last),
]))
self.mlp_head = nn.Sequential(
Reduce('b d h w -> b d', 'mean'),
nn.LayerNorm(dims[-1]),
nn.Linear(dims[-1], num_classes)
)
def forward(self, x):
for ope, peg, transformer in self.layers:
x = ope(x)
x = peg(x)
x = transformer(x)
return self.mlp_head(x)

View File

@@ -38,9 +38,9 @@ class LayerNorm(nn.Module):
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (std + self.eps) * self.g + self.b
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
@@ -130,6 +130,8 @@ class GlobalAttention(nn.Module):
self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False)
self.to_kv = nn.Conv2d(dim, inner_dim * 2, k, stride = k, bias = False)
self.dropout = nn.Dropout(dropout)
self.to_out = nn.Sequential(
nn.Conv2d(inner_dim, dim, 1),
nn.Dropout(dropout)
@@ -145,6 +147,7 @@ class GlobalAttention(nn.Module):
dots = einsum('b i d, b j d -> b i j', q, k) * self.scale
attn = dots.softmax(dim = -1)
attn = self.dropout(attn)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, y = y)

View File

@@ -42,6 +42,8 @@ class Attention(nn.Module):
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
@@ -56,6 +58,7 @@ class Attention(nn.Module):
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
@@ -111,7 +114,7 @@ class ViT(nn.Module):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)

View File

@@ -42,6 +42,8 @@ class LSA(nn.Module):
self.temperature = nn.Parameter(torch.log(torch.tensor(dim_head ** -0.5)))
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
@@ -60,6 +62,7 @@ class LSA(nn.Module):
dots = dots.masked_fill(mask, mask_value)
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')

View File

@@ -0,0 +1,147 @@
import torch
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange, Reduce
# helpers
def exists(val):
return val is not None
def default(val ,d):
return val if exists(val) else d
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# patch merger class
class PatchMerger(nn.Module):
def __init__(self, dim, num_tokens_out):
super().__init__()
self.scale = dim ** -0.5
self.norm = nn.LayerNorm(dim)
self.queries = nn.Parameter(torch.randn(num_tokens_out, dim))
def forward(self, x):
x = self.norm(x)
sim = torch.matmul(self.queries, x.transpose(-1, -2)) * self.scale
attn = sim.softmax(dim = -1)
return torch.matmul(attn, x)
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., patch_merge_layer = None, patch_merge_num_tokens = 8):
super().__init__()
self.layers = nn.ModuleList([])
self.patch_merge_layer_index = default(patch_merge_layer, depth // 2) - 1 # default to mid-way through transformer, as shown in paper
self.patch_merger = PatchMerger(dim = dim, num_tokens_out = patch_merge_num_tokens)
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for index, (attn, ff) in enumerate(self.layers):
x = attn(x) + x
x = ff(x) + x
if index == self.patch_merge_layer_index:
x = self.patch_merger(x)
return x
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, patch_merge_layer = None, patch_merge_num_tokens = 8, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.Linear(patch_dim, dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, patch_merge_layer, patch_merge_num_tokens)
self.mlp_head = nn.Sequential(
Reduce('b n d -> b d', 'mean'),
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
x += self.pos_embedding[:, :n]
x = self.dropout(x)
x = self.transformer(x)
return self.mlp_head(x)