mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 16:12:29 +00:00
Compare commits
20 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
719048d1bd | ||
|
|
d27721a85a | ||
|
|
cb22cbbd19 | ||
|
|
6db20debb4 | ||
|
|
1bae5d3cc5 | ||
|
|
25b384297d | ||
|
|
64a07f50e6 | ||
|
|
126d204ff2 | ||
|
|
c1528acd46 | ||
|
|
1cc0f182a6 | ||
|
|
28eaba6115 | ||
|
|
0082301f9e | ||
|
|
91ed738731 | ||
|
|
1b58daa20a | ||
|
|
f2414b2c1b | ||
|
|
891b92eb74 | ||
|
|
70ba532599 | ||
|
|
e52ac41955 | ||
|
|
0891885485 | ||
|
|
976f489230 |
33
.github/workflows/python-test.yml
vendored
Normal file
33
.github/workflows/python-test.yml
vendored
Normal file
@@ -0,0 +1,33 @@
|
||||
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
|
||||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
|
||||
|
||||
name: Test
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install pytest
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
python setup.py test
|
||||
1
MANIFEST.in
Normal file
1
MANIFEST.in
Normal file
@@ -0,0 +1 @@
|
||||
recursive-include tests *
|
||||
166
README.md
166
README.md
@@ -18,12 +18,15 @@
|
||||
- [Twins SVT](#twins-svt)
|
||||
- [CrossFormer](#crossformer)
|
||||
- [RegionViT](#regionvit)
|
||||
- [ScalableViT](#scalablevit)
|
||||
- [NesT](#nest)
|
||||
- [MobileViT](#mobilevit)
|
||||
- [Masked Autoencoder](#masked-autoencoder)
|
||||
- [Simple Masked Image Modeling](#simple-masked-image-modeling)
|
||||
- [Masked Patch Prediction](#masked-patch-prediction)
|
||||
- [Adaptive Token Sampling](#adaptive-token-sampling)
|
||||
- [Patch Merger](#patch-merger)
|
||||
- [Vision Transformer for Small Datasets](#vision-transformer-for-small-datasets)
|
||||
- [Dino](#dino)
|
||||
- [Accessing Attention](#accessing-attention)
|
||||
- [Research Ideas](#research-ideas)
|
||||
@@ -523,6 +526,38 @@ img = torch.randn(1, 3, 224, 224)
|
||||
pred = model(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## ScalableViT
|
||||
|
||||
<img src="./images/scalable-vit-1.png" width="400px"></img>
|
||||
|
||||
<img src="./images/scalable-vit-2.png" width="400px"></img>
|
||||
|
||||
This Bytedance AI <a href="https://arxiv.org/abs/2203.10790">paper</a> proposes the Scalable Self Attention (SSA) and the Interactive Windowed Self Attention (IWSA) modules. The SSA alleviates the computation needed at earlier stages by reducing the key / value feature map by some factor (`reduction_factor`), while modulating the dimension of the queries and keys (`ssa_dim_key`). The IWSA performs self attention within local windows, similar to other vision transformer papers. However, they add a residual of the values, passed through a convolution of kernel size 3, which they named Local Interactive Module (LIM).
|
||||
|
||||
They make the claim in this paper that this scheme outperforms Swin Transformer, and also demonstrate competitive performance against Crossformer.
|
||||
|
||||
You can use it as follows (ex. ScalableViT-S)
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.scalable_vit import ScalableViT
|
||||
|
||||
model = ScalableViT(
|
||||
num_classes = 1000,
|
||||
dim = 64, # starting model dimension. at every stage, dimension is doubled
|
||||
heads = (2, 4, 8, 16), # number of attention heads at each stage
|
||||
depth = (2, 2, 20, 2), # number of transformer blocks at each stage
|
||||
ssa_dim_key = (40, 40, 40, 32), # the dimension of the attention keys (and queries) for SSA. in the paper, they represented this as a scale factor on the base dimension per key (ssa_dim_key / dim_key)
|
||||
reduction_factor = (8, 4, 2, 1), # downsampling of the key / values in SSA. in the paper, this was represented as (reduction_factor ** -2)
|
||||
window_size = (64, 32, None, None), # window size of the IWSA at each stage. None means no windowing needed
|
||||
dropout = 0.1, # attention and feedforward dropout
|
||||
).cuda()
|
||||
|
||||
img = torch.randn(1, 3, 256, 256).cuda()
|
||||
|
||||
preds = model(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## NesT
|
||||
|
||||
<img src="./images/nest.png" width="400px"></img>
|
||||
@@ -541,7 +576,7 @@ nest = NesT(
|
||||
dim = 96,
|
||||
heads = 3,
|
||||
num_hierarchies = 3, # number of hierarchies
|
||||
block_repeats = (8, 4, 1), # the number of transformer blocks at each heirarchy, starting from the bottom
|
||||
block_repeats = (2, 2, 8), # the number of transformer blocks at each heirarchy, starting from the bottom
|
||||
num_classes = 1000
|
||||
)
|
||||
|
||||
@@ -731,12 +766,104 @@ v = ViT(
|
||||
|
||||
img = torch.randn(4, 3, 256, 256)
|
||||
|
||||
preds = v(img) # (1, 1000)
|
||||
preds = v(img) # (4, 1000)
|
||||
|
||||
# you can also get a list of the final sampled patch ids
|
||||
# a value of -1 denotes padding
|
||||
|
||||
preds, token_ids = v(img, return_sampled_token_ids = True) # (1, 1000), (1, <=8)
|
||||
preds, token_ids = v(img, return_sampled_token_ids = True) # (4, 1000), (4, <=8)
|
||||
```
|
||||
|
||||
## Patch Merger
|
||||
|
||||
|
||||
<img src="./images/patch_merger.png" width="400px"></img>
|
||||
|
||||
This <a href="https://arxiv.org/abs/2202.12015">paper</a> proposes a simple module (Patch Merger) for reducing the number of tokens at any layer of a vision transformer without sacrificing performance.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.vit_with_patch_merger import ViT
|
||||
|
||||
v = ViT(
|
||||
image_size = 256,
|
||||
patch_size = 16,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 12,
|
||||
heads = 8,
|
||||
patch_merge_layer = 6, # at which transformer layer to do patch merging
|
||||
patch_merge_num_tokens = 8, # the output number of tokens from the patch merge
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
img = torch.randn(4, 3, 256, 256)
|
||||
|
||||
preds = v(img) # (4, 1000)
|
||||
```
|
||||
|
||||
One can also use the `PatchMerger` module by itself
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.vit_with_patch_merger import PatchMerger
|
||||
|
||||
merger = PatchMerger(
|
||||
dim = 1024,
|
||||
num_tokens_out = 8 # output number of tokens
|
||||
)
|
||||
|
||||
features = torch.randn(4, 256, 1024) # (batch, num tokens, dimension)
|
||||
|
||||
out = merger(features) # (4, 8, 1024)
|
||||
```
|
||||
|
||||
## Vision Transformer for Small Datasets
|
||||
|
||||
<img src="./images/vit_for_small_datasets.png" width="400px"></img>
|
||||
|
||||
This <a href="https://arxiv.org/abs/2112.13492">paper</a> proposes a new image to patch function that incorporates shifts of the image, before normalizing and dividing the image into patches. I have found shifting to be extremely helpful in some other transformers work, so decided to include this for further explorations. It also includes the `LSA` with the learned temperature and masking out of a token's attention to itself.
|
||||
|
||||
You can use as follows:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.vit_for_small_dataset import ViT
|
||||
|
||||
v = ViT(
|
||||
image_size = 256,
|
||||
patch_size = 16,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
img = torch.randn(4, 3, 256, 256)
|
||||
|
||||
preds = v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
You can also use the `SPT` from this paper as a standalone module
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.vit_for_small_dataset import SPT
|
||||
|
||||
spt = SPT(
|
||||
dim = 1024,
|
||||
patch_size = 16,
|
||||
channels = 3
|
||||
)
|
||||
|
||||
img = torch.randn(4, 3, 256, 256)
|
||||
|
||||
tokens = spt(img) # (4, 256, 1024)
|
||||
```
|
||||
|
||||
## Dino
|
||||
@@ -1236,6 +1363,39 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{lee2021vision,
|
||||
title = {Vision Transformer for Small-Size Datasets},
|
||||
author = {Seung Hoon Lee and Seunghyun Lee and Byung Cheol Song},
|
||||
year = {2021},
|
||||
eprint = {2112.13492},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{renggli2022learning,
|
||||
title = {Learning to Merge Tokens in Vision Transformers},
|
||||
author = {Cedric Renggli and André Susano Pinto and Neil Houlsby and Basil Mustafa and Joan Puigcerver and Carlos Riquelme},
|
||||
year = {2022},
|
||||
eprint = {2202.12015},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{yang2022scalablevit,
|
||||
title = {ScalableViT: Rethinking the Context-oriented Generalization of Vision Transformer},
|
||||
author = {Rui Yang and Hailong Ma and Jie Wu and Yansong Tang and Xuefeng Xiao and Min Zheng and Xiu Li},
|
||||
year = {2022},
|
||||
eprint = {2203.10790},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{vaswani2017attention,
|
||||
title = {Attention Is All You Need},
|
||||
|
||||
BIN
images/patch_merger.png
Normal file
BIN
images/patch_merger.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 54 KiB |
BIN
images/scalable-vit-1.png
Normal file
BIN
images/scalable-vit-1.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 79 KiB |
BIN
images/scalable-vit-2.png
Normal file
BIN
images/scalable-vit-2.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 62 KiB |
BIN
images/vit_for_small_datasets.png
Normal file
BIN
images/vit_for_small_datasets.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 80 KiB |
10
setup.py
10
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.25.1',
|
||||
version = '0.28.1',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
@@ -15,10 +15,16 @@ setup(
|
||||
'image recognition'
|
||||
],
|
||||
install_requires=[
|
||||
'einops>=0.3',
|
||||
'einops>=0.4.1',
|
||||
'torch>=1.6',
|
||||
'torchvision'
|
||||
],
|
||||
setup_requires=[
|
||||
'pytest-runner',
|
||||
],
|
||||
tests_require=[
|
||||
'pytest'
|
||||
],
|
||||
classifiers=[
|
||||
'Development Status :: 4 - Beta',
|
||||
'Intended Audience :: Developers',
|
||||
|
||||
20
tests/test.py
Normal file
20
tests/test.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import torch
|
||||
from vit_pytorch import ViT
|
||||
|
||||
def test():
|
||||
v = ViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 256, 256)
|
||||
|
||||
preds = v(img)
|
||||
assert preds.shape == (1, 1000), 'correct logits outputted'
|
||||
@@ -62,9 +62,9 @@ class LayerNorm(nn.Module):
|
||||
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
|
||||
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (std + self.eps) * self.g + self.b
|
||||
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
|
||||
|
||||
def FeedForward(dim, mult = 4, dropout = 0.):
|
||||
return nn.Sequential(
|
||||
|
||||
@@ -30,9 +30,9 @@ class LayerNorm(nn.Module): # layernorm, but done in the channel dimension #1
|
||||
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
|
||||
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (std + self.eps) * self.g + self.b
|
||||
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
|
||||
@@ -3,12 +3,16 @@ from torch import nn
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
class ViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, transformer, pool = 'cls', channels = 3):
|
||||
super().__init__()
|
||||
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
|
||||
image_size_h, image_size_w = pair(image_size)
|
||||
assert image_size_h % patch_size == 0 and image_size_w % patch_size == 0, 'image dimensions must be divisible by the patch size'
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
num_patches = (image_size_h // patch_size) * (image_size_w // patch_size)
|
||||
patch_dim = channels * patch_size ** 2
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
|
||||
@@ -5,7 +5,14 @@ def exists(val):
|
||||
return val is not None
|
||||
|
||||
class Extractor(nn.Module):
|
||||
def __init__(self, vit, device = None):
|
||||
def __init__(
|
||||
self,
|
||||
vit,
|
||||
device = None,
|
||||
layer_name = 'transformer',
|
||||
layer_save_input = False,
|
||||
return_embeddings_only = False
|
||||
):
|
||||
super().__init__()
|
||||
self.vit = vit
|
||||
|
||||
@@ -16,11 +23,18 @@ class Extractor(nn.Module):
|
||||
self.ejected = False
|
||||
self.device = device
|
||||
|
||||
def _hook(self, _, input, output):
|
||||
self.latents = output.clone().detach()
|
||||
self.layer_name = layer_name
|
||||
self.layer_save_input = layer_save_input # whether to save input or output of layer
|
||||
self.return_embeddings_only = return_embeddings_only
|
||||
|
||||
def _hook(self, _, inputs, output):
|
||||
tensor_to_save = inputs if self.layer_save_input else output
|
||||
self.latents = tensor_to_save.clone().detach()
|
||||
|
||||
def _register_hook(self):
|
||||
handle = self.vit.transformer.register_forward_hook(self._hook)
|
||||
assert hasattr(self.vit, self.layer_name), 'layer whose output to take as embedding not found in vision transformer'
|
||||
layer = getattr(self.vit, self.layer_name)
|
||||
handle = layer.register_forward_hook(self._hook)
|
||||
self.hooks.append(handle)
|
||||
self.hook_registered = True
|
||||
|
||||
@@ -35,7 +49,11 @@ class Extractor(nn.Module):
|
||||
del self.latents
|
||||
self.latents = None
|
||||
|
||||
def forward(self, img):
|
||||
def forward(
|
||||
self,
|
||||
img,
|
||||
return_embeddings_only = False
|
||||
):
|
||||
assert not self.ejected, 'extractor has been ejected, cannot be used anymore'
|
||||
self.clear()
|
||||
if not self.hook_registered:
|
||||
@@ -45,4 +63,8 @@ class Extractor(nn.Module):
|
||||
|
||||
target_device = self.device if exists(self.device) else img.device
|
||||
latents = self.latents.to(target_device)
|
||||
|
||||
if return_embeddings_only or self.return_embeddings_only:
|
||||
return latents
|
||||
|
||||
return pred, latents
|
||||
|
||||
@@ -71,6 +71,10 @@ class MAE(nn.Module):
|
||||
|
||||
decoder_tokens = self.enc_to_dec(encoded_tokens)
|
||||
|
||||
# reapply decoder position embedding to unmasked tokens
|
||||
|
||||
decoder_tokens = decoder_tokens + self.decoder_pos_emb(unmasked_indices)
|
||||
|
||||
# repeat mask tokens for number of masked, and add the positions using the masked indices derived above
|
||||
|
||||
mask_tokens = repeat(self.mask_token, 'd -> b n d', b = batch, n = num_masked)
|
||||
|
||||
@@ -1,40 +1,27 @@
|
||||
"""
|
||||
An implementation of MobileViT Model as defined in:
|
||||
MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer
|
||||
Arxiv: https://arxiv.org/abs/2110.02178
|
||||
Origin Code: https://github.com/murufeng/awesome_lightweight_networks
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Reduce
|
||||
|
||||
def _make_divisible(v, divisor, min_value=None):
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
def conv_bn_relu(inp, oup, kernel, stride=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, kernel_size=kernel, stride=stride, padding=1, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
nn.ReLU6(inplace=True)
|
||||
)
|
||||
|
||||
# helpers
|
||||
|
||||
def conv_1x1_bn(inp, oup):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
nn.ReLU6(inplace=True)
|
||||
nn.SiLU()
|
||||
)
|
||||
|
||||
def conv_nxn_bn(inp, oup, kernal_size=3, stride=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
nn.SiLU()
|
||||
)
|
||||
|
||||
# classes
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
@@ -44,10 +31,11 @@ class PreNorm(nn.Module):
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout=0.):
|
||||
super().__init__()
|
||||
self.ffn = nn.Sequential(
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(dropout),
|
||||
@@ -56,8 +44,7 @@ class FeedForward(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.ffn(x)
|
||||
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
|
||||
@@ -76,7 +63,8 @@ class Attention(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
|
||||
q, k, v = map(lambda t: rearrange(
|
||||
t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
attn = self.attend(dots)
|
||||
@@ -84,15 +72,19 @@ class Attention(nn.Module):
|
||||
out = rearrange(out, 'b p h n d -> b p n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
"""Transformer block described in ViT.
|
||||
Paper: https://arxiv.org/abs/2010.11929
|
||||
Based on: https://github.com/lucidrains/vit-pytorch
|
||||
"""
|
||||
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
|
||||
PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
@@ -102,17 +94,24 @@ class Transformer(nn.Module):
|
||||
return x
|
||||
|
||||
class MV2Block(nn.Module):
|
||||
def __init__(self, inp, oup, stride=1, expand_ratio=4):
|
||||
super(MV2Block, self).__init__()
|
||||
"""MV2 block described in MobileNetV2.
|
||||
Paper: https://arxiv.org/pdf/1801.04381
|
||||
Based on: https://github.com/tonylins/pytorch-mobilenet-v2
|
||||
"""
|
||||
|
||||
def __init__(self, inp, oup, stride=1, expansion=4):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
assert stride in [1, 2]
|
||||
|
||||
hidden_dim = round(inp * expand_ratio)
|
||||
self.identity = stride == 1 and inp == oup
|
||||
hidden_dim = int(inp * expansion)
|
||||
self.use_res_connect = self.stride == 1 and inp == oup
|
||||
|
||||
if expand_ratio == 1:
|
||||
if expansion == 1:
|
||||
self.conv = nn.Sequential(
|
||||
# dw
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
|
||||
1, groups=hidden_dim, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.SiLU(),
|
||||
# pw-linear
|
||||
@@ -126,7 +125,8 @@ class MV2Block(nn.Module):
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.SiLU(),
|
||||
# dw
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
|
||||
1, groups=hidden_dim, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.SiLU(),
|
||||
# pw-linear
|
||||
@@ -136,8 +136,7 @@ class MV2Block(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv(x)
|
||||
|
||||
if self.identity:
|
||||
if self.use_res_connect:
|
||||
out = out + x
|
||||
return out
|
||||
|
||||
@@ -146,13 +145,13 @@ class MobileViTBlock(nn.Module):
|
||||
super().__init__()
|
||||
self.ph, self.pw = patch_size
|
||||
|
||||
self.conv1 = conv_bn_relu(channel, channel, kernel_size)
|
||||
self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
|
||||
self.conv2 = conv_1x1_bn(channel, dim)
|
||||
|
||||
self.transformer = Transformer(dim, depth, 1, 32, mlp_dim, dropout)
|
||||
self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout)
|
||||
|
||||
self.conv3 = conv_1x1_bn(dim, channel)
|
||||
self.conv4 = conv_bn_relu(2 * channel, channel, kernel_size)
|
||||
self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)
|
||||
|
||||
def forward(self, x):
|
||||
y = x.clone()
|
||||
@@ -163,9 +162,11 @@ class MobileViTBlock(nn.Module):
|
||||
|
||||
# Global representations
|
||||
_, _, h, w = x.shape
|
||||
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
|
||||
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d',
|
||||
ph=self.ph, pw=self.pw)
|
||||
x = self.transformer(x)
|
||||
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h // self.ph, w=w // self.pw, ph=self.ph, pw=self.pw)
|
||||
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)',
|
||||
h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)
|
||||
|
||||
# Fusion
|
||||
x = self.conv3(x)
|
||||
@@ -173,18 +174,22 @@ class MobileViTBlock(nn.Module):
|
||||
x = self.conv4(x)
|
||||
return x
|
||||
|
||||
|
||||
class MobileViT(nn.Module):
|
||||
"""MobileViT.
|
||||
Paper: https://arxiv.org/abs/2110.02178
|
||||
Based on: https://github.com/chinhsuanwu/mobilevit-pytorch
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size,
|
||||
dims,
|
||||
channels,
|
||||
num_classes,
|
||||
expansion = 4,
|
||||
kernel_size = 3,
|
||||
patch_size = (2, 2),
|
||||
depths = (2, 4, 3)
|
||||
expansion=4,
|
||||
kernel_size=3,
|
||||
patch_size=(2, 2),
|
||||
depths=(2, 4, 3)
|
||||
):
|
||||
super().__init__()
|
||||
assert len(dims) == 3, 'dims must be a tuple of 3'
|
||||
@@ -196,28 +201,31 @@ class MobileViT(nn.Module):
|
||||
|
||||
init_dim, *_, last_dim = channels
|
||||
|
||||
self.conv1 = conv_bn_relu(3, init_dim, kernel=3, stride=2)
|
||||
self.conv1 = conv_nxn_bn(3, init_dim, stride=2)
|
||||
|
||||
self.stem = nn.ModuleList([])
|
||||
self.stem.append(MV2Block(channels[0], channels[1], 1, expansion))
|
||||
self.stem.append(MV2Block(channels[1], channels[2], 2, expansion))
|
||||
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
|
||||
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
|
||||
|
||||
|
||||
self.trunk = nn.ModuleList([])
|
||||
self.trunk.append(nn.ModuleList([
|
||||
MV2Block(channels[3], channels[4], 2, expansion),
|
||||
MobileViTBlock(dims[0], depths[0], channels[5], kernel_size, patch_size, int(dims[0] * 2))
|
||||
MobileViTBlock(dims[0], depths[0], channels[5],
|
||||
kernel_size, patch_size, int(dims[0] * 2))
|
||||
]))
|
||||
|
||||
self.trunk.append(nn.ModuleList([
|
||||
MV2Block(channels[5], channels[6], 2, expansion),
|
||||
MobileViTBlock(dims[1], depths[1], channels[7], kernel_size, patch_size, int(dims[1] * 4))
|
||||
MobileViTBlock(dims[1], depths[1], channels[7],
|
||||
kernel_size, patch_size, int(dims[1] * 4))
|
||||
]))
|
||||
|
||||
self.trunk.append(nn.ModuleList([
|
||||
MV2Block(channels[7], channels[8], 2, expansion),
|
||||
MobileViTBlock(dims[2], depths[2], channels[9], kernel_size, patch_size, int(dims[2] * 4))
|
||||
MobileViTBlock(dims[2], depths[2], channels[9],
|
||||
kernel_size, patch_size, int(dims[2] * 4))
|
||||
]))
|
||||
|
||||
self.to_logits = nn.Sequential(
|
||||
|
||||
@@ -20,9 +20,9 @@ class LayerNorm(nn.Module):
|
||||
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
|
||||
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (std + self.eps) * self.g + self.b
|
||||
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
@@ -131,10 +131,11 @@ class NesT(nn.Module):
|
||||
|
||||
seq_len = (fmap_size // blocks) ** 2 # sequence length is held constant across heirarchy
|
||||
hierarchies = list(reversed(range(num_hierarchies)))
|
||||
mults = [2 ** i for i in hierarchies]
|
||||
mults = [2 ** i for i in reversed(hierarchies)]
|
||||
|
||||
layer_heads = list(map(lambda t: t * heads, mults))
|
||||
layer_dims = list(map(lambda t: t * dim, mults))
|
||||
last_dim = layer_dims[-1]
|
||||
|
||||
layer_dims = [*layer_dims, layer_dims[-1]]
|
||||
dim_pairs = zip(layer_dims[:-1], layer_dims[1:])
|
||||
@@ -157,10 +158,11 @@ class NesT(nn.Module):
|
||||
Aggregate(dim_in, dim_out) if not is_last else nn.Identity()
|
||||
]))
|
||||
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
LayerNorm(dim),
|
||||
LayerNorm(last_dim),
|
||||
Reduce('b c h w -> b c', 'mean'),
|
||||
nn.Linear(dim, num_classes)
|
||||
nn.Linear(last_dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
|
||||
@@ -55,5 +55,5 @@ class Recorder(nn.Module):
|
||||
target_device = self.device if self.device is not None else img.device
|
||||
recordings = tuple(map(lambda t: t.to(target_device), self.recordings))
|
||||
|
||||
attns = torch.stack(recordings, dim = 1)
|
||||
attns = torch.stack(recordings, dim = 1) if len(recordings) > 0 else None
|
||||
return pred, attns
|
||||
|
||||
302
vit_pytorch/scalable_vit.py
Normal file
302
vit_pytorch/scalable_vit.py
Normal file
@@ -0,0 +1,302 @@
|
||||
from functools import partial
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange, Reduce
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def cast_tuple(val, length = 1):
|
||||
return val if isinstance(val, tuple) else ((val,) * length)
|
||||
|
||||
# helper classes
|
||||
|
||||
class ChanLayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps = 1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = ChanLayerNorm(dim)
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x):
|
||||
return self.fn(self.norm(x))
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(dim_in, dim_out, 3, stride = 2, padding = 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
class PEG(nn.Module):
|
||||
def __init__(self, dim, kernel_size = 3):
|
||||
super().__init__()
|
||||
self.proj = nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.proj(x) + x
|
||||
|
||||
# feedforward
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, expansion_factor = 4, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim * expansion_factor
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(dim, inner_dim, 1),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Conv2d(inner_dim, dim, 1),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
# attention
|
||||
|
||||
class ScalableSelfAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
heads = 8,
|
||||
dim_key = 32,
|
||||
dim_value = 32,
|
||||
dropout = 0.,
|
||||
reduction_factor = 1
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.scale = dim_key ** -0.5
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
self.to_q = nn.Conv2d(dim, dim_key * heads, 1, bias = False)
|
||||
self.to_k = nn.Conv2d(dim, dim_key * heads, reduction_factor, stride = reduction_factor, bias = False)
|
||||
self.to_v = nn.Conv2d(dim, dim_value * heads, reduction_factor, stride = reduction_factor, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Conv2d(dim_value * heads, dim, 1),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
height, width, heads = *x.shape[-2:], self.heads
|
||||
|
||||
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
|
||||
|
||||
# split out heads
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = heads), (q, k, v))
|
||||
|
||||
# similarity
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
# attention
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
# aggregate values
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
|
||||
# merge back heads
|
||||
|
||||
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = height, y = width)
|
||||
return self.to_out(out)
|
||||
|
||||
class InteractiveWindowedSelfAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
window_size,
|
||||
heads = 8,
|
||||
dim_key = 32,
|
||||
dim_value = 32,
|
||||
dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.scale = dim_key ** -0.5
|
||||
self.window_size = window_size
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
self.local_interactive_module = nn.Conv2d(dim_value * heads, dim_value * heads, 3, padding = 1)
|
||||
|
||||
self.to_q = nn.Conv2d(dim, dim_key * heads, 1, bias = False)
|
||||
self.to_k = nn.Conv2d(dim, dim_key * heads, 1, bias = False)
|
||||
self.to_v = nn.Conv2d(dim, dim_value * heads, 1, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Conv2d(dim_value * heads, dim, 1),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
height, width, heads, wsz = *x.shape[-2:], self.heads, self.window_size
|
||||
|
||||
wsz = default(wsz, height) # take height as window size if not given
|
||||
assert (height % wsz) == 0 and (width % wsz) == 0, f'height ({height}) or width ({width}) of feature map is not divisible by the window size ({wsz})'
|
||||
|
||||
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
|
||||
|
||||
# get output of LIM
|
||||
|
||||
local_out = self.local_interactive_module(v)
|
||||
|
||||
# divide into window (and split out heads) for efficient self attention
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b (h d) (x w1) (y w2) -> (b x y) h (w1 w2) d', h = heads, w1 = wsz, w2 = wsz), (q, k, v))
|
||||
|
||||
# similarity
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
# attention
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
# aggregate values
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
|
||||
# reshape the windows back to full feature map (and merge heads)
|
||||
|
||||
out = rearrange(out, '(b x y) h (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz, y = width // wsz, w1 = wsz, w2 = wsz)
|
||||
|
||||
# add LIM output
|
||||
|
||||
out = out + local_out
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
depth,
|
||||
heads = 8,
|
||||
ff_expansion_factor = 4,
|
||||
dropout = 0.,
|
||||
ssa_dim_key = 32,
|
||||
ssa_dim_value = 32,
|
||||
ssa_reduction_factor = 1,
|
||||
iwsa_dim_key = 32,
|
||||
iwsa_dim_value = 32,
|
||||
iwsa_window_size = None,
|
||||
norm_output = True
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for ind in range(depth):
|
||||
is_first = ind == 0
|
||||
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, ScalableSelfAttention(dim, heads = heads, dim_key = ssa_dim_key, dim_value = ssa_dim_value, reduction_factor = ssa_reduction_factor, dropout = dropout)),
|
||||
PreNorm(dim, FeedForward(dim, expansion_factor = ff_expansion_factor, dropout = dropout)),
|
||||
PEG(dim) if is_first else None,
|
||||
PreNorm(dim, FeedForward(dim, expansion_factor = ff_expansion_factor, dropout = dropout)),
|
||||
PreNorm(dim, InteractiveWindowedSelfAttention(dim, heads = heads, dim_key = iwsa_dim_key, dim_value = iwsa_dim_value, window_size = iwsa_window_size, dropout = dropout))
|
||||
]))
|
||||
|
||||
self.norm = ChanLayerNorm(dim) if norm_output else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
for ssa, ff1, peg, iwsa, ff2 in self.layers:
|
||||
x = ssa(x) + x
|
||||
x = ff1(x) + x
|
||||
|
||||
if exists(peg):
|
||||
x = peg(x)
|
||||
|
||||
x = iwsa(x) + x
|
||||
x = ff2(x) + x
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class ScalableViT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
num_classes,
|
||||
dim,
|
||||
depth,
|
||||
heads,
|
||||
reduction_factor,
|
||||
window_size = None,
|
||||
iwsa_dim_key = 32,
|
||||
iwsa_dim_value = 32,
|
||||
ssa_dim_key = 32,
|
||||
ssa_dim_value = 32,
|
||||
ff_expansion_factor = 4,
|
||||
channels = 3,
|
||||
dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
self.to_patches = nn.Conv2d(channels, dim, 7, stride = 4, padding = 3)
|
||||
|
||||
assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'
|
||||
|
||||
num_stages = len(depth)
|
||||
dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))
|
||||
|
||||
hyperparams_per_stage = [
|
||||
heads,
|
||||
ssa_dim_key,
|
||||
ssa_dim_value,
|
||||
reduction_factor,
|
||||
iwsa_dim_key,
|
||||
iwsa_dim_value,
|
||||
window_size,
|
||||
]
|
||||
|
||||
hyperparams_per_stage = list(map(partial(cast_tuple, length = num_stages), hyperparams_per_stage))
|
||||
assert all(tuple(map(lambda arr: len(arr) == num_stages, hyperparams_per_stage)))
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
for ind, (layer_dim, layer_depth, layer_heads, layer_ssa_dim_key, layer_ssa_dim_value, layer_ssa_reduction_factor, layer_iwsa_dim_key, layer_iwsa_dim_value, layer_window_size) in enumerate(zip(dims, depth, *hyperparams_per_stage)):
|
||||
is_last = ind == (num_stages - 1)
|
||||
|
||||
self.layers.append(nn.ModuleList([
|
||||
Transformer(dim = layer_dim, depth = layer_depth, heads = layer_heads, ff_expansion_factor = ff_expansion_factor, dropout = dropout, ssa_dim_key = layer_ssa_dim_key, ssa_dim_value = layer_ssa_dim_value, ssa_reduction_factor = layer_ssa_reduction_factor, iwsa_dim_key = layer_iwsa_dim_key, iwsa_dim_value = layer_iwsa_dim_value, iwsa_window_size = layer_window_size),
|
||||
Downsample(layer_dim, layer_dim * 2) if not is_last else None
|
||||
]))
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
Reduce('b d h w -> b d', 'mean'),
|
||||
nn.LayerNorm(dims[-1]),
|
||||
nn.Linear(dims[-1], num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
x = self.to_patches(img)
|
||||
|
||||
for transformer, downsample in self.layers:
|
||||
x = transformer(x)
|
||||
|
||||
if exists(downsample):
|
||||
x = downsample(x)
|
||||
|
||||
return self.mlp_head(x)
|
||||
@@ -38,9 +38,9 @@ class LayerNorm(nn.Module):
|
||||
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
|
||||
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (std + self.eps) * self.g + self.b
|
||||
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
|
||||
142
vit_pytorch/vit_for_small_dataset.py
Normal file
142
vit_pytorch/vit_for_small_dataset.py
Normal file
@@ -0,0 +1,142 @@
|
||||
from math import sqrt
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
# classes
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class LSA(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.temperature = nn.Parameter(torch.log(torch.tensor(dim_head ** -0.5)))
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.temperature.exp()
|
||||
|
||||
mask = torch.eye(dots.shape[-1], device = dots.device, dtype = torch.bool)
|
||||
mask_value = -torch.finfo(dots.dtype).max
|
||||
dots = dots.masked_fill(mask, mask_value)
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, LSA(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
class SPT(nn.Module):
|
||||
def __init__(self, *, dim, patch_size, channels = 3):
|
||||
super().__init__()
|
||||
patch_dim = patch_size * patch_size * 5 * channels
|
||||
|
||||
self.to_patch_tokens = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
shifts = ((1, -1, 0, 0), (-1, 1, 0, 0), (0, 0, 1, -1), (0, 0, -1, 1))
|
||||
shifted_x = list(map(lambda shift: F.pad(x, shift), shifts))
|
||||
x_with_shifts = torch.cat((x, *shifted_x), dim = 1)
|
||||
return self.to_patch_tokens(x_with_shifts)
|
||||
|
||||
class ViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
|
||||
self.to_patch_embedding = SPT(dim = dim, patch_size = patch_size, channels = channels)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
self.pool = pool
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.mlp_head(x)
|
||||
144
vit_pytorch/vit_with_patch_merger.py
Normal file
144
vit_pytorch/vit_with_patch_merger.py
Normal file
@@ -0,0 +1,144 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange, Reduce
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val ,d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
# patch merger class
|
||||
|
||||
class PatchMerger(nn.Module):
|
||||
def __init__(self, dim, num_tokens_out):
|
||||
super().__init__()
|
||||
self.scale = dim ** -0.5
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.queries = nn.Parameter(torch.randn(num_tokens_out, dim))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
sim = torch.matmul(self.queries, x.transpose(-1, -2)) * self.scale
|
||||
attn = sim.softmax(dim = -1)
|
||||
return torch.matmul(attn, x)
|
||||
|
||||
# classes
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., patch_merge_layer = None, patch_merge_num_tokens = 8):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
self.patch_merge_layer_index = default(patch_merge_layer, depth // 2) - 1 # default to mid-way through transformer, as shown in paper
|
||||
self.patch_merger = PatchMerger(dim = dim, num_tokens_out = patch_merge_num_tokens)
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
|
||||
]))
|
||||
def forward(self, x):
|
||||
for index, (attn, ff) in enumerate(self.layers):
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
|
||||
if index == self.patch_merge_layer_index:
|
||||
x = self.patch_merger(x)
|
||||
|
||||
return x
|
||||
|
||||
class ViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, patch_merge_layer = None, patch_merge_num_tokens = 8, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
|
||||
nn.Linear(patch_dim, dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, patch_merge_layer, patch_merge_num_tokens)
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
Reduce('b n d -> b d', 'mean'),
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
x += self.pos_embedding[:, :n]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
return self.mlp_head(x)
|
||||
Reference in New Issue
Block a user