Compare commits

...

32 Commits

Author SHA1 Message Date
Phil Wang
d27721a85a add scalable vit, from bytedance AI 2022-03-22 17:02:47 -07:00
Phil Wang
cb22cbbd19 update to einops 0.4, which is torchscript jit friendly 2022-03-22 13:58:00 -07:00
Phil Wang
6db20debb4 add patch merger 2022-03-01 16:50:17 -08:00
Phil Wang
1bae5d3cc5 allow for rectangular images for efficient adapter 2022-01-31 08:55:31 -08:00
Phil Wang
25b384297d return None from extractor if no attention layers 2022-01-28 17:49:58 -08:00
Phil Wang
64a07f50e6 epsilon should be inside square root 2022-01-24 17:24:41 -08:00
Phil Wang
126d204ff2 fix block repeats in readme example for Nest 2022-01-22 21:32:53 -08:00
Phil Wang
c1528acd46 fix feature maps in Nest, thanks to @MarkYangjiayi 2022-01-22 13:17:30 -08:00
Phil Wang
1cc0f182a6 decoder positional embedding needs to be reapplied https://twitter.com/giffmana/status/1479195631587631104 2022-01-06 13:14:41 -08:00
Phil Wang
28eaba6115 0.26.2 2022-01-03 12:56:34 -08:00
Phil Wang
0082301f9e build @jrounds suggestion 2022-01-03 12:56:25 -08:00
Phil Wang
91ed738731 0.26.1 2021-12-30 19:31:26 -08:00
Phil Wang
1b58daa20a Merge pull request #186 from chinhsuanwu/mobilevit
Update MobileViT
2021-12-30 19:31:01 -08:00
chinhsuanwu
f2414b2c1b Update MobileViT 2021-12-30 05:52:23 +08:00
Phil Wang
891b92eb74 readme 2021-12-28 16:00:00 -08:00
Phil Wang
70ba532599 add ViT for small datasets https://arxiv.org/abs/2112.13492 2021-12-28 10:58:21 -08:00
Phil Wang
e52ac41955 allow extractor to only return embeddings, to ready for vision transformers to be used in x-clip 2021-12-25 12:31:21 -08:00
Phil Wang
0891885485 include tests in package for conda 2021-12-22 12:44:29 -08:00
Phil Wang
976f489230 add some tests 2021-12-22 09:13:31 -08:00
Phil Wang
2c368d1d4e add extractor wrapper 2021-12-21 11:11:39 -08:00
Phil Wang
b983bbee39 release MobileViT, from @murufeng 2021-12-21 10:22:59 -08:00
Phil Wang
86a7302ba6 Merge pull request #181 from murufeng/main
Add MobileViT
2021-12-21 09:51:56 -08:00
murufeng
89d3a04b3f Add files via upload 2021-12-21 20:48:34 +08:00
murufeng
e7075c64aa Update README.md 2021-12-21 20:44:30 +08:00
murufeng
5ea1559e4c Add files via upload 2021-12-21 20:41:01 +08:00
Phil Wang
f4b0b14094 add ATS to table of contents 2021-12-03 20:07:18 -08:00
Phil Wang
365b4d931e add adaptive token sampling paper 2021-12-03 19:52:40 -08:00
Phil Wang
79c864d796 link to community youtuber 2021-11-24 08:13:52 -08:00
Phil Wang
b45c1356a1 cleanup 2021-11-22 22:53:02 -08:00
Phil Wang
ff44d97cb0 make initial channels customizable for PiT 2021-11-22 18:08:49 -08:00
Phil Wang
d35345df6a remove wip 2021-11-22 17:43:04 -08:00
Phil Wang
b69b5af34f dynamic positional bias for crossformer the more efficient way as described in appendix of paper 2021-11-22 17:39:36 -08:00
25 changed files with 1558 additions and 37 deletions

33
.github/workflows/python-test.yml vendored Normal file
View File

@@ -0,0 +1,33 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: Test
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.7, 3.8, 3.9]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install pytest
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Test with pytest
run: |
python setup.py test

1
MANIFEST.in Normal file
View File

@@ -0,0 +1 @@
recursive-include tests *

284
README.md
View File

@@ -16,11 +16,17 @@
- [LeViT](#levit)
- [CvT](#cvt)
- [Twins SVT](#twins-svt)
- [CrossFormer](#crossformer)
- [RegionViT](#regionvit)
- [ScalableViT](#scalablevit)
- [NesT](#nest)
- [MobileViT](#mobilevit)
- [Masked Autoencoder](#masked-autoencoder)
- [Simple Masked Image Modeling](#simple-masked-image-modeling)
- [Masked Patch Prediction](#masked-patch-prediction)
- [Adaptive Token Sampling](#adaptive-token-sampling)
- [Patch Merger](#patch-merger)
- [Vision Transformer for Small Datasets](#vision-transformer-for-small-datasets)
- [Dino](#dino)
- [Accessing Attention](#accessing-attention)
- [Research Ideas](#research-ideas)
@@ -493,7 +499,7 @@ img = torch.randn(1, 3, 224, 224)
pred = model(img) # (1, 1000)
```
## CrossFormer (wip)
## CrossFormer
<img src="./images/crossformer.png" width="400px"></img>
@@ -520,6 +526,38 @@ img = torch.randn(1, 3, 224, 224)
pred = model(img) # (1, 1000)
```
## ScalableViT
<img src="./images/scalable-vit-1.png" width="400px"></img>
<img src="./images/scalable-vit-2.png" width="400px"></img>
This Bytedance AI <a href="https://arxiv.org/abs/2203.10790">paper</a> proposes the Scalable Self Attention (SSA) and the Interactive Windowed Self Attention (IWSA) modules. The SSA alleviates the computation needed at earlier stages by reducing the key / value feature map by some factor (`reduction_factor`), while modulating the dimension of the queries and keys (`ssa_dim_key`). The IWSA performs self attention within local windows, similar to other vision transformer papers. However, they add a residual of the values, passed through a convolution of kernel size 3, which they named Local Interactive Module (LIM).
They make the claim in this paper that this scheme outperforms Swin Transformer, and also demonstrate competitive performance against Crossformer.
You can use it as follows (ex. ScalableViT-S)
```python
import torch
from vit_pytorch.scalable_vit import ScalableViT
model = ScalableViT(
num_classes = 1000,
dim = 64, # starting model dimension. at every stage, dimension is doubled
heads = (2, 4, 8, 16), # number of attention heads at each stage
depth = (2, 2, 20, 2), # number of transformer blocks at each stage
ssa_dim_key = (40, 40, 40, 32), # the dimension of the attention keys (and queries) for SSA. in the paper, they represented this as a scale factor on the base dimension per key (ssa_dim_key / dim_key)
reduction_factor = (8, 4, 2, 1), # downsampling of the key / values in SSA. in the paper, this was represented as (reduction_factor ** -2)
window_size = (64, 32, None, None), # window size of the IWSA at each stage. None means no windowing needed
dropout = 0.1, # attention and feedforward dropout
).cuda()
img = torch.randn(1, 3, 256, 256).cuda()
preds = model(img) # (1, 1000)
```
## NesT
<img src="./images/nest.png" width="400px"></img>
@@ -538,7 +576,7 @@ nest = NesT(
dim = 96,
heads = 3,
num_hierarchies = 3, # number of hierarchies
block_repeats = (8, 4, 1), # the number of transformer blocks at each heirarchy, starting from the bottom
block_repeats = (2, 2, 8), # the number of transformer blocks at each heirarchy, starting from the bottom
num_classes = 1000
)
@@ -547,6 +585,31 @@ img = torch.randn(1, 3, 224, 224)
pred = nest(img) # (1, 1000)
```
## MobileViT
<img src="./images/mbvit.png" width="400px"></img>
This <a href="https://arxiv.org/abs/2110.02178">paper</a> introduce MobileViT, a light-weight and general purpose vision transformer for mobile devices. MobileViT presents a different
perspective for the global processing of information with transformers.
You can use it with the following code (ex. mobilevit_xs)
```python
import torch
from vit_pytorch.mobile_vit import MobileViT
mbvit_xs = MobileViT(
image_size = (256, 256),
dims = [96, 120, 144],
channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384],
num_classes = 1000
)
img = torch.randn(1, 3, 256, 256)
pred = mbvit_xs(img) # (1, 1000)
```
## Simple Masked Image Modeling
<img src="./images/simmim.png" width="400px"/>
@@ -595,6 +658,8 @@ A new <a href="https://arxiv.org/abs/2111.06377">Kaiming He paper</a> proposes a
<a href="https://www.youtube.com/watch?v=LKixq2S2Pz8">DeepReader quick paper review</a>
<a href="https://www.youtube.com/watch?v=Dp6iICL2dVI">AI Coffeebreak with Letitia</a>
You can use it with the following code
```python
@@ -676,6 +741,131 @@ for _ in range(100):
torch.save(model.state_dict(), './pretrained-net.pt')
```
## Adaptive Token Sampling
<img src="./images/ats.png" width="400px"></img>
This <a href="https://arxiv.org/abs/2111.15667">paper</a> proposes to use the CLS attention scores, re-weighed by the norms of the value heads, as means to discard unimportant tokens at different layers.
```python
import torch
from vit_pytorch.ats_vit import ViT
v = ViT(
image_size = 256,
patch_size = 16,
num_classes = 1000,
dim = 1024,
depth = 6,
max_tokens_per_depth = (256, 128, 64, 32, 16, 8), # a tuple that denotes the maximum number of tokens that any given layer should have. if the layer has greater than this amount, it will undergo adaptive token sampling
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
img = torch.randn(4, 3, 256, 256)
preds = v(img) # (4, 1000)
# you can also get a list of the final sampled patch ids
# a value of -1 denotes padding
preds, token_ids = v(img, return_sampled_token_ids = True) # (4, 1000), (4, <=8)
```
## Patch Merger
<img src="./images/patch_merger.png" width="400px"></img>
This <a href="https://arxiv.org/abs/2202.12015">paper</a> proposes a simple module (Patch Merger) for reducing the number of tokens at any layer of a vision transformer without sacrificing performance.
```python
import torch
from vit_pytorch.vit_with_patch_merger import ViT
v = ViT(
image_size = 256,
patch_size = 16,
num_classes = 1000,
dim = 1024,
depth = 12,
heads = 8,
patch_merge_layer = 6, # at which transformer layer to do patch merging
patch_merge_num_tokens = 8, # the output number of tokens from the patch merge
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
img = torch.randn(4, 3, 256, 256)
preds = v(img) # (4, 1000)
```
One can also use the `PatchMerger` module by itself
```python
import torch
from vit_pytorch.vit_with_patch_merger import PatchMerger
merger = PatchMerger(
dim = 1024,
num_tokens_out = 8 # output number of tokens
)
features = torch.randn(4, 256, 1024) # (batch, num tokens, dimension)
out = merger(features) # (4, 8, 1024)
```
## Vision Transformer for Small Datasets
<img src="./images/vit_for_small_datasets.png" width="400px"></img>
This <a href="https://arxiv.org/abs/2112.13492">paper</a> proposes a new image to patch function that incorporates shifts of the image, before normalizing and dividing the image into patches. I have found shifting to be extremely helpful in some other transformers work, so decided to include this for further explorations. It also includes the `LSA` with the learned temperature and masking out of a token's attention to itself.
You can use as follows:
```python
import torch
from vit_pytorch.vit_for_small_dataset import ViT
v = ViT(
image_size = 256,
patch_size = 16,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
img = torch.randn(4, 3, 256, 256)
preds = v(img) # (1, 1000)
```
You can also use the `SPT` from this paper as a standalone module
```python
import torch
from vit_pytorch.vit_for_small_dataset import SPT
spt = SPT(
dim = 1024,
patch_size = 16,
channels = 3
)
img = torch.randn(4, 3, 256, 256)
tokens = spt(img) # (4, 256, 1024)
```
## Dino
<img src="./images/dino.png" width="350px"></img>
@@ -771,6 +961,41 @@ to cleanup the class and the hooks once you have collected enough data
v = v.eject() # wrapper is discarded and original ViT instance is returned
```
## Accessing Embeddings
You can similarly access the embeddings with the `Extractor` wrapper
```python
import torch
from vit_pytorch.vit import ViT
v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
# import Recorder and wrap the ViT
from vit_pytorch.extractor import Extractor
v = Extractor(v)
# forward pass now returns predictions and the attention maps
img = torch.randn(1, 3, 256, 256)
logits, embeddings = v(img)
# there is one extra token due to the CLS token
embeddings # (1, 65, 1024) - (batch x patches x model dim)
```
## Research Ideas
### Efficient Attention
@@ -1116,6 +1341,61 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```
```bibtex
@misc{fayyaz2021ats,
title = {ATS: Adaptive Token Sampling For Efficient Vision Transformers},
author = {Mohsen Fayyaz and Soroush Abbasi Kouhpayegani and Farnoush Rezaei Jafari and Eric Sommerlade and Hamid Reza Vaezi Joze and Hamed Pirsiavash and Juergen Gall},
year = {2021},
eprint = {2111.15667},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@misc{mehta2021mobilevit,
title = {MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer},
author = {Sachin Mehta and Mohammad Rastegari},
year = {2021},
eprint = {2110.02178},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@misc{lee2021vision,
title = {Vision Transformer for Small-Size Datasets},
author = {Seung Hoon Lee and Seunghyun Lee and Byung Cheol Song},
year = {2021},
eprint = {2112.13492},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@misc{renggli2022learning,
title = {Learning to Merge Tokens in Vision Transformers},
author = {Cedric Renggli and André Susano Pinto and Neil Houlsby and Basil Mustafa and Joan Puigcerver and Carlos Riquelme},
year = {2022},
eprint = {2202.12015},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@misc{yang2022scalablevit,
title = {ScalableViT: Rethinking the Context-oriented Generalization of Vision Transformer},
author = {Rui Yang and Hailong Ma and Jie Wu and Yansong Tang and Xuefeng Xiao and Min Zheng and Xiu Li},
year = {2022},
eprint = {2203.10790},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@misc{vaswani2017attention,
title = {Attention Is All You Need},

BIN
images/ats.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 198 KiB

BIN
images/mbvit.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 206 KiB

BIN
images/patch_merger.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 54 KiB

BIN
images/scalable-vit-1.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 79 KiB

BIN
images/scalable-vit-2.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 62 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 80 KiB

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.24.0',
version = '0.28.0',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',
@@ -15,10 +15,16 @@ setup(
'image recognition'
],
install_requires=[
'einops>=0.3',
'einops>=0.4.1',
'torch>=1.6',
'torchvision'
],
setup_requires=[
'pytest-runner',
],
tests_require=[
'pytest'
],
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',

20
tests/test.py Normal file
View File

@@ -0,0 +1,20 @@
import torch
from vit_pytorch import ViT
def test():
v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
img = torch.randn(1, 3, 256, 256)
preds = v(img)
assert preds.shape == (1, 1000), 'correct logits outputted'

262
vit_pytorch/ats_vit.py Normal file
View File

@@ -0,0 +1,262 @@
import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch import nn, einsum
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
def exists(val):
return val is not None
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# adaptive token sampling functions and classes
def log(t, eps = 1e-6):
return torch.log(t + eps)
def sample_gumbel(shape, device, dtype, eps = 1e-6):
u = torch.empty(shape, device = device, dtype = dtype).uniform_(0, 1)
return -log(-log(u, eps), eps)
def batched_index_select(values, indices, dim = 1):
value_dims = values.shape[(dim + 1):]
values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
indices = indices[(..., *((None,) * len(value_dims)))]
indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
value_expand_len = len(indices_shape) - (dim + 1)
values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)]
value_expand_shape = [-1] * len(values.shape)
expand_slice = slice(dim, (dim + value_expand_len))
value_expand_shape[expand_slice] = indices.shape[expand_slice]
values = values.expand(*value_expand_shape)
dim += value_expand_len
return values.gather(dim, indices)
class AdaptiveTokenSampling(nn.Module):
def __init__(self, output_num_tokens, eps = 1e-6):
super().__init__()
self.eps = eps
self.output_num_tokens = output_num_tokens
def forward(self, attn, value, mask):
heads, output_num_tokens, eps, device, dtype = attn.shape[1], self.output_num_tokens, self.eps, attn.device, attn.dtype
# first get the attention values for CLS token to all other tokens
cls_attn = attn[..., 0, 1:]
# calculate the norms of the values, for weighting the scores, as described in the paper
value_norms = value[..., 1:, :].norm(dim = -1)
# weigh the attention scores by the norm of the values, sum across all heads
cls_attn = einsum('b h n, b h n -> b n', cls_attn, value_norms)
# normalize to 1
normed_cls_attn = cls_attn / (cls_attn.sum(dim = -1, keepdim = True) + eps)
# instead of using inverse transform sampling, going to invert the softmax and use gumbel-max sampling instead
pseudo_logits = log(normed_cls_attn)
# mask out pseudo logits for gumbel-max sampling
mask_without_cls = mask[:, 1:]
mask_value = -torch.finfo(attn.dtype).max / 2
pseudo_logits = pseudo_logits.masked_fill(~mask_without_cls, mask_value)
# expand k times, k being the adaptive sampling number
pseudo_logits = repeat(pseudo_logits, 'b n -> b k n', k = output_num_tokens)
pseudo_logits = pseudo_logits + sample_gumbel(pseudo_logits.shape, device = device, dtype = dtype)
# gumble-max and add one to reserve 0 for padding / mask
sampled_token_ids = pseudo_logits.argmax(dim = -1) + 1
# calculate unique using torch.unique and then pad the sequence from the right
unique_sampled_token_ids_list = [torch.unique(t, sorted = True) for t in torch.unbind(sampled_token_ids)]
unique_sampled_token_ids = pad_sequence(unique_sampled_token_ids_list, batch_first = True)
# calculate the new mask, based on the padding
new_mask = unique_sampled_token_ids != 0
# CLS token never gets masked out (gets a value of True)
new_mask = F.pad(new_mask, (1, 0), value = True)
# prepend a 0 token id to keep the CLS attention scores
unique_sampled_token_ids = F.pad(unique_sampled_token_ids, (1, 0), value = 0)
expanded_unique_sampled_token_ids = repeat(unique_sampled_token_ids, 'b n -> b h n', h = heads)
# gather the new attention scores
new_attn = batched_index_select(attn, expanded_unique_sampled_token_ids, dim = 2)
# return the sampled attention scores, new mask (denoting padding), as well as the sampled token indices (for the residual)
return new_attn, new_mask, unique_sampled_token_ids
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., output_num_tokens = None):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.output_num_tokens = output_num_tokens
self.ats = AdaptiveTokenSampling(output_num_tokens) if exists(output_num_tokens) else None
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, *, mask):
num_tokens = x.shape[1]
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
if exists(mask):
dots_mask = rearrange(mask, 'b i -> b 1 i 1') * rearrange(mask, 'b j -> b 1 1 j')
mask_value = -torch.finfo(dots.dtype).max
dots = dots.masked_fill(~dots_mask, mask_value)
attn = self.attend(dots)
sampled_token_ids = None
# if adaptive token sampling is enabled
# and number of tokens is greater than the number of output tokens
if exists(self.output_num_tokens) and (num_tokens - 1) > self.output_num_tokens:
attn, mask, sampled_token_ids = self.ats(attn, v, mask = mask)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out), mask, sampled_token_ids
class Transformer(nn.Module):
def __init__(self, dim, depth, max_tokens_per_depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
assert len(max_tokens_per_depth) == depth, 'max_tokens_per_depth must be a tuple of length that is equal to the depth of the transformer'
assert sorted(max_tokens_per_depth, reverse = True) == list(max_tokens_per_depth), 'max_tokens_per_depth must be in decreasing order'
assert min(max_tokens_per_depth) > 0, 'max_tokens_per_depth must have at least 1 token at any layer'
self.layers = nn.ModuleList([])
for _, output_num_tokens in zip(range(depth), max_tokens_per_depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, output_num_tokens = output_num_tokens, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
b, n, device = *x.shape[:2], x.device
# use mask to keep track of the paddings when sampling tokens
# as the duplicates (when sampling) are just removed, as mentioned in the paper
mask = torch.ones((b, n), device = device, dtype = torch.bool)
token_ids = torch.arange(n, device = device)
token_ids = repeat(token_ids, 'n -> b n', b = b)
for attn, ff in self.layers:
attn_out, mask, sampled_token_ids = attn(x, mask = mask)
# when token sampling, one needs to then gather the residual tokens with the sampled token ids
if exists(sampled_token_ids):
x = batched_index_select(x, sampled_token_ids, dim = 1)
token_ids = batched_index_select(token_ids, sampled_token_ids, dim = 1)
x = x + attn_out
x = ff(x) + x
return x, token_ids
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, max_tokens_per_depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.Linear(patch_dim, dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, max_tokens_per_depth, heads, dim_head, mlp_dim, dropout)
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img, return_sampled_token_ids = False):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x, token_ids = self.transformer(x)
logits = self.mlp_head(x[:, 0])
if return_sampled_token_ids:
# remove CLS token and decrement by 1 to make -1 the padding
token_ids = token_ids[:, 1:] - 1
return logits, token_ids
return logits

View File

@@ -6,18 +6,9 @@ import torch.nn.functional as F
# helpers
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
def divisible_by(val, d):
return (val % d) == 0
# cross embed layer
class CrossEmbedLayer(nn.Module):
@@ -71,9 +62,9 @@ class LayerNorm(nn.Module):
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (std + self.eps) * self.g + self.b
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
def FeedForward(dim, mult = 4, dropout = 0.):
return nn.Sequential(
@@ -103,12 +94,25 @@ class Attention(nn.Module):
self.attn_type = attn_type
self.window_size = window_size
self.dpb = DynamicPositionBias(dim // 4)
self.norm = LayerNorm(dim)
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(inner_dim, dim, 1)
# positions
self.dpb = DynamicPositionBias(dim // 4)
# calculate and store indices for retrieving bias
pos = torch.arange(window_size)
grid = torch.stack(torch.meshgrid(pos, pos))
grid = rearrange(grid, 'c i j -> (i j) c')
rel_pos = grid[:, None] - grid[None, :]
rel_pos += window_size - 1
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1)
self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False)
def forward(self, x):
*_, height, width, heads, wsz, device = *x.shape, self.heads, self.window_size, x.device
@@ -136,12 +140,11 @@ class Attention(nn.Module):
# add dynamic positional bias
i_pos = torch.arange(wsz, device = device)
j_pos = torch.arange(wsz, device = device)
grid = torch.stack(torch.meshgrid(i_pos, j_pos))
grid = rearrange(grid, 'c i j -> (i j) c')
rel_ij = grid[:, None] - grid[None, :]
rel_pos_bias = self.dpb(rel_ij.float())
pos = torch.arange(-wsz, wsz + 1, device = device)
rel_pos = torch.stack(torch.meshgrid(pos, pos))
rel_pos = rearrange(rel_pos, 'c i j -> (i j) c')
biases = self.dpb(rel_pos.float())
rel_pos_bias = biases[self.rel_pos_indices]
sim = sim + rel_pos_bias

View File

@@ -30,9 +30,9 @@ class LayerNorm(nn.Module): # layernorm, but done in the channel dimension #1
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (std + self.eps) * self.g + self.b
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):

View File

@@ -3,12 +3,16 @@ from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
def pair(t):
return t if isinstance(t, tuple) else (t, t)
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, transformer, pool = 'cls', channels = 3):
super().__init__()
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
image_size_h, image_size_w = pair(image_size)
assert image_size_h % patch_size == 0 and image_size_w % patch_size == 0, 'image dimensions must be divisible by the patch size'
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
num_patches = (image_size // patch_size) ** 2
num_patches = (image_size_h // patch_size) * (image_size_w // patch_size)
patch_dim = channels * patch_size ** 2
self.to_patch_embedding = nn.Sequential(

70
vit_pytorch/extractor.py Normal file
View File

@@ -0,0 +1,70 @@
import torch
from torch import nn
def exists(val):
return val is not None
class Extractor(nn.Module):
def __init__(
self,
vit,
device = None,
layer_name = 'transformer',
layer_save_input = False,
return_embeddings_only = False
):
super().__init__()
self.vit = vit
self.data = None
self.latents = None
self.hooks = []
self.hook_registered = False
self.ejected = False
self.device = device
self.layer_name = layer_name
self.layer_save_input = layer_save_input # whether to save input or output of layer
self.return_embeddings_only = return_embeddings_only
def _hook(self, _, inputs, output):
tensor_to_save = inputs if self.layer_save_input else output
self.latents = tensor_to_save.clone().detach()
def _register_hook(self):
assert hasattr(self.vit, self.layer_name), 'layer whose output to take as embedding not found in vision transformer'
layer = getattr(self.vit, self.layer_name)
handle = layer.register_forward_hook(self._hook)
self.hooks.append(handle)
self.hook_registered = True
def eject(self):
self.ejected = True
for hook in self.hooks:
hook.remove()
self.hooks.clear()
return self.vit
def clear(self):
del self.latents
self.latents = None
def forward(
self,
img,
return_embeddings_only = False
):
assert not self.ejected, 'extractor has been ejected, cannot be used anymore'
self.clear()
if not self.hook_registered:
self._register_hook()
pred = self.vit(img)
target_device = self.device if exists(self.device) else img.device
latents = self.latents.to(target_device)
if return_embeddings_only or self.return_embeddings_only:
return latents
return pred, latents

View File

@@ -71,6 +71,10 @@ class MAE(nn.Module):
decoder_tokens = self.enc_to_dec(encoded_tokens)
# reapply decoder position embedding to unmasked tokens
decoder_tokens = decoder_tokens + self.decoder_pos_emb(unmasked_indices)
# repeat mask tokens for number of masked, and add the positions using the masked indices derived above
mask_tokens = repeat(self.mask_token, 'd -> b n d', b = batch, n = num_masked)

247
vit_pytorch/mobile_vit.py Normal file
View File

@@ -0,0 +1,247 @@
import torch
import torch.nn as nn
from einops import rearrange
from einops.layers.torch import Reduce
# helpers
def conv_1x1_bn(inp, oup):
return nn.Sequential(
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.SiLU()
)
def conv_nxn_bn(inp, oup, kernal_size=3, stride=1):
return nn.Sequential(
nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False),
nn.BatchNorm2d(oup),
nn.SiLU()
)
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.SiLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim=-1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(
t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b p h n d -> b p n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
"""Transformer block described in ViT.
Paper: https://arxiv.org/abs/2010.11929
Based on: https://github.com/lucidrains/vit-pytorch
"""
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class MV2Block(nn.Module):
"""MV2 block described in MobileNetV2.
Paper: https://arxiv.org/pdf/1801.04381
Based on: https://github.com/tonylins/pytorch-mobilenet-v2
"""
def __init__(self, inp, oup, stride=1, expansion=4):
super().__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = int(inp * expansion)
self.use_res_connect = self.stride == 1 and inp == oup
if expansion == 1:
self.conv = nn.Sequential(
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
1, groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
)
else:
self.conv = nn.Sequential(
# pw
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
1, groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
)
def forward(self, x):
out = self.conv(x)
if self.use_res_connect:
out = out + x
return out
class MobileViTBlock(nn.Module):
def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):
super().__init__()
self.ph, self.pw = patch_size
self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
self.conv2 = conv_1x1_bn(channel, dim)
self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout)
self.conv3 = conv_1x1_bn(dim, channel)
self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)
def forward(self, x):
y = x.clone()
# Local representations
x = self.conv1(x)
x = self.conv2(x)
# Global representations
_, _, h, w = x.shape
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d',
ph=self.ph, pw=self.pw)
x = self.transformer(x)
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)',
h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)
# Fusion
x = self.conv3(x)
x = torch.cat((x, y), 1)
x = self.conv4(x)
return x
class MobileViT(nn.Module):
"""MobileViT.
Paper: https://arxiv.org/abs/2110.02178
Based on: https://github.com/chinhsuanwu/mobilevit-pytorch
"""
def __init__(
self,
image_size,
dims,
channels,
num_classes,
expansion=4,
kernel_size=3,
patch_size=(2, 2),
depths=(2, 4, 3)
):
super().__init__()
assert len(dims) == 3, 'dims must be a tuple of 3'
assert len(depths) == 3, 'depths must be a tuple of 3'
ih, iw = image_size
ph, pw = patch_size
assert ih % ph == 0 and iw % pw == 0
init_dim, *_, last_dim = channels
self.conv1 = conv_nxn_bn(3, init_dim, stride=2)
self.stem = nn.ModuleList([])
self.stem.append(MV2Block(channels[0], channels[1], 1, expansion))
self.stem.append(MV2Block(channels[1], channels[2], 2, expansion))
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
self.trunk = nn.ModuleList([])
self.trunk.append(nn.ModuleList([
MV2Block(channels[3], channels[4], 2, expansion),
MobileViTBlock(dims[0], depths[0], channels[5],
kernel_size, patch_size, int(dims[0] * 2))
]))
self.trunk.append(nn.ModuleList([
MV2Block(channels[5], channels[6], 2, expansion),
MobileViTBlock(dims[1], depths[1], channels[7],
kernel_size, patch_size, int(dims[1] * 4))
]))
self.trunk.append(nn.ModuleList([
MV2Block(channels[7], channels[8], 2, expansion),
MobileViTBlock(dims[2], depths[2], channels[9],
kernel_size, patch_size, int(dims[2] * 4))
]))
self.to_logits = nn.Sequential(
conv_1x1_bn(channels[-2], last_dim),
Reduce('b c h w -> b c', 'mean'),
nn.Linear(channels[-1], num_classes, bias=False)
)
def forward(self, x):
x = self.conv1(x)
for conv in self.stem:
x = conv(x)
for conv, attn in self.trunk:
x = conv(x)
x = attn(x)
return self.to_logits(x)

View File

@@ -20,9 +20,9 @@ class LayerNorm(nn.Module):
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (std + self.eps) * self.g + self.b
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
@@ -131,10 +131,11 @@ class NesT(nn.Module):
seq_len = (fmap_size // blocks) ** 2 # sequence length is held constant across heirarchy
hierarchies = list(reversed(range(num_hierarchies)))
mults = [2 ** i for i in hierarchies]
mults = [2 ** i for i in reversed(hierarchies)]
layer_heads = list(map(lambda t: t * heads, mults))
layer_dims = list(map(lambda t: t * dim, mults))
last_dim = layer_dims[-1]
layer_dims = [*layer_dims, layer_dims[-1]]
dim_pairs = zip(layer_dims[:-1], layer_dims[1:])
@@ -157,10 +158,11 @@ class NesT(nn.Module):
Aggregate(dim_in, dim_out) if not is_last else nn.Identity()
]))
self.mlp_head = nn.Sequential(
LayerNorm(dim),
LayerNorm(last_dim),
Reduce('b c h w -> b c', 'mean'),
nn.Linear(dim, num_classes)
nn.Linear(last_dim, num_classes)
)
def forward(self, img):

View File

@@ -129,14 +129,15 @@ class PiT(nn.Module):
mlp_dim,
dim_head = 64,
dropout = 0.,
emb_dropout = 0.
emb_dropout = 0.,
channels = 3
):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
assert isinstance(depth, tuple), 'depth must be a tuple of integers, specifying the number of blocks before each downsizing'
heads = cast_tuple(heads, len(depth))
patch_dim = 3 * patch_size ** 2
patch_dim = channels * patch_size ** 2
self.to_patch_embedding = nn.Sequential(
nn.Unfold(kernel_size = patch_size, stride = patch_size // 2),

View File

@@ -55,5 +55,5 @@ class Recorder(nn.Module):
target_device = self.device if self.device is not None else img.device
recordings = tuple(map(lambda t: t.to(target_device), self.recordings))
attns = torch.stack(recordings, dim = 1)
attns = torch.stack(recordings, dim = 1) if len(recordings) > 0 else None
return pred, attns

302
vit_pytorch/scalable_vit.py Normal file
View File

@@ -0,0 +1,302 @@
from functools import partial
import torch
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange, Reduce
# helpers
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
# helper classes
class ChanLayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = ChanLayerNorm(dim)
self.fn = fn
def forward(self, x):
return self.fn(self.norm(x))
class Downsample(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.conv = nn.Conv2d(dim_in, dim_out, 3, stride = 2, padding = 1)
def forward(self, x):
return self.conv(x)
class PEG(nn.Module):
def __init__(self, dim, kernel_size = 3):
super().__init__()
self.proj = nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1)
def forward(self, x):
return self.proj(x) + x
# feedforward
class FeedForward(nn.Module):
def __init__(self, dim, expansion_factor = 4, dropout = 0.):
super().__init__()
inner_dim = dim * expansion_factor
self.net = nn.Sequential(
nn.Conv2d(dim, inner_dim, 1),
nn.GELU(),
nn.Dropout(dropout),
nn.Conv2d(inner_dim, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
# attention
class ScalableSelfAttention(nn.Module):
def __init__(
self,
dim,
heads = 8,
dim_key = 64,
dim_value = 64,
dropout = 0.,
reduction_factor = 1
):
super().__init__()
self.heads = heads
self.scale = dim_key ** -0.5
self.attend = nn.Softmax(dim = -1)
self.to_q = nn.Conv2d(dim, dim_key * heads, 1, bias = False)
self.to_k = nn.Conv2d(dim, dim_key * heads, reduction_factor, stride = reduction_factor, bias = False)
self.to_v = nn.Conv2d(dim, dim_value * heads, reduction_factor, stride = reduction_factor, bias = False)
self.to_out = nn.Sequential(
nn.Conv2d(dim_value * heads, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
height, width, heads = *x.shape[-2:], self.heads
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
# split out heads
q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = heads), (q, k, v))
# similarity
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
# attention
attn = self.attend(dots)
# aggregate values
out = torch.matmul(attn, v)
# merge back heads
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = height, y = width)
return self.to_out(out)
class InteractiveWindowedSelfAttention(nn.Module):
def __init__(
self,
dim,
window_size,
heads = 8,
dim_key = 64,
dim_value = 64,
dropout = 0.
):
super().__init__()
self.heads = heads
self.scale = dim_key ** -0.5
self.window_size = window_size
self.attend = nn.Softmax(dim = -1)
self.local_interactive_module = nn.Conv2d(dim_value * heads, dim_value * heads, 3, padding = 1)
self.to_q = nn.Conv2d(dim, dim_key * heads, 1, bias = False)
self.to_k = nn.Conv2d(dim, dim_key * heads, 1, bias = False)
self.to_v = nn.Conv2d(dim, dim_value * heads, 1, bias = False)
self.to_out = nn.Sequential(
nn.Conv2d(dim_value * heads, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
height, width, heads, wsz = *x.shape[-2:], self.heads, self.window_size
wsz = default(wsz, height) # take height as window size if not given
assert (height % wsz) == 0 and (width % wsz) == 0, f'height ({height}) or width ({width}) of feature map is not divisible by the window size ({wsz})'
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
# get output of LIM
local_out = self.local_interactive_module(v)
# divide into window (and split out heads) for efficient self attention
q, k, v = map(lambda t: rearrange(t, 'b (h d) (x w1) (y w2) -> (b x y) h (w1 w2) d', h = heads, w1 = wsz, w2 = wsz), (q, k, v))
# similarity
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
# attention
attn = self.attend(dots)
# aggregate values
out = torch.matmul(attn, v)
# reshape the windows back to full feature map (and merge heads)
out = rearrange(out, '(b x y) h (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz, y = width // wsz, w1 = wsz, w2 = wsz)
# add LIM output
out = out + local_out
return self.to_out(out)
class Transformer(nn.Module):
def __init__(
self,
dim,
depth,
heads = 8,
ff_expansion_factor = 4,
dropout = 0.,
ssa_dim_key = 64,
ssa_dim_value = 64,
ssa_reduction_factor = 1,
iwsa_dim_key = 64,
iwsa_dim_value = 64,
iwsa_window_size = 64,
norm_output = True
):
super().__init__()
self.layers = nn.ModuleList([])
for ind in range(depth):
is_first = ind == 0
self.layers.append(nn.ModuleList([
PreNorm(dim, ScalableSelfAttention(dim, heads = heads, dim_key = ssa_dim_key, dim_value = ssa_dim_value, reduction_factor = ssa_reduction_factor, dropout = dropout)),
PreNorm(dim, FeedForward(dim, expansion_factor = ff_expansion_factor, dropout = dropout)),
PEG(dim) if is_first else None,
PreNorm(dim, FeedForward(dim, expansion_factor = ff_expansion_factor, dropout = dropout)),
PreNorm(dim, InteractiveWindowedSelfAttention(dim, heads = heads, dim_key = iwsa_dim_key, dim_value = iwsa_dim_value, window_size = iwsa_window_size, dropout = dropout))
]))
self.norm = ChanLayerNorm(dim) if norm_output else nn.Identity()
def forward(self, x):
for ssa, ff1, peg, iwsa, ff2 in self.layers:
x = ssa(x) + x
x = ff1(x) + x
if exists(peg):
x = peg(x)
x = iwsa(x) + x
x = ff2(x) + x
return self.norm(x)
class ScalableViT(nn.Module):
def __init__(
self,
*,
num_classes,
dim,
depth,
heads,
reduction_factor,
ff_expansion_factor = 4,
iwsa_dim_key = 64,
iwsa_dim_value = 64,
window_size = 64,
ssa_dim_key = 64,
ssa_dim_value = 64,
channels = 3,
dropout = 0.
):
super().__init__()
self.to_patches = nn.Conv2d(channels, dim, 7, stride = 4, padding = 3)
assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'
num_stages = len(depth)
dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))
hyperparams_per_stage = [
heads,
ssa_dim_key,
ssa_dim_value,
reduction_factor,
iwsa_dim_key,
iwsa_dim_value,
window_size,
]
hyperparams_per_stage = list(map(partial(cast_tuple, length = num_stages), hyperparams_per_stage))
assert all(tuple(map(lambda arr: len(arr) == num_stages, hyperparams_per_stage)))
self.layers = nn.ModuleList([])
for ind, (layer_dim, layer_depth, layer_heads, layer_ssa_dim_key, layer_ssa_dim_value, layer_ssa_reduction_factor, layer_iwsa_dim_key, layer_iwsa_dim_value, layer_window_size) in enumerate(zip(dims, depth, *hyperparams_per_stage)):
is_last = ind == (num_stages - 1)
self.layers.append(nn.ModuleList([
Transformer(dim = layer_dim, depth = layer_depth, heads = layer_heads, ff_expansion_factor = ff_expansion_factor, dropout = dropout, ssa_dim_key = layer_ssa_dim_key, ssa_dim_value = layer_ssa_dim_value, ssa_reduction_factor = layer_ssa_reduction_factor, iwsa_dim_key = layer_iwsa_dim_key, iwsa_dim_value = layer_iwsa_dim_value, iwsa_window_size = layer_window_size),
Downsample(layer_dim, layer_dim * 2) if not is_last else None
]))
self.mlp_head = nn.Sequential(
Reduce('b d h w -> b d', 'mean'),
nn.LayerNorm(dims[-1]),
nn.Linear(dims[-1], num_classes)
)
def forward(self, img):
x = self.to_patches(img)
for transformer, downsample in self.layers:
x = transformer(x)
if exists(downsample):
x = downsample(x)
return self.mlp_head(x)

View File

@@ -38,9 +38,9 @@ class LayerNorm(nn.Module):
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (std + self.eps) * self.g + self.b
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):

View File

@@ -0,0 +1,142 @@
from math import sqrt
import torch
import torch.nn.functional as F
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class LSA(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.temperature = nn.Parameter(torch.log(torch.tensor(dim_head ** -0.5)))
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.temperature.exp()
mask = torch.eye(dots.shape[-1], device = dots.device, dtype = torch.bool)
mask_value = -torch.finfo(dots.dtype).max
dots = dots.masked_fill(mask, mask_value)
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, LSA(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class SPT(nn.Module):
def __init__(self, *, dim, patch_size, channels = 3):
super().__init__()
patch_dim = patch_size * patch_size * 5 * channels
self.to_patch_tokens = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim)
)
def forward(self, x):
shifts = ((1, -1, 0, 0), (-1, 1, 0, 0), (0, 0, 1, -1), (0, 0, -1, 1))
shifted_x = list(map(lambda shift: F.pad(x, shift), shifts))
x_with_shifts = torch.cat((x, *shifted_x), dim = 1)
return self.to_patch_tokens(x_with_shifts)
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = SPT(dim = dim, patch_size = patch_size, channels = channels)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)

View File

@@ -0,0 +1,144 @@
import torch
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange, Reduce
# helpers
def exists(val):
return val is not None
def default(val ,d):
return val if exists(val) else d
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# patch merger class
class PatchMerger(nn.Module):
def __init__(self, dim, num_tokens_out):
super().__init__()
self.scale = dim ** -0.5
self.norm = nn.LayerNorm(dim)
self.queries = nn.Parameter(torch.randn(num_tokens_out, dim))
def forward(self, x):
x = self.norm(x)
sim = torch.matmul(self.queries, x.transpose(-1, -2)) * self.scale
attn = sim.softmax(dim = -1)
return torch.matmul(attn, x)
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., patch_merge_layer = None, patch_merge_num_tokens = 8):
super().__init__()
self.layers = nn.ModuleList([])
self.patch_merge_layer_index = default(patch_merge_layer, depth // 2) - 1 # default to mid-way through transformer, as shown in paper
self.patch_merger = PatchMerger(dim = dim, num_tokens_out = patch_merge_num_tokens)
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for index, (attn, ff) in enumerate(self.layers):
x = attn(x) + x
x = ff(x) + x
if index == self.patch_merge_layer_index:
x = self.patch_merger(x)
return x
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, patch_merge_layer = None, patch_merge_num_tokens = 8, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.Linear(patch_dim, dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, patch_merge_layer, patch_merge_num_tokens)
self.mlp_head = nn.Sequential(
Reduce('b n d -> b d', 'mean'),
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
x += self.pos_embedding[:, :n]
x = self.dropout(x)
x = self.transformer(x)
return self.mlp_head(x)