Compare commits

...

31 Commits
1.2.0 ... 1.6.4

Author SHA1 Message Date
lucidrains
d446a41243 share an idea that should be tried if it has not been 2023-11-14 16:55:36 -08:00
lucidrains
0ad09c4cbc allow channels to be customizable for cvt 2023-10-25 14:47:58 -07:00
Phil Wang
92b69321f4 1.6.2 2023-10-24 12:47:38 -07:00
Artem Lukin
fb4ac25174 Fix typo in LayerNorm (#285)
Co-authored-by: Artem Lukin <artyom.lukin98@gmail.com>
2023-10-24 12:47:21 -07:00
lucidrains
53fe345e85 no longer needed with einops 0.7 2023-10-19 18:16:46 -07:00
Phil Wang
efb94608ea readme 2023-10-19 09:38:35 -07:00
lucidrains
51310d1d07 add xcit diagram 2023-10-13 09:18:12 -07:00
Phil Wang
1616288e30 add xcit (#284)
* add xcit

* use Rearrange layers

* give cross correlation transformer a final norm at end

* document
2023-10-13 09:15:13 -07:00
Jason Chou
9e1e824385 Update README.md (#283)
`patch_size` is size of patches, not number of patches
2023-10-09 11:33:56 -07:00
lucidrains
bbb24e34d4 give a learned bias to and from registers for maxvit + register token variant 2023-10-06 10:40:26 -07:00
lucidrains
df8733d86e improvise a max vit with register tokens 2023-10-06 10:27:36 -07:00
lucidrains
680d446e46 document in readme later 2023-10-03 09:26:02 -07:00
lucidrains
3fdb8dd352 fix pypi 2023-10-01 08:14:20 -07:00
lucidrains
a36546df23 add simple vit with register tokens example, cite 2023-10-01 08:11:40 -07:00
lucidrains
d830b05f06 address https://github.com/lucidrains/vit-pytorch/issues/279 2023-09-10 09:32:57 -07:00
Phil Wang
8208c859a5 just remove PreNorm wrapper from all ViTs, as it is unlikely to change at this point 2023-08-14 09:48:55 -07:00
Phil Wang
4264efd906 1.4.2 2023-08-14 07:59:35 -07:00
Phil Wang
b194359301 add a simple vit with qknorm, since authors seem to be promoting the technique on twitter 2023-08-14 07:58:45 -07:00
lucidrains
950c901b80 fix linear head in simple vit, thanks to @atkos 2023-08-10 14:36:21 -07:00
Phil Wang
3e5d1be6f0 address https://github.com/lucidrains/vit-pytorch/pull/274 2023-08-09 07:53:38 -07:00
Phil Wang
6e2393de95 wrap up NaViT 2023-07-25 10:38:55 -07:00
Phil Wang
32974c33df one can pass a callback to token_dropout_prob for NaViT that takes in height and width and calculate appropriate dropout rate 2023-07-24 14:52:40 -07:00
Phil Wang
17675e0de4 add constant token dropout for NaViT 2023-07-24 14:14:36 -07:00
Phil Wang
598cffab53 release NaViT 2023-07-24 13:55:54 -07:00
Phil Wang
23820bc54a begin work on NaViT (#273)
finish core idea of NaViT
2023-07-24 13:54:02 -07:00
Phil Wang
e9ca1f4d57 1.2.5 2023-07-24 06:43:24 -07:00
roydenwa
d4daf7bd0f Support SimpleViT as encoder in MAE (#272)
support simplevit in mae
2023-07-24 06:43:01 -07:00
Phil Wang
9e3fec2398 fix mpp 2023-06-28 08:02:43 -07:00
Phil Wang
ce4bcd08fb address https://github.com/lucidrains/vit-pytorch/issues/266 2023-05-20 08:24:49 -07:00
Phil Wang
ad4ca19775 enforce latest einops 2023-05-08 09:34:14 -07:00
Phil Wang
e1b08c15b9 fix tests 2023-03-19 10:52:47 -07:00
40 changed files with 1789 additions and 317 deletions

View File

@@ -1,11 +1,16 @@
# This workflows will upload a Python Package using Twine when a release is created
# This workflow will upload a Python Package using Twine when a release is created
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
# This workflow uses actions that are not certified by GitHub.
# They are provided by a third-party and are governed by
# separate terms of service, privacy policy, and support
# documentation.
name: Upload Python Package
on:
release:
types: [created]
types: [published]
jobs:
deploy:
@@ -21,11 +26,11 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install setuptools wheel twine
- name: Build and publish
env:
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
run: |
python setup.py sdist bdist_wheel
twine upload dist/*
pip install build
- name: Build package
run: python -m build
- name: Publish package
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}

View File

@@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.7, 3.8, 3.9]
python-version: [3.8, 3.9]
steps:
- uses: actions/checkout@v2
@@ -27,6 +27,7 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install pytest
python -m pip install wheel
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Test with pytest
run: |

120
README.md
View File

@@ -7,6 +7,7 @@
- [Usage](#usage)
- [Parameters](#parameters)
- [Simple ViT](#simple-vit)
- [NaViT](#navit)
- [Distillation](#distillation)
- [Deep ViT](#deep-vit)
- [CaiT](#cait)
@@ -24,6 +25,7 @@
- [MaxViT](#maxvit)
- [NesT](#nest)
- [MobileViT](#mobilevit)
- [XCiT](#xcit)
- [Masked Autoencoder](#masked-autoencoder)
- [Simple Masked Image Modeling](#simple-masked-image-modeling)
- [Masked Patch Prediction](#masked-patch-prediction)
@@ -91,7 +93,7 @@ preds = v(img) # (1, 1000)
- `image_size`: int.
Image size. If you have rectangular images, make sure your image size is the maximum of the width and height
- `patch_size`: int.
Number of patches. `image_size` must be divisible by `patch_size`.
Size of patches. `image_size` must be divisible by `patch_size`.
The number of patches is: ` n = (image_size // patch_size) ** 2` and `n` **must be greater than 16**.
- `num_classes`: int.
Number of classes to classify.
@@ -139,6 +141,63 @@ img = torch.randn(1, 3, 256, 256)
preds = v(img) # (1, 1000)
```
## NaViT
<img src="./images/navit.png" width="450px"></img>
<a href="https://arxiv.org/abs/2307.06304">This paper</a> proposes to leverage the flexibility of attention and masking for variable lengthed sequences to train images of multiple resolution, packed into a single batch. They demonstrate much faster training and improved accuracies, with the only cost being extra complexity in the architecture and dataloading. They use factorized 2d positional encodings, token dropping, as well as query-key normalization.
You can use it as follows
```python
import torch
from vit_pytorch.na_vit import NaViT
v = NaViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1,
token_dropout_prob = 0.1 # token dropout of 10% (keep 90% of tokens)
)
# 5 images of different resolutions - List[List[Tensor]]
# for now, you'll have to correctly place images in same batch element as to not exceed maximum allowed sequence length for self-attention w/ masking
images = [
[torch.randn(3, 256, 256), torch.randn(3, 128, 128)],
[torch.randn(3, 128, 256), torch.randn(3, 256, 128)],
[torch.randn(3, 64, 256)]
]
preds = v(images) # (5, 1000) - 5, because 5 images of different resolution above
```
Or if you would rather that the framework auto group the images into variable lengthed sequences that do not exceed a certain max length
```python
images = [
torch.randn(3, 256, 256),
torch.randn(3, 128, 128),
torch.randn(3, 128, 256),
torch.randn(3, 256, 128),
torch.randn(3, 64, 256)
]
preds = v(
images,
group_images = True,
group_max_seq_len = 64
) # (5, 1000)
```
## Distillation
<img src="./images/distill.png" width="300px"></img>
@@ -714,6 +773,38 @@ img = torch.randn(1, 3, 256, 256)
pred = mbvit_xs(img) # (1, 1000)
```
## XCiT
<img src="./images/xcit.png" width="400px"></img>
This <a href="https://arxiv.org/abs/2106.09681">paper</a> introduces the cross covariance attention (abbreviated XCA). One can think of it as doing attention across the features dimension rather than the spatial one (another perspective would be a dynamic 1x1 convolution, the kernel being attention map defined by spatial correlations).
Technically, this amounts to simply transposing the query, key, values before executing cosine similarity attention with learned temperature.
```python
import torch
from vit_pytorch.xcit import XCiT
v = XCiT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 12, # depth of xcit transformer
cls_depth = 2, # depth of cross attention of CLS tokens to patch, attention pool at end
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1,
layer_dropout = 0.05, # randomly dropout 5% of the layers
local_patch_kernel_size = 3 # kernel size of the local patch interaction module (depthwise convs)
)
img = torch.randn(1, 3, 256, 256)
preds = v(img) # (1, 1000)
```
## Simple Masked Image Modeling
<img src="./images/simmim.png" width="400px"/>
@@ -1934,6 +2025,14 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```
```bibtex
@inproceedings{Dehghani2023PatchNP,
title = {Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution},
author = {Mostafa Dehghani and Basil Mustafa and Josip Djolonga and Jonathan Heek and Matthias Minderer and Mathilde Caron and Andreas Steiner and Joan Puigcerver and Robert Geirhos and Ibrahim M. Alabdulmohsin and Avital Oliver and Piotr Padlewski and Alexey A. Gritsenko and Mario Luvci'c and Neil Houlsby},
year = {2023}
}
```
```bibtex
@misc{vaswani2017attention,
title = {Attention Is All You Need},
@@ -1954,4 +2053,23 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```
```bibtex
@inproceedings{Darcet2023VisionTN,
title = {Vision Transformers Need Registers},
author = {Timoth'ee Darcet and Maxime Oquab and Julien Mairal and Piotr Bojanowski},
year = {2023},
url = {https://api.semanticscholar.org/CorpusID:263134283}
}
```
```bibtex
@inproceedings{ElNouby2021XCiTCI,
title = {XCiT: Cross-Covariance Image Transformers},
author = {Alaaeldin El-Nouby and Hugo Touvron and Mathilde Caron and Piotr Bojanowski and Matthijs Douze and Armand Joulin and Ivan Laptev and Natalia Neverova and Gabriel Synnaeve and Jakob Verbeek and Herv{\'e} J{\'e}gou},
booktitle = {Neural Information Processing Systems},
year = {2021},
url = {https://api.semanticscholar.org/CorpusID:235458262}
}
```
*I visualise a time when we will be to robots what dogs are to humans, and Im rooting for the machines.* — Claude Shannon

BIN
images/navit.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 133 KiB

BIN
images/xcit.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 814 KiB

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.2.0',
version = '1.6.4',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description_content_type = 'text/markdown',
@@ -16,7 +16,7 @@ setup(
'image recognition'
],
install_requires=[
'einops>=0.6.0',
'einops>=0.7.0',
'torch>=1.10',
'torchvision'
],

View File

@@ -110,18 +110,11 @@ class AdaptiveTokenSampling(nn.Module):
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
@@ -138,6 +131,7 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
@@ -154,6 +148,7 @@ class Attention(nn.Module):
def forward(self, x, *, mask):
num_tokens = x.shape[1]
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
@@ -189,8 +184,8 @@ class Transformer(nn.Module):
self.layers = nn.ModuleList([])
for _, output_num_tokens in zip(range(depth), max_tokens_per_depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, output_num_tokens = output_num_tokens, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
Attention(dim, output_num_tokens = output_num_tokens, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):

View File

@@ -44,18 +44,11 @@ class LayerScale(nn.Module):
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) * self.scale
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
@@ -72,6 +65,7 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
@@ -89,6 +83,7 @@ class Attention(nn.Module):
def forward(self, x, context = None):
b, n, _, h = *x.shape, self.heads
x = self.norm(x)
context = x if not exists(context) else torch.cat((x, context), dim = 1)
qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
@@ -115,8 +110,8 @@ class Transformer(nn.Module):
for ind in range(depth):
self.layers.append(nn.ModuleList([
LayerScale(dim, PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), depth = ind + 1),
LayerScale(dim, PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)), depth = ind + 1)
LayerScale(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), depth = ind + 1),
LayerScale(dim, FeedForward(dim, mlp_dim, dropout = dropout), depth = ind + 1)
]))
def forward(self, x, context = None):
layers = dropout_layers(self.layers, dropout = self.layer_dropout)

View File

@@ -13,22 +13,13 @@ def exists(val):
def default(val, d):
return val if exists(val) else d
# pre-layernorm
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
# feedforward
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
@@ -47,6 +38,7 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
@@ -60,6 +52,7 @@ class Attention(nn.Module):
def forward(self, x, context = None, kv_include_self = False):
b, n, _, h = *x.shape, self.heads
x = self.norm(x)
context = default(context, x)
if kv_include_self:
@@ -86,8 +79,8 @@ class Transformer(nn.Module):
self.norm = nn.LayerNorm(dim)
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
@@ -121,8 +114,8 @@ class CrossTransformer(nn.Module):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
ProjectInOut(sm_dim, lg_dim, PreNorm(lg_dim, Attention(lg_dim, heads = heads, dim_head = dim_head, dropout = dropout))),
ProjectInOut(lg_dim, sm_dim, PreNorm(sm_dim, Attention(sm_dim, heads = heads, dim_head = dim_head, dropout = dropout)))
ProjectInOut(sm_dim, lg_dim, Attention(lg_dim, heads = heads, dim_head = dim_head, dropout = dropout)),
ProjectInOut(lg_dim, sm_dim, Attention(sm_dim, heads = heads, dim_head = dim_head, dropout = dropout))
]))
def forward(self, sm_tokens, lg_tokens):

View File

@@ -34,19 +34,11 @@ class LayerNorm(nn.Module): # layernorm, but done in the channel dimension #1
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
LayerNorm(dim),
nn.Conv2d(dim, dim * mult, 1),
nn.GELU(),
nn.Dropout(dropout),
@@ -75,6 +67,7 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
@@ -89,6 +82,8 @@ class Attention(nn.Module):
def forward(self, x):
shape = x.shape
b, n, _, y, h = *shape, self.heads
x = self.norm(x)
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = 1))
q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> (b h) (x y) d', h = h), (q, k, v))
@@ -107,8 +102,8 @@ class Transformer(nn.Module):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, proj_kernel = proj_kernel, kv_proj_stride = kv_proj_stride, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout))
Attention(dim, proj_kernel = proj_kernel, kv_proj_stride = kv_proj_stride, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_mult, dropout = dropout)
]))
def forward(self, x):
for attn, ff in self.layers:
@@ -145,12 +140,13 @@ class CvT(nn.Module):
s3_heads = 6,
s3_depth = 10,
s3_mlp_mult = 4,
dropout = 0.
dropout = 0.,
channels = 3
):
super().__init__()
kwargs = dict(locals())
dim = 3
dim = channels
layers = []
for prefix in ('s1', 's2', 's3'):

View File

@@ -5,25 +5,11 @@ import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
@@ -40,6 +26,7 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.dropout = nn.Dropout(dropout)
@@ -59,6 +46,8 @@ class Attention(nn.Module):
def forward(self, x):
b, n, _, h = *x.shape, self.heads
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
@@ -86,13 +75,13 @@ class Transformer(nn.Module):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x)
x = ff(x)
x = attn(x) + x
x = ff(x) + x
return x
class DeepViT(nn.Module):

View File

@@ -26,16 +26,6 @@ class ExcludeCLS(nn.Module):
x = self.fn(x, **kwargs)
return torch.cat((cls_token, x), dim = 1)
# prenorm
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
# feed forward related classes
class DepthWiseConv2d(nn.Module):
@@ -52,6 +42,7 @@ class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Conv2d(dim, hidden_dim, 1),
nn.Hardswish(),
DepthWiseConv2d(hidden_dim, hidden_dim, 3, padding = 1),
@@ -77,6 +68,7 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
@@ -88,6 +80,8 @@ class Attention(nn.Module):
def forward(self, x):
b, n, _, h = *x.shape, self.heads
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
@@ -106,8 +100,8 @@ class Transformer(nn.Module):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
ExcludeCLS(Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))))
Residual(Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
ExcludeCLS(Residual(FeedForward(dim, mlp_dim, dropout = dropout)))
]))
def forward(self, x):
for attn, ff in self.layers:

View File

@@ -49,7 +49,10 @@ class MAE(nn.Module):
# patch to encoder tokens and add positions
tokens = self.patch_to_emb(patches)
tokens = tokens + self.encoder.pos_embedding[:, 1:(num_patches + 1)]
if self.encoder.pool == "cls":
tokens += self.encoder.pos_embedding[:, 1:(num_patches + 1)]
elif self.encoder.pool == "mean":
tokens += self.encoder.pos_embedding.to(device, dtype=tokens.dtype)
# calculate of patches needed to be masked, and get random indices, dividing it up for mask vs unmasked

View File

@@ -19,20 +19,20 @@ def cast_tuple(val, length = 1):
# helper classes
class PreNormResidual(nn.Module):
class Residual(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x):
return self.fn(self.norm(x)) + x
return self.fn(x) + x
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
inner_dim = int(dim * mult)
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim),
nn.GELU(),
nn.Dropout(dropout),
@@ -132,6 +132,7 @@ class Attention(nn.Module):
self.heads = dim // dim_head
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
self.attend = nn.Sequential(
@@ -160,6 +161,8 @@ class Attention(nn.Module):
def forward(self, x):
batch, height, width, window_height, window_width, _, device, h = *x.shape, x.device, self.heads
x = self.norm(x)
# flatten
x = rearrange(x, 'b x y w1 w2 d -> (b x y) (w1 w2) d')
@@ -170,7 +173,7 @@ class Attention(nn.Module):
# split heads
q, k, v = map(lambda t: rearrange(t, 'b n (h d ) -> b h n d', h = h), (q, k, v))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
# scale
@@ -259,13 +262,13 @@ class MaxViT(nn.Module):
shrinkage_rate = mbconv_shrinkage_rate
),
Rearrange('b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w), # block-like attention
PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
PreNormResidual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
Residual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
Residual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
Rearrange('b x y w1 w2 d -> b d (x w1) (y w2)'),
Rearrange('b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w), # grid-like attention
PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
PreNormResidual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
Residual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
Residual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
Rearrange('b x y w1 w2 d -> b d (w1 x) (w2 y)'),
)

View File

@@ -0,0 +1,340 @@
from functools import partial
import torch
from torch import nn, einsum
import torch.nn.functional as F
from torch.nn import Module, ModuleList, Sequential
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange, Reduce
# helpers
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def pack_one(x, pattern):
return pack([x], pattern)
def unpack_one(x, ps, pattern):
return unpack(x, ps, pattern)[0]
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
# helper classes
def FeedForward(dim, mult = 4, dropout = 0.):
inner_dim = int(dim * mult)
return Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
# MBConv
class SqueezeExcitation(Module):
def __init__(self, dim, shrinkage_rate = 0.25):
super().__init__()
hidden_dim = int(dim * shrinkage_rate)
self.gate = Sequential(
Reduce('b c h w -> b c', 'mean'),
nn.Linear(dim, hidden_dim, bias = False),
nn.SiLU(),
nn.Linear(hidden_dim, dim, bias = False),
nn.Sigmoid(),
Rearrange('b c -> b c 1 1')
)
def forward(self, x):
return x * self.gate(x)
class MBConvResidual(Module):
def __init__(self, fn, dropout = 0.):
super().__init__()
self.fn = fn
self.dropsample = Dropsample(dropout)
def forward(self, x):
out = self.fn(x)
out = self.dropsample(out)
return out + x
class Dropsample(Module):
def __init__(self, prob = 0):
super().__init__()
self.prob = prob
def forward(self, x):
device = x.device
if self.prob == 0. or (not self.training):
return x
keep_mask = torch.FloatTensor((x.shape[0], 1, 1, 1), device = device).uniform_() > self.prob
return x * keep_mask / (1 - self.prob)
def MBConv(
dim_in,
dim_out,
*,
downsample,
expansion_rate = 4,
shrinkage_rate = 0.25,
dropout = 0.
):
hidden_dim = int(expansion_rate * dim_out)
stride = 2 if downsample else 1
net = Sequential(
nn.Conv2d(dim_in, hidden_dim, 1),
nn.BatchNorm2d(hidden_dim),
nn.GELU(),
nn.Conv2d(hidden_dim, hidden_dim, 3, stride = stride, padding = 1, groups = hidden_dim),
nn.BatchNorm2d(hidden_dim),
nn.GELU(),
SqueezeExcitation(hidden_dim, shrinkage_rate = shrinkage_rate),
nn.Conv2d(hidden_dim, dim_out, 1),
nn.BatchNorm2d(dim_out)
)
if dim_in == dim_out and not downsample:
net = MBConvResidual(net, dropout = dropout)
return net
# attention related classes
class Attention(Module):
def __init__(
self,
dim,
dim_head = 32,
dropout = 0.,
window_size = 7,
num_registers = 1
):
super().__init__()
assert num_registers > 0
assert (dim % dim_head) == 0, 'dimension should be divisible by dimension per head'
self.heads = dim // dim_head
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
self.attend = nn.Sequential(
nn.Softmax(dim = -1),
nn.Dropout(dropout)
)
self.to_out = nn.Sequential(
nn.Linear(dim, dim, bias = False),
nn.Dropout(dropout)
)
# relative positional bias
num_rel_pos_bias = (2 * window_size - 1) ** 2
self.rel_pos_bias = nn.Embedding(num_rel_pos_bias + 1, self.heads)
pos = torch.arange(window_size)
grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
grid = rearrange(grid, 'c i j -> (i j) c')
rel_pos = rearrange(grid, 'i ... -> i 1 ...') - rearrange(grid, 'j ... -> 1 j ...')
rel_pos += window_size - 1
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1)
rel_pos_indices = F.pad(rel_pos_indices, (num_registers, 0, num_registers, 0), value = num_rel_pos_bias)
self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False)
def forward(self, x):
device, h, bias_indices = x.device, self.heads, self.rel_pos_indices
x = self.norm(x)
# project for queries, keys, values
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
# split heads
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
# scale
q = q * self.scale
# sim
sim = einsum('b h i d, b h j d -> b h i j', q, k)
# add positional bias
bias = self.rel_pos_bias(bias_indices)
sim = sim + rearrange(bias, 'i j h -> h i j')
# attention
attn = self.attend(sim)
# aggregate
out = einsum('b h i j, b h j d -> b h i d', attn, v)
# combine heads out
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class MaxViT(Module):
def __init__(
self,
*,
num_classes,
dim,
depth,
dim_head = 32,
dim_conv_stem = None,
window_size = 7,
mbconv_expansion_rate = 4,
mbconv_shrinkage_rate = 0.25,
dropout = 0.1,
channels = 3,
num_register_tokens = 4
):
super().__init__()
assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'
assert num_register_tokens > 0
# convolutional stem
dim_conv_stem = default(dim_conv_stem, dim)
self.conv_stem = Sequential(
nn.Conv2d(channels, dim_conv_stem, 3, stride = 2, padding = 1),
nn.Conv2d(dim_conv_stem, dim_conv_stem, 3, padding = 1)
)
# variables
num_stages = len(depth)
dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))
dims = (dim_conv_stem, *dims)
dim_pairs = tuple(zip(dims[:-1], dims[1:]))
self.layers = nn.ModuleList([])
# window size
self.window_size = window_size
self.register_tokens = nn.ParameterList([])
# iterate through stages
for ind, ((layer_dim_in, layer_dim), layer_depth) in enumerate(zip(dim_pairs, depth)):
for stage_ind in range(layer_depth):
is_first = stage_ind == 0
stage_dim_in = layer_dim_in if is_first else layer_dim
conv = MBConv(
stage_dim_in,
layer_dim,
downsample = is_first,
expansion_rate = mbconv_expansion_rate,
shrinkage_rate = mbconv_shrinkage_rate
)
block_attn = Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = window_size, num_registers = num_register_tokens)
block_ff = FeedForward(dim = layer_dim, dropout = dropout)
grid_attn = Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = window_size, num_registers = num_register_tokens)
grid_ff = FeedForward(dim = layer_dim, dropout = dropout)
register_tokens = nn.Parameter(torch.randn(num_register_tokens, layer_dim))
self.layers.append(ModuleList([
conv,
ModuleList([block_attn, block_ff]),
ModuleList([grid_attn, grid_ff])
]))
self.register_tokens.append(register_tokens)
# mlp head out
self.mlp_head = nn.Sequential(
Reduce('b d h w -> b d', 'mean'),
nn.LayerNorm(dims[-1]),
nn.Linear(dims[-1], num_classes)
)
def forward(self, x):
b, w = x.shape[0], self.window_size
x = self.conv_stem(x)
for (conv, (block_attn, block_ff), (grid_attn, grid_ff)), register_tokens in zip(self.layers, self.register_tokens):
x = conv(x)
# block-like attention
x = rearrange(x, 'b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w)
# prepare register tokens
r = repeat(register_tokens, 'n d -> b x y n d', b = b, x = x.shape[1],y = x.shape[2])
r, register_batch_ps = pack_one(r, '* n d')
x, window_ps = pack_one(x, 'b x y * d')
x, batch_ps = pack_one(x, '* n d')
x, register_ps = pack([r, x], 'b * d')
x = block_attn(x) + x
x = block_ff(x) + x
r, x = unpack(x, register_ps, 'b * d')
x = unpack_one(x, batch_ps, '* n d')
x = unpack_one(x, window_ps, 'b x y * d')
x = rearrange(x, 'b x y w1 w2 d -> b d (x w1) (y w2)')
r = unpack_one(r, register_batch_ps, '* n d')
# grid-like attention
x = rearrange(x, 'b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w)
# prepare register tokens
r = reduce(r, 'b x y n d -> b n d', 'mean')
r = repeat(r, 'b n d -> b x y n d', x = x.shape[1], y = x.shape[2])
r, register_batch_ps = pack_one(r, '* n d')
x, window_ps = pack_one(x, 'b x y * d')
x, batch_ps = pack_one(x, '* n d')
x, register_ps = pack([r, x], 'b * d')
x = grid_attn(x) + x
r, x = unpack(x, register_ps, 'b * d')
x = grid_ff(x) + x
x = unpack_one(x, batch_ps, '* n d')
x = unpack_one(x, window_ps, 'b x y * d')
x = rearrange(x, 'b x y w1 w2 d -> b d (w1 x) (w2 y)')
return self.mlp_head(x)

View File

@@ -22,20 +22,11 @@ def conv_nxn_bn(inp, oup, kernel_size=3, stride=1):
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.SiLU(),
nn.Dropout(dropout),
@@ -53,6 +44,7 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(dropout)
@@ -64,9 +56,10 @@ class Attention(nn.Module):
)
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(
t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
@@ -88,8 +81,8 @@ class Transformer(nn.Module):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
Attention(dim, heads, dim_head, dropout),
FeedForward(dim, mlp_dim, dropout)
]))
def forward(self, x):
@@ -167,11 +160,9 @@ class MobileViTBlock(nn.Module):
# Global representations
_, _, h, w = x.shape
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d',
ph=self.ph, pw=self.pw)
x = self.transformer(x)
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)',
h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
x = self.transformer(x)
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)
# Fusion
x = self.conv3(x)

View File

@@ -96,6 +96,9 @@ class MPP(nn.Module):
self.loss = MPPLoss(patch_size, channels, output_channel_bits,
max_pixel_val, mean, std)
# extract patching function
self.patch_to_emb = nn.Sequential(transformer.to_patch_embedding[1:])
# output transformation
self.to_bits = nn.Linear(dim, 2**(output_channel_bits * channels))
@@ -151,7 +154,7 @@ class MPP(nn.Module):
masked_input[bool_mask_replace] = self.mask_token
# linear embedding of patches
masked_input = transformer.to_patch_embedding[-1](masked_input)
masked_input = self.patch_to_emb(masked_input)
# add cls token to input sequence
b, n, _ = masked_input.shape

389
vit_pytorch/na_vit.py Normal file
View File

@@ -0,0 +1,389 @@
from functools import partial
from typing import List, Union
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from torch.nn.utils.rnn import pad_sequence as orig_pad_sequence
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def always(val):
return lambda *args: val
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def divisible_by(numer, denom):
return (numer % denom) == 0
# auto grouping images
def group_images_by_max_seq_len(
images: List[Tensor],
patch_size: int,
calc_token_dropout = None,
max_seq_len = 2048
) -> List[List[Tensor]]:
calc_token_dropout = default(calc_token_dropout, always(0.))
groups = []
group = []
seq_len = 0
if isinstance(calc_token_dropout, (float, int)):
calc_token_dropout = always(calc_token_dropout)
for image in images:
assert isinstance(image, Tensor)
image_dims = image.shape[-2:]
ph, pw = map(lambda t: t // patch_size, image_dims)
image_seq_len = (ph * pw)
image_seq_len = int(image_seq_len * (1 - calc_token_dropout(*image_dims)))
assert image_seq_len <= max_seq_len, f'image with dimensions {image_dims} exceeds maximum sequence length'
if (seq_len + image_seq_len) > max_seq_len:
groups.append(group)
group = []
seq_len = 0
group.append(image)
seq_len += image_seq_len
if len(group) > 0:
groups.append(group)
return groups
# normalization
# they use layernorm without bias, something that pytorch does not offer
class LayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.ones(dim))
self.register_buffer('beta', torch.zeros(dim))
def forward(self, x):
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
# they use a query-key normalization that is equivalent to rms norm (no mean-centering, learned gamma), from vit 22B paper
class RMSNorm(nn.Module):
def __init__(self, heads, dim):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(heads, 1, dim))
def forward(self, x):
normed = F.normalize(x, dim = -1)
return normed * self.scale * self.gamma
# feedforward
def FeedForward(dim, hidden_dim, dropout = 0.):
return nn.Sequential(
LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.norm = LayerNorm(dim)
self.q_norm = RMSNorm(heads, dim_head)
self.k_norm = RMSNorm(heads, dim_head)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
nn.Dropout(dropout)
)
def forward(
self,
x,
context = None,
mask = None,
attn_mask = None
):
x = self.norm(x)
kv_input = default(context, x)
qkv = (self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
q = self.q_norm(q)
k = self.k_norm(k)
dots = torch.matmul(q, k.transpose(-1, -2))
if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
dots = dots.masked_fill(~mask, -torch.finfo(dots.dtype).max)
if exists(attn_mask):
dots = dots.masked_fill(~attn_mask, -torch.finfo(dots.dtype).max)
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
self.norm = LayerNorm(dim)
def forward(
self,
x,
mask = None,
attn_mask = None
):
for attn, ff in self.layers:
x = attn(x, mask = mask, attn_mask = attn_mask) + x
x = ff(x) + x
return self.norm(x)
class NaViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., token_dropout_prob = None):
super().__init__()
image_height, image_width = pair(image_size)
# what percent of tokens to dropout
# if int or float given, then assume constant dropout prob
# otherwise accept a callback that in turn calculates dropout prob from height and width
self.calc_token_dropout = None
if callable(token_dropout_prob):
self.calc_token_dropout = token_dropout_prob
elif isinstance(token_dropout_prob, (float, int)):
assert 0. < token_dropout_prob < 1.
token_dropout_prob = float(token_dropout_prob)
self.calc_token_dropout = lambda height, width: token_dropout_prob
# calculate patching related stuff
assert divisible_by(image_height, patch_size) and divisible_by(image_width, patch_size), 'Image dimensions must be divisible by the patch size.'
patch_height_dim, patch_width_dim = (image_height // patch_size), (image_width // patch_size)
patch_dim = channels * (patch_size ** 2)
self.channels = channels
self.patch_size = patch_size
self.to_patch_embedding = nn.Sequential(
LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
LayerNorm(dim),
)
self.pos_embed_height = nn.Parameter(torch.randn(patch_height_dim, dim))
self.pos_embed_width = nn.Parameter(torch.randn(patch_width_dim, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
# final attention pooling queries
self.attn_pool_queries = nn.Parameter(torch.randn(dim))
self.attn_pool = Attention(dim = dim, dim_head = dim_head, heads = heads)
# output to logits
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
LayerNorm(dim),
nn.Linear(dim, num_classes, bias = False)
)
@property
def device(self):
return next(self.parameters()).device
def forward(
self,
batched_images: Union[List[Tensor], List[List[Tensor]]], # assume different resolution images already grouped correctly
group_images = False,
group_max_seq_len = 2048
):
p, c, device, has_token_dropout = self.patch_size, self.channels, self.device, exists(self.calc_token_dropout)
arange = partial(torch.arange, device = device)
pad_sequence = partial(orig_pad_sequence, batch_first = True)
# auto pack if specified
if group_images:
batched_images = group_images_by_max_seq_len(
batched_images,
patch_size = self.patch_size,
calc_token_dropout = self.calc_token_dropout,
max_seq_len = group_max_seq_len
)
# process images into variable lengthed sequences with attention mask
num_images = []
batched_sequences = []
batched_positions = []
batched_image_ids = []
for images in batched_images:
num_images.append(len(images))
sequences = []
positions = []
image_ids = torch.empty((0,), device = device, dtype = torch.long)
for image_id, image in enumerate(images):
assert image.ndim ==3 and image.shape[0] == c
image_dims = image.shape[-2:]
assert all([divisible_by(dim, p) for dim in image_dims]), f'height and width {image_dims} of images must be divisible by patch size {p}'
ph, pw = map(lambda dim: dim // p, image_dims)
pos = torch.stack(torch.meshgrid((
arange(ph),
arange(pw)
), indexing = 'ij'), dim = -1)
pos = rearrange(pos, 'h w c -> (h w) c')
seq = rearrange(image, 'c (h p1) (w p2) -> (h w) (c p1 p2)', p1 = p, p2 = p)
seq_len = seq.shape[-2]
if has_token_dropout:
token_dropout = self.calc_token_dropout(*image_dims)
num_keep = max(1, int(seq_len * (1 - token_dropout)))
keep_indices = torch.randn((seq_len,), device = device).topk(num_keep, dim = -1).indices
seq = seq[keep_indices]
pos = pos[keep_indices]
image_ids = F.pad(image_ids, (0, seq.shape[-2]), value = image_id)
sequences.append(seq)
positions.append(pos)
batched_image_ids.append(image_ids)
batched_sequences.append(torch.cat(sequences, dim = 0))
batched_positions.append(torch.cat(positions, dim = 0))
# derive key padding mask
lengths = torch.tensor([seq.shape[-2] for seq in batched_sequences], device = device, dtype = torch.long)
max_length = arange(lengths.amax().item())
key_pad_mask = rearrange(lengths, 'b -> b 1') <= rearrange(max_length, 'n -> 1 n')
# derive attention mask, and combine with key padding mask from above
batched_image_ids = pad_sequence(batched_image_ids)
attn_mask = rearrange(batched_image_ids, 'b i -> b 1 i 1') == rearrange(batched_image_ids, 'b j -> b 1 1 j')
attn_mask = attn_mask & rearrange(key_pad_mask, 'b j -> b 1 1 j')
# combine patched images as well as the patched width / height positions for 2d positional embedding
patches = pad_sequence(batched_sequences)
patch_positions = pad_sequence(batched_positions)
# need to know how many images for final attention pooling
num_images = torch.tensor(num_images, device = device, dtype = torch.long)
# to patches
x = self.to_patch_embedding(patches)
# factorized 2d absolute positional embedding
h_indices, w_indices = patch_positions.unbind(dim = -1)
h_pos = self.pos_embed_height[h_indices]
w_pos = self.pos_embed_width[w_indices]
x = x + h_pos + w_pos
# embed dropout
x = self.dropout(x)
# attention
x = self.transformer(x, attn_mask = attn_mask)
# do attention pooling at the end
max_queries = num_images.amax().item()
queries = repeat(self.attn_pool_queries, 'd -> b n d', n = max_queries, b = x.shape[0])
# attention pool mask
image_id_arange = arange(max_queries)
attn_pool_mask = rearrange(image_id_arange, 'i -> i 1') == rearrange(batched_image_ids, 'b j -> b 1 j')
attn_pool_mask = attn_pool_mask & rearrange(key_pad_mask, 'b j -> b 1 j')
attn_pool_mask = rearrange(attn_pool_mask, 'b i j -> b 1 i j')
# attention pool
x = self.attn_pool(queries, context = x, attn_mask = attn_pool_mask) + queries
x = rearrange(x, 'b n d -> (b n) d')
# each batch element may not have same amount of images
is_images = image_id_arange < rearrange(num_images, 'b -> b 1')
is_images = rearrange(is_images, 'b n -> (b n)')
x = x[is_images]
# project out to logits
x = self.to_latent(x)
return self.mlp_head(x)

View File

@@ -24,19 +24,11 @@ class LayerNorm(nn.Module):
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, mlp_mult = 4, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
LayerNorm(dim),
nn.Conv2d(dim, dim * mlp_mult, 1),
nn.GELU(),
nn.Dropout(dropout),
@@ -54,6 +46,7 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
@@ -66,6 +59,8 @@ class Attention(nn.Module):
def forward(self, x):
b, c, h, w, heads = *x.shape, self.heads
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> b h (x y) d', h = heads), qkv)
@@ -93,8 +88,8 @@ class Transformer(nn.Module):
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout))
Attention(dim, heads = heads, dropout = dropout),
FeedForward(dim, mlp_mult, dropout = dropout)
]))
def forward(self, x):
*_, h, w = x.shape

View File

@@ -19,18 +19,11 @@ class Parallel(nn.Module):
def forward(self, x):
return sum([fn(x) for fn in self.fns])
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
@@ -49,6 +42,7 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
@@ -60,6 +54,7 @@ class Attention(nn.Module):
) if project_out else nn.Identity()
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
@@ -77,8 +72,8 @@ class Transformer(nn.Module):
super().__init__()
self.layers = nn.ModuleList([])
attn_block = lambda: PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))
ff_block = lambda: PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
attn_block = lambda: Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)
ff_block = lambda: FeedForward(dim, mlp_dim, dropout = dropout)
for _ in range(depth):
self.layers.append(nn.ModuleList([

View File

@@ -17,18 +17,11 @@ def conv_output_size(image_size, kernel_size, stride, padding = 0):
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
@@ -47,6 +40,7 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
@@ -58,6 +52,8 @@ class Attention(nn.Module):
def forward(self, x):
b, n, _, h = *x.shape, self.heads
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
@@ -76,8 +72,8 @@ class Transformer(nn.Module):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
for attn, ff in self.layers:

View File

@@ -55,14 +55,6 @@ class DepthWiseConv2d(nn.Module):
# helper classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class SpatialConv(nn.Module):
def __init__(self, dim_in, dim_out, kernel, bias = False):
super().__init__()
@@ -86,6 +78,7 @@ class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0., use_glu = True):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim * 2 if use_glu else hidden_dim),
GEGLU() if use_glu else nn.GELU(),
nn.Dropout(dropout),
@@ -103,6 +96,7 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
@@ -121,6 +115,9 @@ class Attention(nn.Module):
b, n, _, h = *x.shape, self.heads
to_q_kwargs = {'fmap_dims': fmap_dims} if self.use_ds_conv else {}
x = self.norm(x)
q = self.to_q(x, **to_q_kwargs)
qkv = (q, *self.to_kv(x).chunk(2, dim = -1))
@@ -162,8 +159,8 @@ class Transformer(nn.Module):
self.pos_emb = AxialRotaryEmbedding(dim_head, max_freq = image_size)
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_rotary = use_rotary, use_ds_conv = use_ds_conv)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout, use_glu = use_glu))
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_rotary = use_rotary, use_ds_conv = use_ds_conv),
FeedForward(dim, mlp_dim, dropout = dropout, use_glu = use_glu)
]))
def forward(self, x, fmap_dims):
pos_emb = self.pos_emb(x[:, 1:])

View File

@@ -33,15 +33,6 @@ class ChanLayerNorm(nn.Module):
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = ChanLayerNorm(dim)
self.fn = fn
def forward(self, x):
return self.fn(self.norm(x))
class Downsample(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
@@ -65,6 +56,7 @@ class FeedForward(nn.Module):
super().__init__()
inner_dim = dim * expansion_factor
self.net = nn.Sequential(
ChanLayerNorm(dim),
nn.Conv2d(dim, inner_dim, 1),
nn.GELU(),
nn.Dropout(dropout),
@@ -92,6 +84,7 @@ class ScalableSelfAttention(nn.Module):
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.norm = ChanLayerNorm(dim)
self.to_q = nn.Conv2d(dim, dim_key * heads, 1, bias = False)
self.to_k = nn.Conv2d(dim, dim_key * heads, reduction_factor, stride = reduction_factor, bias = False)
self.to_v = nn.Conv2d(dim, dim_value * heads, reduction_factor, stride = reduction_factor, bias = False)
@@ -104,6 +97,8 @@ class ScalableSelfAttention(nn.Module):
def forward(self, x):
height, width, heads = *x.shape[-2:], self.heads
x = self.norm(x)
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
# split out heads
@@ -145,6 +140,7 @@ class InteractiveWindowedSelfAttention(nn.Module):
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.norm = ChanLayerNorm(dim)
self.local_interactive_module = nn.Conv2d(dim_value * heads, dim_value * heads, 3, padding = 1)
self.to_q = nn.Conv2d(dim, dim_key * heads, 1, bias = False)
@@ -159,6 +155,8 @@ class InteractiveWindowedSelfAttention(nn.Module):
def forward(self, x):
height, width, heads, wsz = *x.shape[-2:], self.heads, self.window_size
x = self.norm(x)
wsz_h, wsz_w = default(wsz, height), default(wsz, width)
assert (height % wsz_h) == 0 and (width % wsz_w) == 0, f'height ({height}) or width ({width}) of feature map is not divisible by the window size ({wsz_h}, {wsz_w})'
@@ -217,11 +215,11 @@ class Transformer(nn.Module):
is_first = ind == 0
self.layers.append(nn.ModuleList([
PreNorm(dim, ScalableSelfAttention(dim, heads = heads, dim_key = ssa_dim_key, dim_value = ssa_dim_value, reduction_factor = ssa_reduction_factor, dropout = dropout)),
PreNorm(dim, FeedForward(dim, expansion_factor = ff_expansion_factor, dropout = dropout)),
ScalableSelfAttention(dim, heads = heads, dim_key = ssa_dim_key, dim_value = ssa_dim_value, reduction_factor = ssa_reduction_factor, dropout = dropout),
FeedForward(dim, expansion_factor = ff_expansion_factor, dropout = dropout),
PEG(dim) if is_first else None,
PreNorm(dim, FeedForward(dim, expansion_factor = ff_expansion_factor, dropout = dropout)),
PreNorm(dim, InteractiveWindowedSelfAttention(dim, heads = heads, dim_key = iwsa_dim_key, dim_value = iwsa_dim_value, window_size = iwsa_window_size, dropout = dropout))
FeedForward(dim, expansion_factor = ff_expansion_factor, dropout = dropout),
InteractiveWindowedSelfAttention(dim, heads = heads, dim_key = iwsa_dim_key, dim_value = iwsa_dim_value, window_size = iwsa_window_size, dropout = dropout)
]))
self.norm = ChanLayerNorm(dim) if norm_output else nn.Identity()

View File

@@ -25,15 +25,6 @@ class ChanLayerNorm(nn.Module):
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = ChanLayerNorm(dim)
self.fn = fn
def forward(self, x):
return self.fn(self.norm(x))
class OverlappingPatchEmbed(nn.Module):
def __init__(self, dim_in, dim_out, stride = 2):
super().__init__()
@@ -59,6 +50,7 @@ class FeedForward(nn.Module):
super().__init__()
inner_dim = int(dim * mult)
self.net = nn.Sequential(
ChanLayerNorm(dim),
nn.Conv2d(dim, inner_dim, 1),
nn.GELU(),
nn.Dropout(dropout),
@@ -85,6 +77,8 @@ class DSSA(nn.Module):
self.window_size = window_size
inner_dim = dim_head * heads
self.norm = ChanLayerNorm(dim)
self.attend = nn.Sequential(
nn.Softmax(dim = -1),
nn.Dropout(dropout)
@@ -138,6 +132,8 @@ class DSSA(nn.Module):
assert (height % wsz) == 0 and (width % wsz) == 0, f'height {height} and width {width} must be divisible by window size {wsz}'
num_windows = (height // wsz) * (width // wsz)
x = self.norm(x)
# fold in windows for "depthwise" attention - not sure why it is named depthwise when it is just "windowed" attention
x = rearrange(x, 'b c (h w1) (w w2) -> (b h w) c (w1 w2)', w1 = wsz, w2 = wsz)
@@ -225,8 +221,8 @@ class Transformer(nn.Module):
for ind in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, DSSA(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mult = ff_mult, dropout = dropout)),
DSSA(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mult = ff_mult, dropout = dropout),
]))
self.norm = ChanLayerNorm(dim) if norm_output else nn.Identity()

View File

@@ -9,17 +9,15 @@ from einops.layers.torch import Rearrange
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
_, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype
y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1)
omega = 1. / (temperature ** omega)
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
omega = torch.arange(dim // 4) / (dim // 4 - 1)
omega = 1.0 / (temperature ** omega)
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
return pe.type(dtype)
# classes
@@ -66,6 +64,7 @@ class Attention(nn.Module):
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
@@ -76,7 +75,7 @@ class Transformer(nn.Module):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
return self.norm(x)
class SimpleViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
@@ -86,30 +85,33 @@ class SimpleViT(nn.Module):
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_height, p2 = patch_width),
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.pos_embedding = posemb_sincos_2d(
h = image_height // patch_height,
w = image_width // patch_width,
dim = dim,
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
self.pool = "mean"
self.to_latent = nn.Identity()
self.linear_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
self.linear_head = nn.Linear(dim, num_classes)
def forward(self, img):
*_, h, w, dtype = *img.shape, img.dtype
device = img.device
x = self.to_patch_embedding(img)
pe = posemb_sincos_2d(x)
x = rearrange(x, 'b ... d -> b (...) d') + pe
x += self.pos_embedding.to(device, dtype=x.dtype)
x = self.transformer(x)
x = x.mean(dim = 1)

View File

@@ -62,6 +62,7 @@ class Attention(nn.Module):
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
@@ -72,7 +73,7 @@ class Transformer(nn.Module):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
return self.norm(x)
class SimpleViT(nn.Module):
def __init__(self, *, seq_len, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
@@ -93,10 +94,7 @@ class SimpleViT(nn.Module):
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
self.to_latent = nn.Identity()
self.linear_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
self.linear_head = nn.Linear(dim, num_classes)
def forward(self, series):
*_, n, dtype = *series.shape, series.dtype

View File

@@ -77,6 +77,7 @@ class Attention(nn.Module):
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
@@ -87,7 +88,7 @@ class Transformer(nn.Module):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
return self.norm(x)
class SimpleViT(nn.Module):
def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
@@ -111,10 +112,7 @@ class SimpleViT(nn.Module):
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
self.to_latent = nn.Identity()
self.linear_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
self.linear_head = nn.Linear(dim, num_classes)
def forward(self, video):
*_, h, w, dtype = *video.shape, video.dtype

View File

@@ -0,0 +1,162 @@
import torch
from torch.fft import fft
from torch import nn
from einops import rearrange, reduce, pack, unpack
from einops.layers.torch import Rearrange
# helpers
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
omega = torch.arange(dim // 4) / (dim // 4 - 1)
omega = 1.0 / (temperature ** omega)
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
return pe.type(dtype)
# classes
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head),
FeedForward(dim, mlp_dim)
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)
class SimpleViT(nn.Module):
def __init__(self, *, image_size, patch_size, freq_patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
freq_patch_height, freq_patch_width = pair(freq_patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
assert image_height % freq_patch_height == 0 and image_width % freq_patch_width == 0, 'Image dimensions must be divisible by the freq patch size.'
patch_dim = channels * patch_height * patch_width
freq_patch_dim = channels * 2 * freq_patch_height * freq_patch_width
self.to_patch_embedding = nn.Sequential(
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.to_freq_embedding = nn.Sequential(
Rearrange("b c (h p1) (w p2) ri -> b (h w) (p1 p2 ri c)", p1 = freq_patch_height, p2 = freq_patch_width),
nn.LayerNorm(freq_patch_dim),
nn.Linear(freq_patch_dim, dim),
nn.LayerNorm(dim)
)
self.pos_embedding = posemb_sincos_2d(
h = image_height // patch_height,
w = image_width // patch_width,
dim = dim,
)
self.freq_pos_embedding = posemb_sincos_2d(
h = image_height // freq_patch_height,
w = image_width // freq_patch_width,
dim = dim
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
self.pool = "mean"
self.to_latent = nn.Identity()
self.linear_head = nn.Linear(dim, num_classes)
def forward(self, img):
device, dtype = img.device, img.dtype
x = self.to_patch_embedding(img)
freqs = torch.view_as_real(fft(img))
f = self.to_freq_embedding(freqs)
x += self.pos_embedding.to(device, dtype = dtype)
f += self.freq_pos_embedding.to(device, dtype = dtype)
x, ps = pack((f, x), 'b * d')
x = self.transformer(x)
_, x = unpack(x, ps, 'b * d')
x = reduce(x, 'b n d -> b d', 'mean')
x = self.to_latent(x)
return self.linear_head(x)
if __name__ == '__main__':
vit = SimpleViT(
num_classes = 1000,
image_size = 256,
patch_size = 8,
freq_patch_size = 8,
dim = 1024,
depth = 1,
heads = 8,
mlp_dim = 2048,
)
images = torch.randn(8, 3, 256, 256)
logits = vit(images)

View File

@@ -87,6 +87,7 @@ class Attention(nn.Module):
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
@@ -97,7 +98,7 @@ class Transformer(nn.Module):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
return self.norm(x)
class SimpleViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, patch_dropout = 0.5):
@@ -122,10 +123,7 @@ class SimpleViT(nn.Module):
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
self.to_latent = nn.Identity()
self.linear_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
self.linear_head = nn.Linear(dim, num_classes)
def forward(self, img):
*_, h, w, dtype = *img.shape, img.dtype

View File

@@ -0,0 +1,141 @@
import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange
from einops.layers.torch import Rearrange
# helpers
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
omega = torch.arange(dim // 4) / (dim // 4 - 1)
omega = 1.0 / (temperature ** omega)
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
return pe.type(dtype)
# they use a query-key normalization that is equivalent to rms norm (no mean-centering, learned gamma), from vit 22B paper
# in latest tweet, seem to claim more stable training at higher learning rates
# unsure if this has taken off within Brain, or it has some hidden drawback
class RMSNorm(nn.Module):
def __init__(self, heads, dim):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(heads, 1, dim) / self.scale)
def forward(self, x):
normed = F.normalize(x, dim = -1)
return normed * self.scale * self.gamma
# classes
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.q_norm = RMSNorm(heads, dim_head)
self.k_norm = RMSNorm(heads, dim_head)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
q = self.q_norm(q)
k = self.k_norm(k)
dots = torch.matmul(q, k.transpose(-1, -2))
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head),
FeedForward(dim, mlp_dim)
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)
class SimpleViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
patch_dim = channels * patch_height * patch_width
self.to_patch_embedding = nn.Sequential(
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.pos_embedding = posemb_sincos_2d(
h = image_height // patch_height,
w = image_width // patch_width,
dim = dim,
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
self.pool = "mean"
self.to_latent = nn.Identity()
self.linear_head = nn.LayerNorm(dim)
def forward(self, img):
device = img.device
x = self.to_patch_embedding(img)
x += self.pos_embedding.to(device, dtype=x.dtype)
x = self.transformer(x)
x = x.mean(dim = 1)
x = self.to_latent(x)
return self.linear_head(x)

View File

@@ -0,0 +1,134 @@
"""
Vision Transformers Need Registers
https://arxiv.org/abs/2309.16588
"""
import torch
from torch import nn
from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange
# helpers
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
omega = torch.arange(dim // 4) / (dim // 4 - 1)
omega = 1.0 / (temperature ** omega)
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
return pe.type(dtype)
# classes
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head),
FeedForward(dim, mlp_dim)
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)
class SimpleViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, num_register_tokens = 4, channels = 3, dim_head = 64):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
patch_dim = channels * patch_height * patch_width
self.to_patch_embedding = nn.Sequential(
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim))
self.pos_embedding = posemb_sincos_2d(
h = image_height // patch_height,
w = image_width // patch_width,
dim = dim,
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
self.pool = "mean"
self.to_latent = nn.Identity()
self.linear_head = nn.Linear(dim, num_classes)
def forward(self, img):
batch, device = img.shape[0], img.device
x = self.to_patch_embedding(img)
x += self.pos_embedding.to(device, dtype=x.dtype)
r = repeat(self.register_tokens, 'n d -> b n d', b = batch)
x, ps = pack([x, r], 'b * d')
x = self.transformer(x)
x, _ = unpack(x, ps, 'b * d')
x = x.mean(dim = 1)
x = self.to_latent(x)
return self.linear_head(x)

View File

@@ -42,20 +42,11 @@ class LayerNorm(nn.Module):
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
LayerNorm(dim),
nn.Conv2d(dim, dim * mult, 1),
nn.GELU(),
nn.Dropout(dropout),
@@ -99,6 +90,7 @@ class LocalAttention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = LayerNorm(dim)
self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False)
self.to_kv = nn.Conv2d(dim, inner_dim * 2, 1, bias = False)
@@ -108,6 +100,8 @@ class LocalAttention(nn.Module):
)
def forward(self, fmap):
fmap = self.norm(fmap)
shape, p = fmap.shape, self.patch_size
b, n, x, y, h = *shape, self.heads
x, y = map(lambda t: t // p, (x, y))
@@ -132,6 +126,8 @@ class GlobalAttention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = LayerNorm(dim)
self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False)
self.to_kv = nn.Conv2d(dim, inner_dim * 2, k, stride = k, bias = False)
@@ -143,6 +139,8 @@ class GlobalAttention(nn.Module):
)
def forward(self, x):
x = self.norm(x)
shape = x.shape
b, n, _, y, h = *shape, self.heads
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = 1))
@@ -164,10 +162,10 @@ class Transformer(nn.Module):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Residual(PreNorm(dim, LocalAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout, patch_size = local_patch_size))) if has_local else nn.Identity(),
Residual(PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout))) if has_local else nn.Identity(),
Residual(PreNorm(dim, GlobalAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout, k = global_k))),
Residual(PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout)))
Residual(LocalAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout, patch_size = local_patch_size)) if has_local else nn.Identity(),
Residual(FeedForward(dim, mlp_mult, dropout = dropout)) if has_local else nn.Identity(),
Residual(GlobalAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout, k = global_k)),
Residual(FeedForward(dim, mlp_mult, dropout = dropout))
]))
def forward(self, x):
for local_attn, ff1, global_attn, ff2 in self.layers:

View File

@@ -11,24 +11,18 @@ def pair(t):
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
@@ -41,6 +35,8 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
@@ -52,6 +48,8 @@ class Attention(nn.Module):
) if project_out else nn.Identity()
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
@@ -67,17 +65,20 @@ class Attention(nn.Module):
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
return self.norm(x)
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
@@ -107,10 +108,7 @@ class ViT(nn.Module):
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
self.mlp_head = nn.Linear(dim, num_classes)
def forward(self, img):
x = self.to_patch_embedding(img)

View File

@@ -6,18 +6,11 @@ from einops.layers.torch import Rearrange
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
@@ -36,6 +29,7 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
@@ -47,6 +41,7 @@ class Attention(nn.Module):
) if project_out else nn.Identity()
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
@@ -65,8 +60,8 @@ class Transformer(nn.Module):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
for attn, ff in self.layers:

View File

@@ -11,18 +11,11 @@ def pair(t):
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
@@ -41,6 +34,7 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
@@ -52,6 +46,7 @@ class Attention(nn.Module):
) if project_out else nn.Identity()
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
@@ -70,8 +65,8 @@ class Transformer(nn.Module):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
for attn, ff in self.layers:

View File

@@ -13,18 +13,11 @@ def pair(t):
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
@@ -41,6 +34,7 @@ class LSA(nn.Module):
self.heads = heads
self.temperature = nn.Parameter(torch.log(torch.tensor(dim_head ** -0.5)))
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
@@ -52,6 +46,7 @@ class LSA(nn.Module):
)
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
@@ -74,8 +69,8 @@ class Transformer(nn.Module):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, LSA(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
LSA(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
for attn, ff in self.layers:

View File

@@ -30,18 +30,11 @@ class PatchDropout(nn.Module):
return x[batch_indices, patch_indices_keep]
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
@@ -60,6 +53,7 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
@@ -71,6 +65,7 @@ class Attention(nn.Module):
) if project_out else nn.Identity()
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
@@ -89,8 +84,8 @@ class Transformer(nn.Module):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
for attn, ff in self.layers:

View File

@@ -32,18 +32,11 @@ class PatchMerger(nn.Module):
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
@@ -62,6 +55,7 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
@@ -73,6 +67,7 @@ class Attention(nn.Module):
) if project_out else nn.Identity()
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
@@ -88,6 +83,7 @@ class Attention(nn.Module):
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., patch_merge_layer = None, patch_merge_num_tokens = 8):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
self.patch_merge_layer_index = default(patch_merge_layer, depth // 2) - 1 # default to mid-way through transformer, as shown in paper
@@ -95,8 +91,8 @@ class Transformer(nn.Module):
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
for index, (attn, ff) in enumerate(self.layers):
@@ -106,7 +102,7 @@ class Transformer(nn.Module):
if index == self.patch_merge_layer_index:
x = self.patch_merger(x)
return x
return self.norm(x)
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, patch_merge_layer = None, patch_merge_num_tokens = 8, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
@@ -133,7 +129,6 @@ class ViT(nn.Module):
self.mlp_head = nn.Sequential(
Reduce('b n d -> b d', 'mean'),
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)

View File

@@ -14,18 +14,11 @@ def pair(t):
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
@@ -44,6 +37,7 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
@@ -55,6 +49,7 @@ class Attention(nn.Module):
) if project_out else nn.Identity()
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
@@ -70,17 +65,18 @@ class Attention(nn.Module):
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
return self.norm(x)
class ViT(nn.Module):
def __init__(
@@ -137,16 +133,13 @@ class ViT(nn.Module):
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
self.mlp_head = nn.Linear(dim, num_classes)
def forward(self, video):
x = self.to_patch_embedding(video)
b, f, n, _ = x.shape
x = x + self.pos_embedding
x = x + self.pos_embedding[:, :f, :n]
if exists(self.spatial_cls_token):
spatial_cls_tokens = repeat(self.spatial_cls_token, '1 1 d -> b f 1 d', b = b, f = f)

283
vit_pytorch/xcit.py Normal file
View File

@@ -0,0 +1,283 @@
from random import randrange
import torch
from torch import nn, einsum
from torch.nn import Module, ModuleList
import torch.nn.functional as F
from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange
# helpers
def exists(val):
return val is not None
def pack_one(t, pattern):
return pack([t], pattern)
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
def l2norm(t):
return F.normalize(t, dim = -1, p = 2)
def dropout_layers(layers, dropout):
if dropout == 0:
return layers
num_layers = len(layers)
to_drop = torch.zeros(num_layers).uniform_(0., 1.) < dropout
# make sure at least one layer makes it
if all(to_drop):
rand_index = randrange(num_layers)
to_drop[rand_index] = False
layers = [layer for (layer, drop) in zip(layers, to_drop) if not drop]
return layers
# classes
class LayerScale(Module):
def __init__(self, dim, fn, depth):
super().__init__()
if depth <= 18:
init_eps = 0.1
elif 18 > depth <= 24:
init_eps = 1e-5
else:
init_eps = 1e-6
self.fn = fn
self.scale = nn.Parameter(torch.full((dim,), init_eps))
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) * self.scale
class FeedForward(Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, context = None):
h = self.heads
x = self.norm(x)
context = x if not exists(context) else torch.cat((x, context), dim = 1)
qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = self.attend(sim)
attn = self.dropout(attn)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class XCAttention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.norm = nn.LayerNorm(dim)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
h = self.heads
x, ps = pack_one(x, 'b * d')
x = self.norm(x)
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h d n', h = h), (q, k, v))
q, k = map(l2norm, (q, k))
sim = einsum('b h i n, b h j n -> b h i j', q, k) * self.temperature.exp()
attn = self.attend(sim)
attn = self.dropout(attn)
out = einsum('b h i j, b h j n -> b h i n', attn, v)
out = rearrange(out, 'b h d n -> b n (h d)')
out = unpack_one(out, ps, 'b * d')
return self.to_out(out)
class LocalPatchInteraction(Module):
def __init__(self, dim, kernel_size = 3):
super().__init__()
assert (kernel_size % 2) == 1
padding = kernel_size // 2
self.net = nn.Sequential(
nn.LayerNorm(dim),
Rearrange('b h w c -> b c h w'),
nn.Conv2d(dim, dim, kernel_size, padding = padding, groups = dim),
nn.BatchNorm2d(dim),
nn.GELU(),
nn.Conv2d(dim, dim, kernel_size, padding = padding, groups = dim),
Rearrange('b c h w -> b h w c'),
)
def forward(self, x):
return self.net(x)
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., layer_dropout = 0.):
super().__init__()
self.layers = ModuleList([])
self.layer_dropout = layer_dropout
for ind in range(depth):
layer = ind + 1
self.layers.append(ModuleList([
LayerScale(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), depth = layer),
LayerScale(dim, FeedForward(dim, mlp_dim, dropout = dropout), depth = layer)
]))
def forward(self, x, context = None):
layers = dropout_layers(self.layers, dropout = self.layer_dropout)
for attn, ff in layers:
x = attn(x, context = context) + x
x = ff(x) + x
return x
class XCATransformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, local_patch_kernel_size = 3, dropout = 0., layer_dropout = 0.):
super().__init__()
self.layers = ModuleList([])
self.layer_dropout = layer_dropout
for ind in range(depth):
layer = ind + 1
self.layers.append(ModuleList([
LayerScale(dim, XCAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout), depth = layer),
LayerScale(dim, LocalPatchInteraction(dim, local_patch_kernel_size), depth = layer),
LayerScale(dim, FeedForward(dim, mlp_dim, dropout = dropout), depth = layer)
]))
def forward(self, x):
layers = dropout_layers(self.layers, dropout = self.layer_dropout)
for cross_covariance_attn, local_patch_interaction, ff in layers:
x = cross_covariance_attn(x) + x
x = local_patch_interaction(x) + x
x = ff(x) + x
return x
class XCiT(Module):
def __init__(
self,
*,
image_size,
patch_size,
num_classes,
dim,
depth,
cls_depth,
heads,
mlp_dim,
dim_head = 64,
dropout = 0.,
emb_dropout = 0.,
local_patch_kernel_size = 3,
layer_dropout = 0.
):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
patch_dim = 3 * patch_size ** 2
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_size, p2 = patch_size),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim)
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
self.cls_token = nn.Parameter(torch.randn(dim))
self.dropout = nn.Dropout(emb_dropout)
self.xcit_transformer = XCATransformer(dim, depth, heads, dim_head, mlp_dim, local_patch_kernel_size, dropout, layer_dropout)
self.final_norm = nn.LayerNorm(dim)
self.cls_transformer = Transformer(dim, cls_depth, heads, dim_head, mlp_dim, dropout, layer_dropout)
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
x = self.to_patch_embedding(img)
x, ps = pack_one(x, 'b * d')
b, n, _ = x.shape
x += self.pos_embedding[:, :n]
x = unpack_one(x, ps, 'b * d')
x = self.dropout(x)
x = self.xcit_transformer(x)
x = self.final_norm(x)
cls_tokens = repeat(self.cls_token, 'd -> b 1 d', b = b)
x = rearrange(x, 'b ... d -> b (...) d')
cls_tokens = self.cls_transformer(cls_tokens, context = x)
return self.mlp_head(cls_tokens[:, 0])