Compare commits

...

130 Commits

Author SHA1 Message Date
Phil Wang
4c37586510 add simple vit, from https://arxiv.org/abs/2205.01580 2022-05-03 19:33:48 -07:00
Phil Wang
4ef72fc4dc add EsViT, by popular request, an alternative to Dino that is compatible with efficient ViTs with accounting for regional self-supervised loss 2022-05-03 10:29:29 -07:00
Zhengzhong Tu
c2aab05ebf fix bibtex typo (#212) 2022-04-06 22:15:05 -07:00
Phil Wang
81661e3966 fix mbconv residual block 2022-04-06 16:43:06 -07:00
Phil Wang
13f8e123bb fix maxvit - need feedforwards after attention 2022-04-06 16:34:40 -07:00
Phil Wang
2d4089c88e link to maxvit in readme 2022-04-06 16:24:12 -07:00
Phil Wang
c7bb5fc43f maxvit intent to build (#211)
complete hybrid mbconv + block / grid efficient self attention MaxViT
2022-04-06 16:12:17 -07:00
Phil Wang
946b19be64 sponsor button 2022-04-06 14:12:11 -07:00
Phil Wang
d93cd84ccd let windowed tokens exchange information across heads a la talking heads prior to pointwise attention in sep-vit 2022-03-31 15:22:24 -07:00
Phil Wang
5d4c798949 cleanup sepvit 2022-03-31 14:35:11 -07:00
Phil Wang
d65a742efe intent to build (#210)
complete SepViT, from bytedance AI labs
2022-03-31 14:30:23 -07:00
Phil Wang
8c54e01492 do not layernorm on last transformer block for scalable vit, as there is already one in mlp head 2022-03-31 13:25:21 -07:00
Phil Wang
df656fe7c7 complete learnable memory ViT, for efficient fine-tuning and potentially plays into continual learning 2022-03-31 09:51:12 -07:00
Phil Wang
4e6a42a0ca correct need for post-attention dropout 2022-03-30 10:50:57 -07:00
Phil Wang
6d7298d8ad link to tensorflow2 translation by @taki0112 2022-03-28 09:05:34 -07:00
Phil Wang
9cd56ff29b CCT allow for rectangular images 2022-03-26 14:02:49 -07:00
Phil Wang
2aae406ce8 add proposed parallel vit from facebook ai for exploration purposes 2022-03-23 10:42:35 -07:00
Phil Wang
c2b2db2a54 fix window size of none for scalable vit for rectangular images 2022-03-22 17:37:59 -07:00
Phil Wang
719048d1bd some better defaults for scalable vit 2022-03-22 17:19:58 -07:00
Phil Wang
d27721a85a add scalable vit, from bytedance AI 2022-03-22 17:02:47 -07:00
Phil Wang
cb22cbbd19 update to einops 0.4, which is torchscript jit friendly 2022-03-22 13:58:00 -07:00
Phil Wang
6db20debb4 add patch merger 2022-03-01 16:50:17 -08:00
Phil Wang
1bae5d3cc5 allow for rectangular images for efficient adapter 2022-01-31 08:55:31 -08:00
Phil Wang
25b384297d return None from extractor if no attention layers 2022-01-28 17:49:58 -08:00
Phil Wang
64a07f50e6 epsilon should be inside square root 2022-01-24 17:24:41 -08:00
Phil Wang
126d204ff2 fix block repeats in readme example for Nest 2022-01-22 21:32:53 -08:00
Phil Wang
c1528acd46 fix feature maps in Nest, thanks to @MarkYangjiayi 2022-01-22 13:17:30 -08:00
Phil Wang
1cc0f182a6 decoder positional embedding needs to be reapplied https://twitter.com/giffmana/status/1479195631587631104 2022-01-06 13:14:41 -08:00
Phil Wang
28eaba6115 0.26.2 2022-01-03 12:56:34 -08:00
Phil Wang
0082301f9e build @jrounds suggestion 2022-01-03 12:56:25 -08:00
Phil Wang
91ed738731 0.26.1 2021-12-30 19:31:26 -08:00
Phil Wang
1b58daa20a Merge pull request #186 from chinhsuanwu/mobilevit
Update MobileViT
2021-12-30 19:31:01 -08:00
chinhsuanwu
f2414b2c1b Update MobileViT 2021-12-30 05:52:23 +08:00
Phil Wang
891b92eb74 readme 2021-12-28 16:00:00 -08:00
Phil Wang
70ba532599 add ViT for small datasets https://arxiv.org/abs/2112.13492 2021-12-28 10:58:21 -08:00
Phil Wang
e52ac41955 allow extractor to only return embeddings, to ready for vision transformers to be used in x-clip 2021-12-25 12:31:21 -08:00
Phil Wang
0891885485 include tests in package for conda 2021-12-22 12:44:29 -08:00
Phil Wang
976f489230 add some tests 2021-12-22 09:13:31 -08:00
Phil Wang
2c368d1d4e add extractor wrapper 2021-12-21 11:11:39 -08:00
Phil Wang
b983bbee39 release MobileViT, from @murufeng 2021-12-21 10:22:59 -08:00
Phil Wang
86a7302ba6 Merge pull request #181 from murufeng/main
Add MobileViT
2021-12-21 09:51:56 -08:00
murufeng
89d3a04b3f Add files via upload 2021-12-21 20:48:34 +08:00
murufeng
e7075c64aa Update README.md 2021-12-21 20:44:30 +08:00
murufeng
5ea1559e4c Add files via upload 2021-12-21 20:41:01 +08:00
Phil Wang
f4b0b14094 add ATS to table of contents 2021-12-03 20:07:18 -08:00
Phil Wang
365b4d931e add adaptive token sampling paper 2021-12-03 19:52:40 -08:00
Phil Wang
79c864d796 link to community youtuber 2021-11-24 08:13:52 -08:00
Phil Wang
b45c1356a1 cleanup 2021-11-22 22:53:02 -08:00
Phil Wang
ff44d97cb0 make initial channels customizable for PiT 2021-11-22 18:08:49 -08:00
Phil Wang
d35345df6a remove wip 2021-11-22 17:43:04 -08:00
Phil Wang
b69b5af34f dynamic positional bias for crossformer the more efficient way as described in appendix of paper 2021-11-22 17:39:36 -08:00
Phil Wang
36e32b70fb complete and release crossformer 2021-11-22 17:10:53 -08:00
Phil Wang
768e47441e crossformer without dynamic position bias 2021-11-22 16:21:55 -08:00
Phil Wang
de0b8ba189 additional diagram 2021-11-22 14:05:39 -08:00
Phil Wang
6665fc6cd1 cleanup region vit 2021-11-22 12:42:24 -08:00
Phil Wang
5b2382f9f0 intent to add 2021-11-22 12:00:03 -08:00
Phil Wang
9f8c60651d clearer mae 2021-11-22 10:19:48 -08:00
Phil Wang
5ae555750f add SimMIM 2021-11-21 15:50:19 -08:00
Phil Wang
c5a461661c Merge pull request #170 from ankandrew/patch-1
add Table of Contents
2021-11-17 16:55:09 -08:00
ankandrew
e212918e2d add Table of Contents 2021-11-17 21:21:19 -03:00
Phil Wang
dc57c75478 cleanup 2021-11-14 12:24:48 -08:00
Phil Wang
99c44cf5f6 readme 2021-11-14 11:49:12 -08:00
Phil Wang
5b16e8f809 readme 2021-11-12 20:19:38 -08:00
Phil Wang
e8f6d72033 release masked autoencoder 2021-11-12 20:08:48 -08:00
Phil Wang
cb1729af28 more efficient feedforward for regionvit 2021-11-07 17:18:59 -08:00
Phil Wang
9e50b2a41e readme 2021-11-07 09:59:49 -08:00
Phil Wang
06d375351e add RegionViT paper 2021-11-07 09:47:28 -08:00
Phil Wang
f196d1ec5b move freqs in RvT to linspace 2021-10-05 09:23:44 -07:00
Phil Wang
529044c9b3 Merge pull request #153 from developer0hye/fix-example
fix transforms for val an test process in example code
2021-09-02 06:57:16 -07:00
yhkwon-DT01
c30655f3bc fix transforms for val an test process 2021-09-02 17:30:18 +09:00
Phil Wang
d2d6de01d3 0.20.7 2021-08-30 08:14:43 -07:00
Phil Wang
b9eadaef60 Merge pull request #151 from developer0hye/patch-1
Cleanup Attention Class & matmul based implementation for TensorRT conversion
2021-08-30 08:14:11 -07:00
Yonghye Kwon
24ac8350bf remove unused package 2021-08-30 18:25:03 +09:00
Yonghye Kwon
ca3cef9de0 Cleanup Attention Class 2021-08-30 18:05:16 +09:00
Phil Wang
6e1be11517 0.20.6 2021-08-21 09:03:54 -07:00
Phil Wang
73ed562ce4 Merge pull request #147 from developer0hye/patch-4
Make T2T process any scale image
2021-08-21 09:03:42 -07:00
Phil Wang
ff863175a6 Merge pull request #146 from developer0hye/patch-1
Make Pit process image with width and height less than the image_size
2021-08-21 09:03:31 -07:00
Yonghye Kwon
ca0bdca192 Make model process any scale image
Related to #145
2021-08-21 22:35:26 +09:00
Yonghye Kwon
1c70271778 Support image with width and height less than the image_size
Related to #145
2021-08-21 22:25:46 +09:00
Phil Wang
d7d3febfe3 Merge pull request #144 from developer0hye/patch-1
Remove unused package
2021-08-20 10:14:02 -07:00
Yonghye Kwon
946815164a Remove unused package 2021-08-20 13:44:57 +09:00
Phil Wang
aeed3381c1 use hardswish for levit 2021-08-19 08:22:55 -07:00
Phil Wang
3f754956fb remove last transformer layer in t2t 2021-08-14 08:06:23 -07:00
Phil Wang
918869571c fix hard distillation, thanks to @CiaoHe 2021-08-12 08:40:57 -07:00
Phil Wang
e5324242be fix wrong norm in nest 2021-08-05 12:55:48 -07:00
Phil Wang
22da26fa4b fix recorder in data parallel situation 2021-07-08 10:15:07 -07:00
Phil Wang
a6c085a2df 0.20.0 for cct 2021-07-02 15:48:48 -07:00
Phil Wang
121353c604 Merge pull request #128 from stevenwalton/main
Adding Compact Convolutional Transformers (CCT)
2021-07-02 15:48:32 -07:00
alih
2ece3333da Minor changes 2021-07-01 17:51:35 -07:00
Ali Hassani
a73030c9aa Update README.md 2021-07-01 16:41:27 -07:00
Steven Walton
780f91a220 Tested and changed README format 2021-07-01 16:26:41 -07:00
Steven Walton
88451068e8 Adding CCT
Adding Compact Convolutional Transformers (CCT) from Escaping the Big Data
Paradigm with Compact Transformers by Hassani et. al.
https://arxiv.org/abs/2104.05704
2021-07-01 16:22:33 -07:00
Phil Wang
64a2ef6462 fix mpp 2021-06-16 16:46:32 -07:00
Phil Wang
53884f583f 0.19.5 2021-06-16 14:24:46 -07:00
Phil Wang
e616b5dcbc Merge pull request #101 from zankner/mpp-fix
Mpp fix
2021-06-16 14:24:26 -07:00
Phil Wang
60ad4e266e layernorm on channel dimension == instancenorm2d with affine set to true 2021-06-03 16:41:45 -07:00
Phil Wang
a254a0258a fix typo 2021-06-01 07:33:00 -07:00
Phil Wang
26df10c0b7 fix max pool in nest 2021-05-28 11:06:02 -07:00
Phil Wang
17cb8976df make nest resilient to dimension that are not divisible by number of heads 2021-05-27 22:41:07 -07:00
Phil Wang
daf3abbeb5 add NesT 2021-05-27 22:02:17 -07:00
Phil Wang
b483b16833 0.18.4 2021-05-18 14:40:33 -07:00
Phil Wang
c457573808 Merge pull request #118 from loctruong96/main
update  mpp.py to work on GPU
2021-05-18 14:40:17 -07:00
Loc Truong
e75b6d0251 Update mpp.py
fix issue with GPU device mismatch
2021-05-16 20:07:49 -07:00
Phil Wang
679e5be3e7 apply scale to 2d rel pos bias in levit 2021-05-10 11:37:23 -07:00
Phil Wang
7333979e6b add link to official repo for levit 2021-05-06 13:12:30 -07:00
Phil Wang
74b402377b add image 2021-05-02 15:40:53 -07:00
Phil Wang
41d2d460d0 link to yannic 2021-05-02 14:51:55 -07:00
Phil Wang
04f86dee3c implement SOTA new self-supervised learning technique from facebook for vision transformers, Dino 2021-05-02 14:00:36 -07:00
Phil Wang
6549522629 be able to accept non-square patches, thanks to @FilipAndersson245 2021-05-01 20:04:41 -07:00
Phil Wang
6a80a4ef89 update readme 2021-05-01 11:51:35 -07:00
Phil Wang
9f05587a7d 0.17.2 2021-04-30 06:44:59 -07:00
Phil Wang
65bb350e85 0.17.2 2021-04-30 06:44:54 -07:00
Phil Wang
fd4a7dfcf8 Merge pull request #102 from jon-tow/rvt-add-use-glu-flag
Add `use_glu` flag to `RvT`
2021-04-30 06:44:41 -07:00
Jonathan Tow
6f3a5fcf0b Add use_glu flag to RvT 2021-04-30 02:07:41 -04:00
Phil Wang
7807f24509 fix small bug 2021-04-29 15:39:41 -07:00
Phil Wang
a612327126 readme 2021-04-29 15:22:12 -07:00
Phil Wang
30a1335d31 release twins svt 2021-04-29 14:55:25 -07:00
Phil Wang
ab781f7ddb add Twins SVT (small) 2021-04-29 14:54:06 -07:00
Zack Ankner
a2df363224 adding un-normalizing targets and fix for mask token dimension 2021-04-29 15:43:22 -04:00
Phil Wang
4f3dbd003f for PiT, project to increased dimensions on first grouped conv for depthwise-conv 2021-04-29 12:41:00 -07:00
Zack Ankner
710b6d57d3 Merge pull request #1 from lucidrains/main
catch up
2021-04-29 19:33:25 +00:00
Phil Wang
60b5687a79 cleanup rvt 2021-04-27 11:45:46 -07:00
Phil Wang
0df1505662 add zeroing of weight parameters of batchnorm in levit just before residual connection, noticed by @EelcoHoogendoorn 2021-04-27 08:41:16 -07:00
Phil Wang
3df6c31c61 fix norm issues in cvt 2021-04-27 08:36:17 -07:00
Phil Wang
54af220930 fix cvt 2021-04-26 20:37:51 -07:00
Phil Wang
bad4b94e7b fix all issues with rotary vision transformer 2021-04-25 12:09:32 -07:00
Phil Wang
fbced01fe7 cite 2021-04-20 18:36:54 -07:00
Phil Wang
e42e9876bc offer a way to turn off ds conv in rotary vision transformer for ablation 2021-04-20 10:12:03 -07:00
Phil Wang
566365978d add ability to turn off rotary, for ablation 2021-04-20 09:00:27 -07:00
Phil Wang
34f78294d3 fix pooling bugs across a few new archs 2021-04-19 22:36:23 -07:00
62 changed files with 5734 additions and 157 deletions

3
.github/FUNDING.yml vendored Normal file
View File

@@ -0,0 +1,3 @@
# These are supported funding model platforms
github: [lucidrains]

33
.github/workflows/python-test.yml vendored Normal file
View File

@@ -0,0 +1,33 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: Test
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.7, 3.8, 3.9]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install pytest
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Test with pytest
run: |
python setup.py test

1
MANIFEST.in Normal file
View File

@@ -0,0 +1 @@
recursive-include tests *

1118
README.md

File diff suppressed because it is too large Load Diff

View File

@@ -364,9 +364,8 @@
"\n",
"val_transforms = transforms.Compose(\n",
" [\n",
" transforms.Resize((224, 224)),\n",
" transforms.RandomResizedCrop(224),\n",
" transforms.RandomHorizontalFlip(),\n",
" transforms.Resize(256),\n",
" transforms.CenterCrop(224),\n",
" transforms.ToTensor(),\n",
" ]\n",
")\n",
@@ -374,9 +373,8 @@
"\n",
"test_transforms = transforms.Compose(\n",
" [\n",
" transforms.Resize((224, 224)),\n",
" transforms.RandomResizedCrop(224),\n",
" transforms.RandomHorizontalFlip(),\n",
" transforms.Resize(256),\n",
" transforms.CenterCrop(224),\n",
" transforms.ToTensor(),\n",
" ]\n",
")\n"
@@ -6250,4 +6248,4 @@
},
"nbformat": 4,
"nbformat_minor": 1
}
}

BIN
images/ats.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 198 KiB

BIN
images/crossformer.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 169 KiB

BIN
images/crossformer2.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 237 KiB

BIN
images/dino.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 84 KiB

BIN
images/esvit.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 190 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 108 KiB

BIN
images/mae.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 198 KiB

BIN
images/max-vit.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 133 KiB

BIN
images/mbvit.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 206 KiB

BIN
images/nest.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 75 KiB

BIN
images/parallel-vit.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 14 KiB

BIN
images/patch_merger.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 54 KiB

BIN
images/regionvit.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 94 KiB

BIN
images/regionvit2.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 55 KiB

BIN
images/scalable-vit-1.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 79 KiB

BIN
images/scalable-vit-2.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 62 KiB

BIN
images/sep-vit.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 142 KiB

BIN
images/simmim.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 365 KiB

BIN
images/twins_svt.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 110 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 80 KiB

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.16.4',
version = '0.35.0',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',
@@ -15,8 +15,15 @@ setup(
'image recognition'
],
install_requires=[
'torch>=1.6',
'einops>=0.3'
'einops>=0.4.1',
'torch>=1.10',
'torchvision'
],
setup_requires=[
'pytest-runner',
],
tests_require=[
'pytest'
],
classifiers=[
'Development Status :: 4 - Beta',

20
tests/test.py Normal file
View File

@@ -0,0 +1,20 @@
import torch
from vit_pytorch import ViT
def test():
v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
img = torch.randn(1, 3, 256, 256)
preds = v(img)
assert preds.shape == (1, 1000), 'correct logits outputted'

View File

@@ -1 +1,5 @@
from vit_pytorch.vit import ViT
from vit_pytorch.simple_vit import SimpleViT
from vit_pytorch.mae import MAE
from vit_pytorch.dino import Dino

265
vit_pytorch/ats_vit.py Normal file
View File

@@ -0,0 +1,265 @@
import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch import nn, einsum
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
def exists(val):
return val is not None
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# adaptive token sampling functions and classes
def log(t, eps = 1e-6):
return torch.log(t + eps)
def sample_gumbel(shape, device, dtype, eps = 1e-6):
u = torch.empty(shape, device = device, dtype = dtype).uniform_(0, 1)
return -log(-log(u, eps), eps)
def batched_index_select(values, indices, dim = 1):
value_dims = values.shape[(dim + 1):]
values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
indices = indices[(..., *((None,) * len(value_dims)))]
indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
value_expand_len = len(indices_shape) - (dim + 1)
values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)]
value_expand_shape = [-1] * len(values.shape)
expand_slice = slice(dim, (dim + value_expand_len))
value_expand_shape[expand_slice] = indices.shape[expand_slice]
values = values.expand(*value_expand_shape)
dim += value_expand_len
return values.gather(dim, indices)
class AdaptiveTokenSampling(nn.Module):
def __init__(self, output_num_tokens, eps = 1e-6):
super().__init__()
self.eps = eps
self.output_num_tokens = output_num_tokens
def forward(self, attn, value, mask):
heads, output_num_tokens, eps, device, dtype = attn.shape[1], self.output_num_tokens, self.eps, attn.device, attn.dtype
# first get the attention values for CLS token to all other tokens
cls_attn = attn[..., 0, 1:]
# calculate the norms of the values, for weighting the scores, as described in the paper
value_norms = value[..., 1:, :].norm(dim = -1)
# weigh the attention scores by the norm of the values, sum across all heads
cls_attn = einsum('b h n, b h n -> b n', cls_attn, value_norms)
# normalize to 1
normed_cls_attn = cls_attn / (cls_attn.sum(dim = -1, keepdim = True) + eps)
# instead of using inverse transform sampling, going to invert the softmax and use gumbel-max sampling instead
pseudo_logits = log(normed_cls_attn)
# mask out pseudo logits for gumbel-max sampling
mask_without_cls = mask[:, 1:]
mask_value = -torch.finfo(attn.dtype).max / 2
pseudo_logits = pseudo_logits.masked_fill(~mask_without_cls, mask_value)
# expand k times, k being the adaptive sampling number
pseudo_logits = repeat(pseudo_logits, 'b n -> b k n', k = output_num_tokens)
pseudo_logits = pseudo_logits + sample_gumbel(pseudo_logits.shape, device = device, dtype = dtype)
# gumble-max and add one to reserve 0 for padding / mask
sampled_token_ids = pseudo_logits.argmax(dim = -1) + 1
# calculate unique using torch.unique and then pad the sequence from the right
unique_sampled_token_ids_list = [torch.unique(t, sorted = True) for t in torch.unbind(sampled_token_ids)]
unique_sampled_token_ids = pad_sequence(unique_sampled_token_ids_list, batch_first = True)
# calculate the new mask, based on the padding
new_mask = unique_sampled_token_ids != 0
# CLS token never gets masked out (gets a value of True)
new_mask = F.pad(new_mask, (1, 0), value = True)
# prepend a 0 token id to keep the CLS attention scores
unique_sampled_token_ids = F.pad(unique_sampled_token_ids, (1, 0), value = 0)
expanded_unique_sampled_token_ids = repeat(unique_sampled_token_ids, 'b n -> b h n', h = heads)
# gather the new attention scores
new_attn = batched_index_select(attn, expanded_unique_sampled_token_ids, dim = 2)
# return the sampled attention scores, new mask (denoting padding), as well as the sampled token indices (for the residual)
return new_attn, new_mask, unique_sampled_token_ids
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., output_num_tokens = None):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.output_num_tokens = output_num_tokens
self.ats = AdaptiveTokenSampling(output_num_tokens) if exists(output_num_tokens) else None
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, *, mask):
num_tokens = x.shape[1]
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
if exists(mask):
dots_mask = rearrange(mask, 'b i -> b 1 i 1') * rearrange(mask, 'b j -> b 1 1 j')
mask_value = -torch.finfo(dots.dtype).max
dots = dots.masked_fill(~dots_mask, mask_value)
attn = self.attend(dots)
attn = self.dropout(attn)
sampled_token_ids = None
# if adaptive token sampling is enabled
# and number of tokens is greater than the number of output tokens
if exists(self.output_num_tokens) and (num_tokens - 1) > self.output_num_tokens:
attn, mask, sampled_token_ids = self.ats(attn, v, mask = mask)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out), mask, sampled_token_ids
class Transformer(nn.Module):
def __init__(self, dim, depth, max_tokens_per_depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
assert len(max_tokens_per_depth) == depth, 'max_tokens_per_depth must be a tuple of length that is equal to the depth of the transformer'
assert sorted(max_tokens_per_depth, reverse = True) == list(max_tokens_per_depth), 'max_tokens_per_depth must be in decreasing order'
assert min(max_tokens_per_depth) > 0, 'max_tokens_per_depth must have at least 1 token at any layer'
self.layers = nn.ModuleList([])
for _, output_num_tokens in zip(range(depth), max_tokens_per_depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, output_num_tokens = output_num_tokens, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
b, n, device = *x.shape[:2], x.device
# use mask to keep track of the paddings when sampling tokens
# as the duplicates (when sampling) are just removed, as mentioned in the paper
mask = torch.ones((b, n), device = device, dtype = torch.bool)
token_ids = torch.arange(n, device = device)
token_ids = repeat(token_ids, 'n -> b n', b = b)
for attn, ff in self.layers:
attn_out, mask, sampled_token_ids = attn(x, mask = mask)
# when token sampling, one needs to then gather the residual tokens with the sampled token ids
if exists(sampled_token_ids):
x = batched_index_select(x, sampled_token_ids, dim = 1)
token_ids = batched_index_select(token_ids, sampled_token_ids, dim = 1)
x = x + attn_out
x = ff(x) + x
return x, token_ids
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, max_tokens_per_depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.Linear(patch_dim, dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, max_tokens_per_depth, heads, dim_head, mlp_dim, dropout)
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img, return_sampled_token_ids = False):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x, token_ids = self.transformer(x)
logits = self.mlp_head(x[:, 0])
if return_sampled_token_ids:
# remove CLS token and decrement by 1 to make -1 the padding
token_ids = token_ids[:, 1:] - 1
return logits, token_ids
return logits

View File

@@ -76,6 +76,7 @@ class Attention(nn.Module):
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.mix_heads_pre_attn = nn.Parameter(torch.randn(heads, heads))
self.mix_heads_post_attn = nn.Parameter(torch.randn(heads, heads))
@@ -96,7 +97,10 @@ class Attention(nn.Module):
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
dots = einsum('b h i j, h g -> b g i j', dots, self.mix_heads_pre_attn) # talking heads, pre-softmax
attn = self.attend(dots)
attn = self.dropout(attn)
attn = einsum('b h i j, h g -> b g i j', attn, self.mix_heads_post_attn) # talking heads, post-softmax
out = einsum('b h i j, b h j d -> b h i d', attn, v)

345
vit_pytorch/cct.py Normal file
View File

@@ -0,0 +1,345 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
# helpers
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# CCT Models
__all__ = ['cct_2', 'cct_4', 'cct_6', 'cct_7', 'cct_8', 'cct_14', 'cct_16']
def cct_2(*args, **kwargs):
return _cct(num_layers=2, num_heads=2, mlp_ratio=1, embedding_dim=128,
*args, **kwargs)
def cct_4(*args, **kwargs):
return _cct(num_layers=4, num_heads=2, mlp_ratio=1, embedding_dim=128,
*args, **kwargs)
def cct_6(*args, **kwargs):
return _cct(num_layers=6, num_heads=4, mlp_ratio=2, embedding_dim=256,
*args, **kwargs)
def cct_7(*args, **kwargs):
return _cct(num_layers=7, num_heads=4, mlp_ratio=2, embedding_dim=256,
*args, **kwargs)
def cct_8(*args, **kwargs):
return _cct(num_layers=8, num_heads=4, mlp_ratio=2, embedding_dim=256,
*args, **kwargs)
def cct_14(*args, **kwargs):
return _cct(num_layers=14, num_heads=6, mlp_ratio=3, embedding_dim=384,
*args, **kwargs)
def cct_16(*args, **kwargs):
return _cct(num_layers=16, num_heads=6, mlp_ratio=3, embedding_dim=384,
*args, **kwargs)
def _cct(num_layers, num_heads, mlp_ratio, embedding_dim,
kernel_size=3, stride=None, padding=None,
*args, **kwargs):
stride = stride if stride is not None else max(1, (kernel_size // 2) - 1)
padding = padding if padding is not None else max(1, (kernel_size // 2))
return CCT(num_layers=num_layers,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
embedding_dim=embedding_dim,
kernel_size=kernel_size,
stride=stride,
padding=padding,
*args, **kwargs)
# modules
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1):
super().__init__()
self.num_heads = num_heads
head_dim = dim // self.num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=False)
self.attn_drop = nn.Dropout(attention_dropout)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(projection_dropout)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class TransformerEncoderLayer(nn.Module):
"""
Inspired by torch.nn.TransformerEncoderLayer and
rwightman's timm package.
"""
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
attention_dropout=0.1, drop_path_rate=0.1):
super(TransformerEncoderLayer, self).__init__()
self.pre_norm = nn.LayerNorm(d_model)
self.self_attn = Attention(dim=d_model, num_heads=nhead,
attention_dropout=attention_dropout, projection_dropout=dropout)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout1 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(d_model)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.dropout2 = nn.Dropout(dropout)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
self.activation = F.gelu
def forward(self, src: torch.Tensor, *args, **kwargs) -> torch.Tensor:
src = src + self.drop_path(self.self_attn(self.pre_norm(src)))
src = self.norm1(src)
src2 = self.linear2(self.dropout1(self.activation(self.linear1(src))))
src = src + self.drop_path(self.dropout2(src2))
return src
def drop_path(x, drop_prob: float = 0., training: bool = False):
"""
Obtained from: github.com:rwightman/pytorch-image-models
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
"""
Obtained from: github.com:rwightman/pytorch-image-models
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
class Tokenizer(nn.Module):
def __init__(self,
kernel_size, stride, padding,
pooling_kernel_size=3, pooling_stride=2, pooling_padding=1,
n_conv_layers=1,
n_input_channels=3,
n_output_channels=64,
in_planes=64,
activation=None,
max_pool=True,
conv_bias=False):
super(Tokenizer, self).__init__()
n_filter_list = [n_input_channels] + \
[in_planes for _ in range(n_conv_layers - 1)] + \
[n_output_channels]
self.conv_layers = nn.Sequential(
*[nn.Sequential(
nn.Conv2d(n_filter_list[i], n_filter_list[i + 1],
kernel_size=(kernel_size, kernel_size),
stride=(stride, stride),
padding=(padding, padding), bias=conv_bias),
nn.Identity() if activation is None else activation(),
nn.MaxPool2d(kernel_size=pooling_kernel_size,
stride=pooling_stride,
padding=pooling_padding) if max_pool else nn.Identity()
)
for i in range(n_conv_layers)
])
self.flattener = nn.Flatten(2, 3)
self.apply(self.init_weight)
def sequence_length(self, n_channels=3, height=224, width=224):
return self.forward(torch.zeros((1, n_channels, height, width))).shape[1]
def forward(self, x):
return self.flattener(self.conv_layers(x)).transpose(-2, -1)
@staticmethod
def init_weight(m):
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
class TransformerClassifier(nn.Module):
def __init__(self,
seq_pool=True,
embedding_dim=768,
num_layers=12,
num_heads=12,
mlp_ratio=4.0,
num_classes=1000,
dropout_rate=0.1,
attention_dropout=0.1,
stochastic_depth_rate=0.1,
positional_embedding='sine',
sequence_length=None,
*args, **kwargs):
super().__init__()
positional_embedding = positional_embedding if \
positional_embedding in ['sine', 'learnable', 'none'] else 'sine'
dim_feedforward = int(embedding_dim * mlp_ratio)
self.embedding_dim = embedding_dim
self.sequence_length = sequence_length
self.seq_pool = seq_pool
assert sequence_length is not None or positional_embedding == 'none', \
f"Positional embedding is set to {positional_embedding} and" \
f" the sequence length was not specified."
if not seq_pool:
sequence_length += 1
self.class_emb = nn.Parameter(torch.zeros(1, 1, self.embedding_dim),
requires_grad=True)
else:
self.attention_pool = nn.Linear(self.embedding_dim, 1)
if positional_embedding != 'none':
if positional_embedding == 'learnable':
self.positional_emb = nn.Parameter(torch.zeros(1, sequence_length, embedding_dim),
requires_grad=True)
nn.init.trunc_normal_(self.positional_emb, std=0.2)
else:
self.positional_emb = nn.Parameter(self.sinusoidal_embedding(sequence_length, embedding_dim),
requires_grad=False)
else:
self.positional_emb = None
self.dropout = nn.Dropout(p=dropout_rate)
dpr = [x.item() for x in torch.linspace(0, stochastic_depth_rate, num_layers)]
self.blocks = nn.ModuleList([
TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads,
dim_feedforward=dim_feedforward, dropout=dropout_rate,
attention_dropout=attention_dropout, drop_path_rate=dpr[i])
for i in range(num_layers)])
self.norm = nn.LayerNorm(embedding_dim)
self.fc = nn.Linear(embedding_dim, num_classes)
self.apply(self.init_weight)
def forward(self, x):
if self.positional_emb is None and x.size(1) < self.sequence_length:
x = F.pad(x, (0, 0, 0, self.n_channels - x.size(1)), mode='constant', value=0)
if not self.seq_pool:
cls_token = self.class_emb.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
if self.positional_emb is not None:
x += self.positional_emb
x = self.dropout(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
if self.seq_pool:
x = torch.matmul(F.softmax(self.attention_pool(x), dim=1).transpose(-1, -2), x).squeeze(-2)
else:
x = x[:, 0]
x = self.fc(x)
return x
@staticmethod
def init_weight(m):
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@staticmethod
def sinusoidal_embedding(n_channels, dim):
pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)]
for p in range(n_channels)])
pe[:, 0::2] = torch.sin(pe[:, 0::2])
pe[:, 1::2] = torch.cos(pe[:, 1::2])
return pe.unsqueeze(0)
# CCT Main model
class CCT(nn.Module):
def __init__(self,
img_size=224,
embedding_dim=768,
n_input_channels=3,
n_conv_layers=1,
kernel_size=7,
stride=2,
padding=3,
pooling_kernel_size=3,
pooling_stride=2,
pooling_padding=1,
*args, **kwargs):
super(CCT, self).__init__()
img_height, img_width = pair(img_size)
self.tokenizer = Tokenizer(n_input_channels=n_input_channels,
n_output_channels=embedding_dim,
kernel_size=kernel_size,
stride=stride,
padding=padding,
pooling_kernel_size=pooling_kernel_size,
pooling_stride=pooling_stride,
pooling_padding=pooling_padding,
max_pool=True,
activation=nn.ReLU,
n_conv_layers=n_conv_layers,
conv_bias=False)
self.classifier = TransformerClassifier(
sequence_length=self.tokenizer.sequence_length(n_channels=n_input_channels,
height=img_height,
width=img_width),
embedding_dim=embedding_dim,
seq_pool=True,
dropout_rate=0.,
attention_dropout=0.1,
stochastic_depth=0.1,
*args, **kwargs)
def forward(self, x):
x = self.tokenizer(x)
return self.classifier(x)

View File

@@ -48,6 +48,8 @@ class Attention(nn.Module):
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
@@ -69,6 +71,7 @@ class Attention(nn.Module):
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')

267
vit_pytorch/crossformer.py Normal file
View File

@@ -0,0 +1,267 @@
import torch
from torch import nn, einsum
from einops import rearrange
from einops.layers.torch import Rearrange, Reduce
import torch.nn.functional as F
# helpers
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
# cross embed layer
class CrossEmbedLayer(nn.Module):
def __init__(
self,
dim_in,
dim_out,
kernel_sizes,
stride = 2
):
super().__init__()
kernel_sizes = sorted(kernel_sizes)
num_scales = len(kernel_sizes)
# calculate the dimension at each scale
dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]
dim_scales = [*dim_scales, dim_out - sum(dim_scales)]
self.convs = nn.ModuleList([])
for kernel, dim_scale in zip(kernel_sizes, dim_scales):
self.convs.append(nn.Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))
def forward(self, x):
fmaps = tuple(map(lambda conv: conv(x), self.convs))
return torch.cat(fmaps, dim = 1)
# dynamic positional bias
def DynamicPositionBias(dim):
return nn.Sequential(
nn.Linear(2, dim),
nn.LayerNorm(dim),
nn.ReLU(),
nn.Linear(dim, dim),
nn.LayerNorm(dim),
nn.ReLU(),
nn.Linear(dim, dim),
nn.LayerNorm(dim),
nn.ReLU(),
nn.Linear(dim, 1),
Rearrange('... () -> ...')
)
# transformer classes
class LayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
def FeedForward(dim, mult = 4, dropout = 0.):
return nn.Sequential(
LayerNorm(dim),
nn.Conv2d(dim, dim * mult, 1),
nn.GELU(),
nn.Dropout(dropout),
nn.Conv2d(dim * mult, dim, 1)
)
class Attention(nn.Module):
def __init__(
self,
dim,
attn_type,
window_size,
dim_head = 32,
dropout = 0.
):
super().__init__()
assert attn_type in {'short', 'long'}, 'attention type must be one of local or distant'
heads = dim // dim_head
self.heads = heads
self.scale = dim_head ** -0.5
inner_dim = dim_head * heads
self.attn_type = attn_type
self.window_size = window_size
self.norm = LayerNorm(dim)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(inner_dim, dim, 1)
# positions
self.dpb = DynamicPositionBias(dim // 4)
# calculate and store indices for retrieving bias
pos = torch.arange(window_size)
grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
grid = rearrange(grid, 'c i j -> (i j) c')
rel_pos = grid[:, None] - grid[None, :]
rel_pos += window_size - 1
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1)
self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False)
def forward(self, x):
*_, height, width, heads, wsz, device = *x.shape, self.heads, self.window_size, x.device
# prenorm
x = self.norm(x)
# rearrange for short or long distance attention
if self.attn_type == 'short':
x = rearrange(x, 'b d (h s1) (w s2) -> (b h w) d s1 s2', s1 = wsz, s2 = wsz)
elif self.attn_type == 'long':
x = rearrange(x, 'b d (l1 h) (l2 w) -> (b h w) d l1 l2', l1 = wsz, l2 = wsz)
# queries / keys / values
q, k, v = self.to_qkv(x).chunk(3, dim = 1)
# split heads
q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> b h (x y) d', h = heads), (q, k, v))
q = q * self.scale
sim = einsum('b h i d, b h j d -> b h i j', q, k)
# add dynamic positional bias
pos = torch.arange(-wsz, wsz + 1, device = device)
rel_pos = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
rel_pos = rearrange(rel_pos, 'c i j -> (i j) c')
biases = self.dpb(rel_pos.float())
rel_pos_bias = biases[self.rel_pos_indices]
sim = sim + rel_pos_bias
# attend
attn = sim.softmax(dim = -1)
attn = self.dropout(attn)
# merge heads
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = wsz, y = wsz)
out = self.to_out(out)
# rearrange back for long or short distance attention
if self.attn_type == 'short':
out = rearrange(out, '(b h w) d s1 s2 -> b d (h s1) (w s2)', h = height // wsz, w = width // wsz)
elif self.attn_type == 'long':
out = rearrange(out, '(b h w) d l1 l2 -> b d (l1 h) (l2 w)', h = height // wsz, w = width // wsz)
return out
class Transformer(nn.Module):
def __init__(
self,
dim,
*,
local_window_size,
global_window_size,
depth = 4,
dim_head = 32,
attn_dropout = 0.,
ff_dropout = 0.,
):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, attn_type = 'short', window_size = local_window_size, dim_head = dim_head, dropout = attn_dropout),
FeedForward(dim, dropout = ff_dropout),
Attention(dim, attn_type = 'long', window_size = global_window_size, dim_head = dim_head, dropout = attn_dropout),
FeedForward(dim, dropout = ff_dropout)
]))
def forward(self, x):
for short_attn, short_ff, long_attn, long_ff in self.layers:
x = short_attn(x) + x
x = short_ff(x) + x
x = long_attn(x) + x
x = long_ff(x) + x
return x
# classes
class CrossFormer(nn.Module):
def __init__(
self,
*,
dim = (64, 128, 256, 512),
depth = (2, 2, 8, 2),
global_window_size = (8, 4, 2, 1),
local_window_size = 7,
cross_embed_kernel_sizes = ((4, 8, 16, 32), (2, 4), (2, 4), (2, 4)),
cross_embed_strides = (4, 2, 2, 2),
num_classes = 1000,
attn_dropout = 0.,
ff_dropout = 0.,
channels = 3
):
super().__init__()
dim = cast_tuple(dim, 4)
depth = cast_tuple(depth, 4)
global_window_size = cast_tuple(global_window_size, 4)
local_window_size = cast_tuple(local_window_size, 4)
cross_embed_kernel_sizes = cast_tuple(cross_embed_kernel_sizes, 4)
cross_embed_strides = cast_tuple(cross_embed_strides, 4)
assert len(dim) == 4
assert len(depth) == 4
assert len(global_window_size) == 4
assert len(local_window_size) == 4
assert len(cross_embed_kernel_sizes) == 4
assert len(cross_embed_strides) == 4
# dimensions
last_dim = dim[-1]
dims = [channels, *dim]
dim_in_and_out = tuple(zip(dims[:-1], dims[1:]))
# layers
self.layers = nn.ModuleList([])
for (dim_in, dim_out), layers, global_wsz, local_wsz, cel_kernel_sizes, cel_stride in zip(dim_in_and_out, depth, global_window_size, local_window_size, cross_embed_kernel_sizes, cross_embed_strides):
self.layers.append(nn.ModuleList([
CrossEmbedLayer(dim_in, dim_out, cel_kernel_sizes, stride = cel_stride),
Transformer(dim_out, local_window_size = local_wsz, global_window_size = global_wsz, depth = layers, attn_dropout = attn_dropout, ff_dropout = ff_dropout)
]))
# final logits
self.to_logits = nn.Sequential(
Reduce('b c h w -> b c', 'mean'),
nn.Linear(last_dim, num_classes)
)
def forward(self, x):
for cel, transformer in self.layers:
x = cel(x)
x = transformer(x)
return self.to_logits(x)

View File

@@ -22,15 +22,25 @@ def group_by_key_prefix_and_remove_prefix(prefix, d):
# classes
class LayerNorm(nn.Module): # layernorm, but done in the channel dimension #1
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.norm = LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
x = rearrange(x, 'b c h w -> b h w c')
x = self.norm(x)
x = rearrange(x, 'b h w c -> b c h w')
return self.fn(x, **kwargs)
class FeedForward(nn.Module):
@@ -66,9 +76,10 @@ class Attention(nn.Module):
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_q = DepthWiseConv2d(dim, inner_dim, 3, padding = padding, stride = 1, bias = False)
self.to_kv = DepthWiseConv2d(dim, inner_dim * 2, 3, padding = padding, stride = kv_proj_stride, bias = False)
self.to_q = DepthWiseConv2d(dim, inner_dim, proj_kernel, padding = padding, stride = 1, bias = False)
self.to_kv = DepthWiseConv2d(dim, inner_dim * 2, proj_kernel, padding = padding, stride = kv_proj_stride, bias = False)
self.to_out = nn.Sequential(
nn.Conv2d(inner_dim, dim, 1),
@@ -84,6 +95,7 @@ class Attention(nn.Module):
dots = einsum('b i d, b j d -> b i j', q, k) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, y = y)
@@ -130,7 +142,7 @@ class CvT(nn.Module):
s3_emb_stride = 2,
s3_proj_kernel = 3,
s3_kv_proj_stride = 2,
s3_heads = 4,
s3_heads = 6,
s3_depth = 10,
s3_mlp_mult = 4,
dropout = 0.
@@ -146,17 +158,20 @@ class CvT(nn.Module):
layers.append(nn.Sequential(
nn.Conv2d(dim, config['emb_dim'], kernel_size = config['emb_kernel'], padding = (config['emb_kernel'] // 2), stride = config['emb_stride']),
LayerNorm(config['emb_dim']),
Transformer(dim = config['emb_dim'], proj_kernel = config['proj_kernel'], kv_proj_stride = config['kv_proj_stride'], depth = config['depth'], heads = config['heads'], mlp_mult = config['mlp_mult'], dropout = dropout)
))
dim = config['emb_dim']
self.layers = nn.Sequential(
*layers,
self.layers = nn.Sequential(*layers)
self.to_logits = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
Rearrange('... () () -> ...'),
nn.Linear(dim, num_classes)
)
def forward(self, x):
return self.layers(x)
latents = self.layers(x)
return self.to_logits(latents)

View File

@@ -42,6 +42,8 @@ class Attention(nn.Module):
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.dropout = nn.Dropout(dropout)
self.reattn_weights = nn.Parameter(torch.randn(heads, heads))
self.reattn_norm = nn.Sequential(
@@ -64,6 +66,7 @@ class Attention(nn.Module):
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = dots.softmax(dim=-1)
attn = self.dropout(attn)
# re-attention

303
vit_pytorch/dino.py Normal file
View File

@@ -0,0 +1,303 @@
import copy
import random
from functools import wraps, partial
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms as T
# helper functions
def exists(val):
return val is not None
def default(val, default):
return val if exists(val) else default
def singleton(cache_key):
def inner_fn(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
instance = getattr(self, cache_key)
if instance is not None:
return instance
instance = fn(self, *args, **kwargs)
setattr(self, cache_key, instance)
return instance
return wrapper
return inner_fn
def get_module_device(module):
return next(module.parameters()).device
def set_requires_grad(model, val):
for p in model.parameters():
p.requires_grad = val
# loss function # (algorithm 1 in the paper)
def loss_fn(
teacher_logits,
student_logits,
teacher_temp,
student_temp,
centers,
eps = 1e-20
):
teacher_logits = teacher_logits.detach()
student_probs = (student_logits / student_temp).softmax(dim = -1)
teacher_probs = ((teacher_logits - centers) / teacher_temp).softmax(dim = -1)
return - (teacher_probs * torch.log(student_probs + eps)).sum(dim = -1).mean()
# augmentation utils
class RandomApply(nn.Module):
def __init__(self, fn, p):
super().__init__()
self.fn = fn
self.p = p
def forward(self, x):
if random.random() > self.p:
return x
return self.fn(x)
# exponential moving average
class EMA():
def __init__(self, beta):
super().__init__()
self.beta = beta
def update_average(self, old, new):
if old is None:
return new
return old * self.beta + (1 - self.beta) * new
def update_moving_average(ema_updater, ma_model, current_model):
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = ema_updater.update_average(old_weight, up_weight)
# MLP class for projector and predictor
class L2Norm(nn.Module):
def forward(self, x, eps = 1e-6):
norm = x.norm(dim = 1, keepdim = True).clamp(min = eps)
return x / norm
class MLP(nn.Module):
def __init__(self, dim, dim_out, num_layers, hidden_size = 256):
super().__init__()
layers = []
dims = (dim, *((hidden_size,) * (num_layers - 1)))
for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
is_last = ind == (len(dims) - 1)
layers.extend([
nn.Linear(layer_dim_in, layer_dim_out),
nn.GELU() if not is_last else nn.Identity()
])
self.net = nn.Sequential(
*layers,
L2Norm(),
nn.Linear(hidden_size, dim_out)
)
def forward(self, x):
return self.net(x)
# a wrapper class for the base neural network
# will manage the interception of the hidden layer output
# and pipe it into the projecter and predictor nets
class NetWrapper(nn.Module):
def __init__(self, net, output_dim, projection_hidden_size, projection_num_layers, layer = -2):
super().__init__()
self.net = net
self.layer = layer
self.projector = None
self.projection_hidden_size = projection_hidden_size
self.projection_num_layers = projection_num_layers
self.output_dim = output_dim
self.hidden = {}
self.hook_registered = False
def _find_layer(self):
if type(self.layer) == str:
modules = dict([*self.net.named_modules()])
return modules.get(self.layer, None)
elif type(self.layer) == int:
children = [*self.net.children()]
return children[self.layer]
return None
def _hook(self, _, input, output):
device = input[0].device
self.hidden[device] = output.flatten(1)
def _register_hook(self):
layer = self._find_layer()
assert layer is not None, f'hidden layer ({self.layer}) not found'
handle = layer.register_forward_hook(self._hook)
self.hook_registered = True
@singleton('projector')
def _get_projector(self, hidden):
_, dim = hidden.shape
projector = MLP(dim, self.output_dim, self.projection_num_layers, self.projection_hidden_size)
return projector.to(hidden)
def get_embedding(self, x):
if self.layer == -1:
return self.net(x)
if not self.hook_registered:
self._register_hook()
self.hidden.clear()
_ = self.net(x)
hidden = self.hidden[x.device]
self.hidden.clear()
assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
return hidden
def forward(self, x, return_projection = True):
embed = self.get_embedding(x)
if not return_projection:
return embed
projector = self._get_projector(embed)
return projector(embed), embed
# main class
class Dino(nn.Module):
def __init__(
self,
net,
image_size,
hidden_layer = -2,
projection_hidden_size = 256,
num_classes_K = 65336,
projection_layers = 4,
student_temp = 0.9,
teacher_temp = 0.04,
local_upper_crop_scale = 0.4,
global_lower_crop_scale = 0.5,
moving_average_decay = 0.9,
center_moving_average_decay = 0.9,
augment_fn = None,
augment_fn2 = None
):
super().__init__()
self.net = net
# default BYOL augmentation
DEFAULT_AUG = torch.nn.Sequential(
RandomApply(
T.ColorJitter(0.8, 0.8, 0.8, 0.2),
p = 0.3
),
T.RandomGrayscale(p=0.2),
T.RandomHorizontalFlip(),
RandomApply(
T.GaussianBlur((3, 3), (1.0, 2.0)),
p = 0.2
),
T.Normalize(
mean=torch.tensor([0.485, 0.456, 0.406]),
std=torch.tensor([0.229, 0.224, 0.225])),
)
self.augment1 = default(augment_fn, DEFAULT_AUG)
self.augment2 = default(augment_fn2, DEFAULT_AUG)
# local and global crops
self.local_crop = T.RandomResizedCrop((image_size, image_size), scale = (0.05, local_upper_crop_scale))
self.global_crop = T.RandomResizedCrop((image_size, image_size), scale = (global_lower_crop_scale, 1.))
self.student_encoder = NetWrapper(net, num_classes_K, projection_hidden_size, projection_layers, layer = hidden_layer)
self.teacher_encoder = None
self.teacher_ema_updater = EMA(moving_average_decay)
self.register_buffer('teacher_centers', torch.zeros(1, num_classes_K))
self.register_buffer('last_teacher_centers', torch.zeros(1, num_classes_K))
self.teacher_centering_ema_updater = EMA(center_moving_average_decay)
self.student_temp = student_temp
self.teacher_temp = teacher_temp
# get device of network and make wrapper same device
device = get_module_device(net)
self.to(device)
# send a mock image tensor to instantiate singleton parameters
self.forward(torch.randn(2, 3, image_size, image_size, device=device))
@singleton('teacher_encoder')
def _get_teacher_encoder(self):
teacher_encoder = copy.deepcopy(self.student_encoder)
set_requires_grad(teacher_encoder, False)
return teacher_encoder
def reset_moving_average(self):
del self.teacher_encoder
self.teacher_encoder = None
def update_moving_average(self):
assert self.teacher_encoder is not None, 'target encoder has not been created yet'
update_moving_average(self.teacher_ema_updater, self.teacher_encoder, self.student_encoder)
new_teacher_centers = self.teacher_centering_ema_updater.update_average(self.teacher_centers, self.last_teacher_centers)
self.teacher_centers.copy_(new_teacher_centers)
def forward(
self,
x,
return_embedding = False,
return_projection = True,
student_temp = None,
teacher_temp = None
):
if return_embedding:
return self.student_encoder(x, return_projection = return_projection)
image_one, image_two = self.augment1(x), self.augment2(x)
local_image_one, local_image_two = self.local_crop(image_one), self.local_crop(image_two)
global_image_one, global_image_two = self.global_crop(image_one), self.global_crop(image_two)
student_proj_one, _ = self.student_encoder(local_image_one)
student_proj_two, _ = self.student_encoder(local_image_two)
with torch.no_grad():
teacher_encoder = self._get_teacher_encoder()
teacher_proj_one, _ = teacher_encoder(global_image_one)
teacher_proj_two, _ = teacher_encoder(global_image_two)
loss_fn_ = partial(
loss_fn,
student_temp = default(student_temp, self.student_temp),
teacher_temp = default(teacher_temp, self.teacher_temp),
centers = self.teacher_centers
)
teacher_logits_avg = torch.cat((teacher_proj_one, teacher_proj_two)).mean(dim = 0)
self.last_teacher_centers.copy_(teacher_logits_avg)
loss = (loss_fn_(teacher_proj_one, student_proj_two) + loss_fn_(teacher_proj_two, student_proj_one)) / 2
return loss

View File

@@ -148,6 +148,6 @@ class DistillWrapper(nn.Module):
else:
teacher_labels = teacher_logits.argmax(dim = -1)
distill_loss = F.cross_entropy(student_logits, teacher_labels)
distill_loss = F.cross_entropy(distill_logits, teacher_labels)
return loss * (1 - alpha) + distill_loss * alpha

View File

@@ -3,12 +3,16 @@ from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
def pair(t):
return t if isinstance(t, tuple) else (t, t)
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, transformer, pool = 'cls', channels = 3):
super().__init__()
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
image_size_h, image_size_w = pair(image_size)
assert image_size_h % patch_size == 0 and image_size_w % patch_size == 0, 'image dimensions must be divisible by the patch size'
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
num_patches = (image_size // patch_size) ** 2
num_patches = (image_size_h // patch_size) * (image_size_w // patch_size)
patch_dim = channels * patch_size ** 2
self.to_patch_embedding = nn.Sequential(

367
vit_pytorch/es_vit.py Normal file
View File

@@ -0,0 +1,367 @@
import copy
import random
from functools import wraps, partial
import torch
from torch import nn, einsum
import torch.nn.functional as F
from torchvision import transforms as T
from einops import rearrange, reduce, repeat
# helper functions
def exists(val):
return val is not None
def default(val, default):
return val if exists(val) else default
def singleton(cache_key):
def inner_fn(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
instance = getattr(self, cache_key)
if instance is not None:
return instance
instance = fn(self, *args, **kwargs)
setattr(self, cache_key, instance)
return instance
return wrapper
return inner_fn
def get_module_device(module):
return next(module.parameters()).device
def set_requires_grad(model, val):
for p in model.parameters():
p.requires_grad = val
# tensor related helpers
def log(t, eps = 1e-20):
return torch.log(t + eps)
# loss function # (algorithm 1 in the paper)
def view_loss_fn(
teacher_logits,
student_logits,
teacher_temp,
student_temp,
centers,
eps = 1e-20
):
teacher_logits = teacher_logits.detach()
student_probs = (student_logits / student_temp).softmax(dim = -1)
teacher_probs = ((teacher_logits - centers) / teacher_temp).softmax(dim = -1)
return - (teacher_probs * log(student_probs, eps)).sum(dim = -1).mean()
def region_loss_fn(
teacher_logits,
student_logits,
teacher_latent,
student_latent,
teacher_temp,
student_temp,
centers,
eps = 1e-20
):
teacher_logits = teacher_logits.detach()
student_probs = (student_logits / student_temp).softmax(dim = -1)
teacher_probs = ((teacher_logits - centers) / teacher_temp).softmax(dim = -1)
sim_matrix = einsum('b i d, b j d -> b i j', student_latent, teacher_latent)
sim_indices = sim_matrix.max(dim = -1).indices
sim_indices = repeat(sim_indices, 'b n -> b n k', k = teacher_probs.shape[-1])
max_sim_teacher_probs = teacher_probs.gather(1, sim_indices)
return - (max_sim_teacher_probs * log(student_probs, eps)).sum(dim = -1).mean()
# augmentation utils
class RandomApply(nn.Module):
def __init__(self, fn, p):
super().__init__()
self.fn = fn
self.p = p
def forward(self, x):
if random.random() > self.p:
return x
return self.fn(x)
# exponential moving average
class EMA():
def __init__(self, beta):
super().__init__()
self.beta = beta
def update_average(self, old, new):
if old is None:
return new
return old * self.beta + (1 - self.beta) * new
def update_moving_average(ema_updater, ma_model, current_model):
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = ema_updater.update_average(old_weight, up_weight)
# MLP class for projector and predictor
class L2Norm(nn.Module):
def forward(self, x, eps = 1e-6):
return F.normalize(x, dim = 1, eps = eps)
class MLP(nn.Module):
def __init__(self, dim, dim_out, num_layers, hidden_size = 256):
super().__init__()
layers = []
dims = (dim, *((hidden_size,) * (num_layers - 1)))
for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
is_last = ind == (len(dims) - 1)
layers.extend([
nn.Linear(layer_dim_in, layer_dim_out),
nn.GELU() if not is_last else nn.Identity()
])
self.net = nn.Sequential(
*layers,
L2Norm(),
nn.Linear(hidden_size, dim_out)
)
def forward(self, x):
return self.net(x)
# a wrapper class for the base neural network
# will manage the interception of the hidden layer output
# and pipe it into the projecter and predictor nets
class NetWrapper(nn.Module):
def __init__(self, net, output_dim, projection_hidden_size, projection_num_layers, layer = -2):
super().__init__()
self.net = net
self.layer = layer
self.view_projector = None
self.region_projector = None
self.projection_hidden_size = projection_hidden_size
self.projection_num_layers = projection_num_layers
self.output_dim = output_dim
self.hidden = {}
self.hook_registered = False
def _find_layer(self):
if type(self.layer) == str:
modules = dict([*self.net.named_modules()])
return modules.get(self.layer, None)
elif type(self.layer) == int:
children = [*self.net.children()]
return children[self.layer]
return None
def _hook(self, _, input, output):
device = input[0].device
self.hidden[device] = output
def _register_hook(self):
layer = self._find_layer()
assert layer is not None, f'hidden layer ({self.layer}) not found'
handle = layer.register_forward_hook(self._hook)
self.hook_registered = True
@singleton('view_projector')
def _get_view_projector(self, hidden):
dim = hidden.shape[1]
projector = MLP(dim, self.output_dim, self.projection_num_layers, self.projection_hidden_size)
return projector.to(hidden)
@singleton('region_projector')
def _get_region_projector(self, hidden):
dim = hidden.shape[1]
projector = MLP(dim, self.output_dim, self.projection_num_layers, self.projection_hidden_size)
return projector.to(hidden)
def get_embedding(self, x):
if self.layer == -1:
return self.net(x)
if not self.hook_registered:
self._register_hook()
self.hidden.clear()
_ = self.net(x)
hidden = self.hidden[x.device]
self.hidden.clear()
assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
return hidden
def forward(self, x, return_projection = True):
region_latents = self.get_embedding(x)
global_latent = reduce(region_latents, 'b c h w -> b c', 'mean')
if not return_projection:
return global_latent, region_latents
view_projector = self._get_view_projector(global_latent)
region_projector = self._get_region_projector(region_latents)
region_latents = rearrange(region_latents, 'b c h w -> b (h w) c')
return view_projector(global_latent), region_projector(region_latents), region_latents
# main class
class EsViTTrainer(nn.Module):
def __init__(
self,
net,
image_size,
hidden_layer = -2,
projection_hidden_size = 256,
num_classes_K = 65336,
projection_layers = 4,
student_temp = 0.9,
teacher_temp = 0.04,
local_upper_crop_scale = 0.4,
global_lower_crop_scale = 0.5,
moving_average_decay = 0.9,
center_moving_average_decay = 0.9,
augment_fn = None,
augment_fn2 = None
):
super().__init__()
self.net = net
# default BYOL augmentation
DEFAULT_AUG = torch.nn.Sequential(
RandomApply(
T.ColorJitter(0.8, 0.8, 0.8, 0.2),
p = 0.3
),
T.RandomGrayscale(p=0.2),
T.RandomHorizontalFlip(),
RandomApply(
T.GaussianBlur((3, 3), (1.0, 2.0)),
p = 0.2
),
T.Normalize(
mean=torch.tensor([0.485, 0.456, 0.406]),
std=torch.tensor([0.229, 0.224, 0.225])),
)
self.augment1 = default(augment_fn, DEFAULT_AUG)
self.augment2 = default(augment_fn2, DEFAULT_AUG)
# local and global crops
self.local_crop = T.RandomResizedCrop((image_size, image_size), scale = (0.05, local_upper_crop_scale))
self.global_crop = T.RandomResizedCrop((image_size, image_size), scale = (global_lower_crop_scale, 1.))
self.student_encoder = NetWrapper(net, num_classes_K, projection_hidden_size, projection_layers, layer = hidden_layer)
self.teacher_encoder = None
self.teacher_ema_updater = EMA(moving_average_decay)
self.register_buffer('teacher_view_centers', torch.zeros(1, num_classes_K))
self.register_buffer('last_teacher_view_centers', torch.zeros(1, num_classes_K))
self.register_buffer('teacher_region_centers', torch.zeros(1, num_classes_K))
self.register_buffer('last_teacher_region_centers', torch.zeros(1, num_classes_K))
self.teacher_centering_ema_updater = EMA(center_moving_average_decay)
self.student_temp = student_temp
self.teacher_temp = teacher_temp
# get device of network and make wrapper same device
device = get_module_device(net)
self.to(device)
# send a mock image tensor to instantiate singleton parameters
self.forward(torch.randn(2, 3, image_size, image_size, device=device))
@singleton('teacher_encoder')
def _get_teacher_encoder(self):
teacher_encoder = copy.deepcopy(self.student_encoder)
set_requires_grad(teacher_encoder, False)
return teacher_encoder
def reset_moving_average(self):
del self.teacher_encoder
self.teacher_encoder = None
def update_moving_average(self):
assert self.teacher_encoder is not None, 'target encoder has not been created yet'
update_moving_average(self.teacher_ema_updater, self.teacher_encoder, self.student_encoder)
new_teacher_view_centers = self.teacher_centering_ema_updater.update_average(self.teacher_view_centers, self.last_teacher_view_centers)
self.teacher_view_centers.copy_(new_teacher_view_centers)
new_teacher_region_centers = self.teacher_centering_ema_updater.update_average(self.teacher_region_centers, self.last_teacher_region_centers)
self.teacher_region_centers.copy_(new_teacher_region_centers)
def forward(
self,
x,
return_embedding = False,
return_projection = True,
student_temp = None,
teacher_temp = None
):
if return_embedding:
return self.student_encoder(x, return_projection = return_projection)
image_one, image_two = self.augment1(x), self.augment2(x)
local_image_one, local_image_two = self.local_crop(image_one), self.local_crop(image_two)
global_image_one, global_image_two = self.global_crop(image_one), self.global_crop(image_two)
student_view_proj_one, student_region_proj_one, student_latent_one = self.student_encoder(local_image_one)
student_view_proj_two, student_region_proj_two, student_latent_two = self.student_encoder(local_image_two)
with torch.no_grad():
teacher_encoder = self._get_teacher_encoder()
teacher_view_proj_one, teacher_region_proj_one, teacher_latent_one = teacher_encoder(global_image_one)
teacher_view_proj_two, teacher_region_proj_two, teacher_latent_two = teacher_encoder(global_image_two)
view_loss_fn_ = partial(
view_loss_fn,
student_temp = default(student_temp, self.student_temp),
teacher_temp = default(teacher_temp, self.teacher_temp),
centers = self.teacher_view_centers
)
region_loss_fn_ = partial(
region_loss_fn,
student_temp = default(student_temp, self.student_temp),
teacher_temp = default(teacher_temp, self.teacher_temp),
centers = self.teacher_region_centers
)
# calculate view-level loss
teacher_view_logits_avg = torch.cat((teacher_view_proj_one, teacher_view_proj_two)).mean(dim = 0)
self.last_teacher_view_centers.copy_(teacher_view_logits_avg)
teacher_region_logits_avg = torch.cat((teacher_region_proj_one, teacher_region_proj_two)).mean(dim = (0, 1))
self.last_teacher_region_centers.copy_(teacher_region_logits_avg)
view_loss = (view_loss_fn_(teacher_view_proj_one, student_view_proj_two) \
+ view_loss_fn_(teacher_view_proj_two, student_view_proj_one)) / 2
# calculate region-level loss
region_loss = (region_loss_fn_(teacher_region_proj_one, student_region_proj_two, teacher_latent_one, student_latent_two) \
+ region_loss_fn_(teacher_region_proj_two, student_region_proj_one, teacher_latent_two, student_latent_one)) / 2
return (view_loss + region_loss) / 2

70
vit_pytorch/extractor.py Normal file
View File

@@ -0,0 +1,70 @@
import torch
from torch import nn
def exists(val):
return val is not None
class Extractor(nn.Module):
def __init__(
self,
vit,
device = None,
layer_name = 'transformer',
layer_save_input = False,
return_embeddings_only = False
):
super().__init__()
self.vit = vit
self.data = None
self.latents = None
self.hooks = []
self.hook_registered = False
self.ejected = False
self.device = device
self.layer_name = layer_name
self.layer_save_input = layer_save_input # whether to save input or output of layer
self.return_embeddings_only = return_embeddings_only
def _hook(self, _, inputs, output):
tensor_to_save = inputs if self.layer_save_input else output
self.latents = tensor_to_save.clone().detach()
def _register_hook(self):
assert hasattr(self.vit, self.layer_name), 'layer whose output to take as embedding not found in vision transformer'
layer = getattr(self.vit, self.layer_name)
handle = layer.register_forward_hook(self._hook)
self.hooks.append(handle)
self.hook_registered = True
def eject(self):
self.ejected = True
for hook in self.hooks:
hook.remove()
self.hooks.clear()
return self.vit
def clear(self):
del self.latents
self.latents = None
def forward(
self,
img,
return_embeddings_only = False
):
assert not self.ejected, 'extractor has been ejected, cannot be used anymore'
self.clear()
if not self.hook_registered:
self._register_hook()
pred = self.vit(img)
target_device = self.device if exists(self.device) else img.device
latents = self.latents.to(target_device)
if return_embeddings_only or self.return_embeddings_only:
return latents
return pred, latents

View File

@@ -0,0 +1,216 @@
import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
def exists(val):
return val is not None
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# controlling freezing of layers
def set_module_requires_grad_(module, requires_grad):
for param in module.parameters():
param.requires_grad = requires_grad
def freeze_all_layers_(module):
set_module_requires_grad_(module, False)
def unfreeze_all_layers_(module):
set_module_requires_grad_(module, True)
# classes
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, attn_mask = None, memories = None):
x = self.norm(x)
x_kv = x # input for key / values projection
if exists(memories):
# add memories to key / values if it is passed in
memories = repeat(memories, 'n d -> b n d', b = x.shape[0]) if memories.ndim == 2 else memories
x_kv = torch.cat((x_kv, memories), dim = 1)
qkv = (self.to_q(x), *self.to_kv(x_kv).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
if exists(attn_mask):
dots = dots.masked_fill(~attn_mask, -torch.finfo(dots.dtype).max)
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x, attn_mask = None, memories = None):
for ind, (attn, ff) in enumerate(self.layers):
layer_memories = memories[ind] if exists(memories) else None
x = attn(x, attn_mask = attn_mask, memories = layer_memories) + x
x = ff(x) + x
return x
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.Linear(patch_dim, dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def img_to_tokens(self, img):
x = self.to_patch_embedding(img)
cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = x.shape[0])
x = torch.cat((cls_tokens, x), dim = 1)
x += self.pos_embedding
x = self.dropout(x)
return x
def forward(self, img):
x = self.img_to_tokens(img)
x = self.transformer(x)
cls_tokens = x[:, 0]
return self.mlp_head(cls_tokens)
# adapter with learnable memories per layer, memory CLS token, and learnable adapter head
class Adapter(nn.Module):
def __init__(
self,
*,
vit,
num_memories_per_layer = 10,
num_classes = 2,
):
super().__init__()
assert isinstance(vit, ViT)
# extract some model variables needed
dim = vit.cls_token.shape[-1]
layers = len(vit.transformer.layers)
num_patches = vit.pos_embedding.shape[-2]
self.vit = vit
# freeze ViT backbone - only memories will be finetuned
freeze_all_layers_(vit)
# learnable parameters
self.memory_cls_token = nn.Parameter(torch.randn(dim))
self.memories_per_layer = nn.Parameter(torch.randn(layers, num_memories_per_layer, dim))
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
# specialized attention mask to preserve the output of the original ViT
# it allows the memory CLS token to attend to all other tokens (and the learnable memory layer tokens), but not vice versa
attn_mask = torch.ones((num_patches, num_patches), dtype = torch.bool)
attn_mask = F.pad(attn_mask, (1, num_memories_per_layer), value = False) # main tokens cannot attend to learnable memories per layer
attn_mask = F.pad(attn_mask, (0, 0, 1, 0), value = True) # memory CLS token can attend to everything
self.register_buffer('attn_mask', attn_mask)
def forward(self, img):
b = img.shape[0]
tokens = self.vit.img_to_tokens(img)
# add task specific memory tokens
memory_cls_tokens = repeat(self.memory_cls_token, 'd -> b 1 d', b = b)
tokens = torch.cat((memory_cls_tokens, tokens), dim = 1)
# pass memories along with image tokens through transformer for attending
out = self.vit.transformer(tokens, memories = self.memories_per_layer, attn_mask = self.attn_mask)
# extract memory CLS tokens
memory_cls_tokens = out[:, 0]
# pass through task specific adapter head
return self.mlp_head(memory_cls_tokens)

View File

@@ -29,7 +29,7 @@ class FeedForward(nn.Module):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(dim, dim * mult, 1),
nn.GELU(),
nn.Hardswish(),
nn.Dropout(dropout),
nn.Conv2d(dim * mult, dim, 1),
nn.Dropout(dropout)
@@ -52,11 +52,15 @@ class Attention(nn.Module):
self.to_v = nn.Sequential(nn.Conv2d(dim, inner_dim_value, 1, bias = False), nn.BatchNorm2d(inner_dim_value))
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
out_batch_norm = nn.BatchNorm2d(dim_out)
nn.init.zeros_(out_batch_norm.weight)
self.to_out = nn.Sequential(
nn.GELU(),
nn.Conv2d(inner_dim_value, dim_out, 1),
nn.BatchNorm2d(dim_out),
out_batch_norm,
nn.Dropout(dropout)
)
@@ -67,8 +71,8 @@ class Attention(nn.Module):
q_range = torch.arange(0, fmap_size, step = (2 if downsample else 1))
k_range = torch.arange(fmap_size)
q_pos = torch.stack(torch.meshgrid(q_range, q_range), dim = -1)
k_pos = torch.stack(torch.meshgrid(k_range, k_range), dim = -1)
q_pos = torch.stack(torch.meshgrid(q_range, q_range, indexing = 'ij'), dim = -1)
k_pos = torch.stack(torch.meshgrid(k_range, k_range, indexing = 'ij'), dim = -1)
q_pos, k_pos = map(lambda t: rearrange(t, 'i j c -> (i j) c'), (q_pos, k_pos))
rel_pos = (q_pos[:, None, ...] - k_pos[None, :, ...]).abs()
@@ -81,7 +85,7 @@ class Attention(nn.Module):
def apply_pos_bias(self, fmap):
bias = self.pos_bias(self.pos_indices)
bias = rearrange(bias, 'i j h -> () h i j')
return fmap + bias
return fmap + (bias / self.scale)
def forward(self, x):
b, n, *_, h = *x.shape, self.heads
@@ -97,6 +101,7 @@ class Attention(nn.Module):
dots = self.apply_pos_bias(dots)
attn = self.attend(dots)
attn = self.dropout(attn)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h (x y) d -> b (h d) x y', h = h, y = y)

View File

@@ -78,6 +78,7 @@ class Attention(nn.Module):
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
@@ -93,6 +94,7 @@ class Attention(nn.Module):
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
@@ -149,4 +151,4 @@ class LocalViT(nn.Module):
x = self.transformer(x)
return self.mlp_head(x)
return self.mlp_head(x[:, 0])

96
vit_pytorch/mae.py Normal file
View File

@@ -0,0 +1,96 @@
import torch
from torch import nn
import torch.nn.functional as F
from einops import repeat
from vit_pytorch.vit import Transformer
class MAE(nn.Module):
def __init__(
self,
*,
encoder,
decoder_dim,
masking_ratio = 0.75,
decoder_depth = 1,
decoder_heads = 8,
decoder_dim_head = 64
):
super().__init__()
assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1'
self.masking_ratio = masking_ratio
# extract some hyperparameters and functions from encoder (vision transformer to be trained)
self.encoder = encoder
num_patches, encoder_dim = encoder.pos_embedding.shape[-2:]
self.to_patch, self.patch_to_emb = encoder.to_patch_embedding[:2]
pixel_values_per_patch = self.patch_to_emb.weight.shape[-1]
# decoder parameters
self.enc_to_dec = nn.Linear(encoder_dim, decoder_dim) if encoder_dim != decoder_dim else nn.Identity()
self.mask_token = nn.Parameter(torch.randn(decoder_dim))
self.decoder = Transformer(dim = decoder_dim, depth = decoder_depth, heads = decoder_heads, dim_head = decoder_dim_head, mlp_dim = decoder_dim * 4)
self.decoder_pos_emb = nn.Embedding(num_patches, decoder_dim)
self.to_pixels = nn.Linear(decoder_dim, pixel_values_per_patch)
def forward(self, img):
device = img.device
# get patches
patches = self.to_patch(img)
batch, num_patches, *_ = patches.shape
# patch to encoder tokens and add positions
tokens = self.patch_to_emb(patches)
tokens = tokens + self.encoder.pos_embedding[:, 1:(num_patches + 1)]
# calculate of patches needed to be masked, and get random indices, dividing it up for mask vs unmasked
num_masked = int(self.masking_ratio * num_patches)
rand_indices = torch.rand(batch, num_patches, device = device).argsort(dim = -1)
masked_indices, unmasked_indices = rand_indices[:, :num_masked], rand_indices[:, num_masked:]
# get the unmasked tokens to be encoded
batch_range = torch.arange(batch, device = device)[:, None]
tokens = tokens[batch_range, unmasked_indices]
# get the patches to be masked for the final reconstruction loss
masked_patches = patches[batch_range, masked_indices]
# attend with vision transformer
encoded_tokens = self.encoder.transformer(tokens)
# project encoder to decoder dimensions, if they are not equal - the paper says you can get away with a smaller dimension for decoder
decoder_tokens = self.enc_to_dec(encoded_tokens)
# reapply decoder position embedding to unmasked tokens
decoder_tokens = decoder_tokens + self.decoder_pos_emb(unmasked_indices)
# repeat mask tokens for number of masked, and add the positions using the masked indices derived above
mask_tokens = repeat(self.mask_token, 'd -> b n d', b = batch, n = num_masked)
mask_tokens = mask_tokens + self.decoder_pos_emb(masked_indices)
# concat the masked tokens to the decoder tokens and attend with decoder
decoder_tokens = torch.cat((mask_tokens, decoder_tokens), dim = 1)
decoded_tokens = self.decoder(decoder_tokens)
# splice out the mask tokens and project to pixel values
mask_tokens = decoded_tokens[:, :num_masked]
pred_pixel_values = self.to_pixels(mask_tokens)
# calculate reconstruction loss
recon_loss = F.mse_loss(pred_pixel_values, masked_patches)
return recon_loss

286
vit_pytorch/max_vit.py Normal file
View File

@@ -0,0 +1,286 @@
from functools import partial
import torch
from torch import nn, einsum
from einops import rearrange, repeat
from einops.layers.torch import Rearrange, Reduce
# helpers
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
# helper classes
class PreNormResidual(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x):
return self.fn(self.norm(x)) + x
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
inner_dim = int(dim * mult)
self.net = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
# MBConv
class SqueezeExcitation(nn.Module):
def __init__(self, dim, shrinkage_rate = 0.25):
super().__init__()
hidden_dim = int(dim * shrinkage_rate)
self.gate = nn.Sequential(
Reduce('b c h w -> b c', 'mean'),
nn.Linear(dim, hidden_dim, bias = False),
nn.SiLU(),
nn.Linear(hidden_dim, dim, bias = False),
nn.Sigmoid(),
Rearrange('b c -> b c 1 1')
)
def forward(self, x):
return x * self.gate(x)
class MBConvResidual(nn.Module):
def __init__(self, fn, dropout = 0.):
super().__init__()
self.fn = fn
self.dropsample = Dropsample(dropout)
def forward(self, x):
out = self.fn(x)
out = self.dropsample(out)
return out + x
class Dropsample(nn.Module):
def __init__(self, prob = 0):
super().__init__()
self.prob = prob
def forward(self, x):
device = x.device
if self.prob == 0. or (not self.training):
return x
keep_mask = torch.FloatTensor((x.shape[0], 1, 1, 1), device = device).uniform_() > self.prob
return x * keep_mask / (1 - self.prob)
def MBConv(
dim_in,
dim_out,
*,
downsample,
expansion_rate = 4,
shrinkage_rate = 0.25,
dropout = 0.
):
hidden_dim = int(expansion_rate * dim_out)
stride = 2 if downsample else 1
net = nn.Sequential(
nn.Conv2d(dim_in, dim_out, 1),
nn.BatchNorm2d(dim_out),
nn.SiLU(),
nn.Conv2d(dim_out, dim_out, 3, stride = stride, padding = 1, groups = dim_out),
SqueezeExcitation(dim_out, shrinkage_rate = shrinkage_rate),
nn.Conv2d(dim_out, dim_out, 1),
nn.BatchNorm2d(dim_out)
)
if dim_in == dim_out and not downsample:
net = MBConvResidual(net, dropout = dropout)
return net
# attention related classes
class Attention(nn.Module):
def __init__(
self,
dim,
dim_head = 32,
dropout = 0.,
window_size = 7
):
super().__init__()
assert (dim % dim_head) == 0, 'dimension should be divisible by dimension per head'
self.heads = dim // dim_head
self.scale = dim_head ** -0.5
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
self.attend = nn.Sequential(
nn.Softmax(dim = -1),
nn.Dropout(dropout)
)
self.to_out = nn.Sequential(
nn.Linear(dim, dim, bias = False),
nn.Dropout(dropout)
)
# relative positional bias
self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads)
pos = torch.arange(window_size)
grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
grid = rearrange(grid, 'c i j -> (i j) c')
rel_pos = rearrange(grid, 'i ... -> i 1 ...') - rearrange(grid, 'j ... -> 1 j ...')
rel_pos += window_size - 1
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1)
self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False)
def forward(self, x):
batch, height, width, window_height, window_width, _, device, h = *x.shape, x.device, self.heads
# flatten
x = rearrange(x, 'b x y w1 w2 d -> (b x y) (w1 w2) d')
# project for queries, keys, values
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
# split heads
q, k, v = map(lambda t: rearrange(t, 'b n (h d ) -> b h n d', h = h), (q, k, v))
# scale
q = q * self.scale
# sim
sim = einsum('b h i d, b h j d -> b h i j', q, k)
# add positional bias
bias = self.rel_pos_bias(self.rel_pos_indices)
sim = sim + rearrange(bias, 'i j h -> h i j')
# attention
attn = self.attend(sim)
# aggregate
out = einsum('b h i j, b h j d -> b h i d', attn, v)
# merge heads
out = rearrange(out, 'b h (w1 w2) d -> b w1 w2 (h d)', w1 = window_height, w2 = window_width)
# combine heads out
out = self.to_out(out)
return rearrange(out, '(b x y) ... -> b x y ...', x = height, y = width)
class MaxViT(nn.Module):
def __init__(
self,
*,
num_classes,
dim,
depth,
dim_head = 32,
dim_conv_stem = None,
window_size = 7,
mbconv_expansion_rate = 4,
mbconv_shrinkage_rate = 0.25,
dropout = 0.1,
channels = 3
):
super().__init__()
assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'
# convolutional stem
dim_conv_stem = default(dim_conv_stem, dim)
self.conv_stem = nn.Sequential(
nn.Conv2d(channels, dim_conv_stem, 3, stride = 2, padding = 1),
nn.Conv2d(dim_conv_stem, dim_conv_stem, 3, padding = 1)
)
# variables
num_stages = len(depth)
dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))
dims = (dim_conv_stem, *dims)
dim_pairs = tuple(zip(dims[:-1], dims[1:]))
self.layers = nn.ModuleList([])
# shorthand for window size for efficient block - grid like attention
w = window_size
# iterate through stages
for ind, ((layer_dim_in, layer_dim), layer_depth) in enumerate(zip(dim_pairs, depth)):
for stage_ind in range(layer_depth):
is_first = stage_ind == 0
stage_dim_in = layer_dim_in if is_first else layer_dim
block = nn.Sequential(
MBConv(
stage_dim_in,
layer_dim,
downsample = is_first,
expansion_rate = mbconv_expansion_rate,
shrinkage_rate = mbconv_shrinkage_rate
),
Rearrange('b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w), # block-like attention
PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
PreNormResidual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
Rearrange('b x y w1 w2 d -> b d (x w1) (y w2)'),
Rearrange('b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w), # grid-like attention
PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
PreNormResidual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
Rearrange('b x y w1 w2 d -> b d (w1 x) (w2 y)'),
)
self.layers.append(block)
# mlp head out
self.mlp_head = nn.Sequential(
Reduce('b d h w -> b d', 'mean'),
nn.LayerNorm(dims[-1]),
nn.Linear(dims[-1], num_classes)
)
def forward(self, x):
x = self.conv_stem(x)
for stage in self.layers:
x = stage(x)
return self.mlp_head(x)

252
vit_pytorch/mobile_vit.py Normal file
View File

@@ -0,0 +1,252 @@
import torch
import torch.nn as nn
from einops import rearrange
from einops.layers.torch import Reduce
# helpers
def conv_1x1_bn(inp, oup):
return nn.Sequential(
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.SiLU()
)
def conv_nxn_bn(inp, oup, kernal_size=3, stride=1):
return nn.Sequential(
nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False),
nn.BatchNorm2d(oup),
nn.SiLU()
)
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.SiLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(
t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b p h n d -> b p n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
"""Transformer block described in ViT.
Paper: https://arxiv.org/abs/2010.11929
Based on: https://github.com/lucidrains/vit-pytorch
"""
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class MV2Block(nn.Module):
"""MV2 block described in MobileNetV2.
Paper: https://arxiv.org/pdf/1801.04381
Based on: https://github.com/tonylins/pytorch-mobilenet-v2
"""
def __init__(self, inp, oup, stride=1, expansion=4):
super().__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = int(inp * expansion)
self.use_res_connect = self.stride == 1 and inp == oup
if expansion == 1:
self.conv = nn.Sequential(
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
1, groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
)
else:
self.conv = nn.Sequential(
# pw
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
1, groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
)
def forward(self, x):
out = self.conv(x)
if self.use_res_connect:
out = out + x
return out
class MobileViTBlock(nn.Module):
def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):
super().__init__()
self.ph, self.pw = patch_size
self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
self.conv2 = conv_1x1_bn(channel, dim)
self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout)
self.conv3 = conv_1x1_bn(dim, channel)
self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)
def forward(self, x):
y = x.clone()
# Local representations
x = self.conv1(x)
x = self.conv2(x)
# Global representations
_, _, h, w = x.shape
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d',
ph=self.ph, pw=self.pw)
x = self.transformer(x)
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)',
h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)
# Fusion
x = self.conv3(x)
x = torch.cat((x, y), 1)
x = self.conv4(x)
return x
class MobileViT(nn.Module):
"""MobileViT.
Paper: https://arxiv.org/abs/2110.02178
Based on: https://github.com/chinhsuanwu/mobilevit-pytorch
"""
def __init__(
self,
image_size,
dims,
channels,
num_classes,
expansion=4,
kernel_size=3,
patch_size=(2, 2),
depths=(2, 4, 3)
):
super().__init__()
assert len(dims) == 3, 'dims must be a tuple of 3'
assert len(depths) == 3, 'depths must be a tuple of 3'
ih, iw = image_size
ph, pw = patch_size
assert ih % ph == 0 and iw % pw == 0
init_dim, *_, last_dim = channels
self.conv1 = conv_nxn_bn(3, init_dim, stride=2)
self.stem = nn.ModuleList([])
self.stem.append(MV2Block(channels[0], channels[1], 1, expansion))
self.stem.append(MV2Block(channels[1], channels[2], 2, expansion))
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
self.trunk = nn.ModuleList([])
self.trunk.append(nn.ModuleList([
MV2Block(channels[3], channels[4], 2, expansion),
MobileViTBlock(dims[0], depths[0], channels[5],
kernel_size, patch_size, int(dims[0] * 2))
]))
self.trunk.append(nn.ModuleList([
MV2Block(channels[5], channels[6], 2, expansion),
MobileViTBlock(dims[1], depths[1], channels[7],
kernel_size, patch_size, int(dims[1] * 4))
]))
self.trunk.append(nn.ModuleList([
MV2Block(channels[7], channels[8], 2, expansion),
MobileViTBlock(dims[2], depths[2], channels[9],
kernel_size, patch_size, int(dims[2] * 4))
]))
self.to_logits = nn.Sequential(
conv_1x1_bn(channels[-2], last_dim),
Reduce('b c h w -> b c', 'mean'),
nn.Linear(channels[-1], num_classes, bias=False)
)
def forward(self, x):
x = self.conv1(x)
for conv in self.stem:
x = conv(x)
for conv, attn in self.trunk:
x = conv(x)
x = attn(x)
return self.to_logits(x)

View File

@@ -1,20 +1,20 @@
import math
from functools import reduce
import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange, repeat
from einops import rearrange, repeat, reduce
# helpers
def exists(val):
return val is not None
def prob_mask_like(t, prob):
batch, seq_length, _ = t.shape
return torch.zeros((batch, seq_length)).float().uniform_(0, 1) < prob
def get_mask_subset_with_prob(patched_input, prob):
batch, seq_len, _, device = *patched_input.shape, patched_input.device
max_masked = math.ceil(prob * seq_len)
@@ -31,43 +31,45 @@ def get_mask_subset_with_prob(patched_input, prob):
class MPPLoss(nn.Module):
def __init__(self, patch_size, channels, output_channel_bits,
max_pixel_val):
super(MPPLoss, self).__init__()
def __init__(
self,
patch_size,
channels,
output_channel_bits,
max_pixel_val,
mean,
std
):
super().__init__()
self.patch_size = patch_size
self.channels = channels
self.output_channel_bits = output_channel_bits
self.max_pixel_val = max_pixel_val
self.mean = torch.tensor(mean).view(-1, 1, 1) if mean else None
self.std = torch.tensor(std).view(-1, 1, 1) if std else None
def forward(self, predicted_patches, target, mask):
p, c, mpv, bits, device = self.patch_size, self.channels, self.max_pixel_val, self.output_channel_bits, target.device
bin_size = mpv / (2 ** bits)
# un-normalize input
if exists(self.mean) and exists(self.std):
target = target * self.std + self.mean
# reshape target to patches
p = self.patch_size
target = rearrange(target,
"b c (h p1) (w p2) -> b (h w) c (p1 p2) ",
p1=p,
p2=p)
target = target.clamp(max = mpv) # clamp just in case
avg_target = reduce(target, 'b c (h p1) (w p2) -> b (h w) c', 'mean', p1 = p, p2 = p).contiguous()
avg_target = target.mean(dim=3)
bin_size = self.max_pixel_val / self.output_channel_bits
channel_bins = torch.arange(bin_size, self.max_pixel_val, bin_size)
channel_bins = torch.arange(bin_size, mpv, bin_size, device = device)
discretized_target = torch.bucketize(avg_target, channel_bins)
discretized_target = F.one_hot(discretized_target,
self.output_channel_bits)
c, bi = self.channels, self.output_channel_bits
discretized_target = rearrange(discretized_target,
"b n c bi -> b n (c bi)",
c=c,
bi=bi)
bin_mask = 2**torch.arange(c * bi - 1, -1,
-1).to(discretized_target.device,
discretized_target.dtype)
target_label = torch.sum(bin_mask * discretized_target, -1)
bin_mask = (2 ** bits) ** torch.arange(0, c, device = device).long()
bin_mask = rearrange(bin_mask, 'c -> () () c')
predicted_patches = predicted_patches[mask]
target_label = target_label[mask]
loss = F.cross_entropy(predicted_patches, target_label)
target_label = torch.sum(bin_mask * discretized_target, dim = -1)
loss = F.cross_entropy(predicted_patches[mask], target_label[mask])
return loss
@@ -75,21 +77,24 @@ class MPPLoss(nn.Module):
class MPP(nn.Module):
def __init__(self,
transformer,
patch_size,
dim,
output_channel_bits=3,
channels=3,
max_pixel_val=1.0,
mask_prob=0.15,
replace_prob=0.5,
random_patch_prob=0.5):
def __init__(
self,
transformer,
patch_size,
dim,
output_channel_bits=3,
channels=3,
max_pixel_val=1.0,
mask_prob=0.15,
replace_prob=0.5,
random_patch_prob=0.5,
mean=None,
std=None
):
super().__init__()
self.transformer = transformer
self.loss = MPPLoss(patch_size, channels, output_channel_bits,
max_pixel_val)
max_pixel_val, mean, std)
# output transformation
self.to_bits = nn.Linear(dim, 2**(output_channel_bits * channels))
@@ -103,7 +108,7 @@ class MPP(nn.Module):
self.random_patch_prob = random_patch_prob
# token ids
self.mask_token = nn.Parameter(torch.randn(1, 1, dim * channels))
self.mask_token = nn.Parameter(torch.randn(1, 1, channels * patch_size ** 2))
def forward(self, input, **kwargs):
transformer = self.transformer
@@ -127,8 +132,9 @@ class MPP(nn.Module):
random_patch_sampling_prob = self.random_patch_prob / (
1 - self.replace_prob)
random_patch_prob = prob_mask_like(input,
random_patch_sampling_prob)
bool_random_patch_prob = mask * random_patch_prob == True
random_patch_sampling_prob).to(mask.device)
bool_random_patch_prob = mask * (random_patch_prob == True)
random_patches = torch.randint(0,
input.shape[1],
(input.shape[0], input.shape[1]),
@@ -140,7 +146,7 @@ class MPP(nn.Module):
bool_random_patch_prob]
# [mask] input
replace_prob = prob_mask_like(input, self.replace_prob)
replace_prob = prob_mask_like(input, self.replace_prob).to(mask.device)
bool_mask_replace = (mask * replace_prob) == True
masked_input[bool_mask_replace] = self.mask_token

183
vit_pytorch/nest.py Normal file
View File

@@ -0,0 +1,183 @@
from functools import partial
import torch
from torch import nn, einsum
from einops import rearrange
from einops.layers.torch import Rearrange, Reduce
# helpers
def cast_tuple(val, depth):
return val if isinstance(val, tuple) else ((val,) * depth)
# classes
class LayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, mlp_mult = 4, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(dim, dim * mlp_mult, 1),
nn.GELU(),
nn.Dropout(dropout),
nn.Conv2d(dim * mlp_mult, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dropout = 0.):
super().__init__()
dim_head = dim // heads
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
self.to_out = nn.Sequential(
nn.Conv2d(inner_dim, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
b, c, h, w, heads = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> b h (x y) d', h = heads), qkv)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
return self.to_out(out)
def Aggregate(dim, dim_out):
return nn.Sequential(
nn.Conv2d(dim, dim_out, 3, padding = 1),
LayerNorm(dim_out),
nn.MaxPool2d(3, stride = 2, padding = 1)
)
class Transformer(nn.Module):
def __init__(self, dim, seq_len, depth, heads, mlp_mult, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
self.pos_emb = nn.Parameter(torch.randn(seq_len))
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout))
]))
def forward(self, x):
*_, h, w = x.shape
pos_emb = self.pos_emb[:(h * w)]
pos_emb = rearrange(pos_emb, '(h w) -> () () h w', h = h, w = w)
x = x + pos_emb
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class NesT(nn.Module):
def __init__(
self,
*,
image_size,
patch_size,
num_classes,
dim,
heads,
num_hierarchies,
block_repeats,
mlp_mult = 4,
channels = 3,
dim_head = 64,
dropout = 0.
):
super().__init__()
assert (image_size % patch_size) == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
fmap_size = image_size // patch_size
blocks = 2 ** (num_hierarchies - 1)
seq_len = (fmap_size // blocks) ** 2 # sequence length is held constant across heirarchy
hierarchies = list(reversed(range(num_hierarchies)))
mults = [2 ** i for i in reversed(hierarchies)]
layer_heads = list(map(lambda t: t * heads, mults))
layer_dims = list(map(lambda t: t * dim, mults))
last_dim = layer_dims[-1]
layer_dims = [*layer_dims, layer_dims[-1]]
dim_pairs = zip(layer_dims[:-1], layer_dims[1:])
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (p1 p2 c) h w', p1 = patch_size, p2 = patch_size),
nn.Conv2d(patch_dim, layer_dims[0], 1),
)
block_repeats = cast_tuple(block_repeats, num_hierarchies)
self.layers = nn.ModuleList([])
for level, heads, (dim_in, dim_out), block_repeat in zip(hierarchies, layer_heads, dim_pairs, block_repeats):
is_last = level == 0
depth = block_repeat
self.layers.append(nn.ModuleList([
Transformer(dim_in, seq_len, depth, heads, mlp_mult, dropout),
Aggregate(dim_in, dim_out) if not is_last else nn.Identity()
]))
self.mlp_head = nn.Sequential(
LayerNorm(last_dim),
Reduce('b c h w -> b c', 'mean'),
nn.Linear(last_dim, num_classes)
)
def forward(self, img):
x = self.to_patch_embedding(img)
b, c, h, w = x.shape
num_hierarchies = len(self.layers)
for level, (transformer, aggregate) in zip(reversed(range(num_hierarchies)), self.layers):
block_size = 2 ** level
x = rearrange(x, 'b c (b1 h) (b2 w) -> (b b1 b2) c h w', b1 = block_size, b2 = block_size)
x = transformer(x)
x = rearrange(x, '(b b1 b2) c h w -> b c (b1 h) (b2 w)', b1 = block_size, b2 = block_size)
x = aggregate(x)
return self.mlp_head(x)

140
vit_pytorch/parallel_vit.py Normal file
View File

@@ -0,0 +1,140 @@
import torch
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# classes
class Parallel(nn.Module):
def __init__(self, *fns):
super().__init__()
self.fns = nn.ModuleList(fns)
def forward(self, x):
return sum([fn(x) for fn in self.fns])
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, num_parallel_branches = 2, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
attn_block = lambda: PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))
ff_block = lambda: PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
for _ in range(depth):
self.layers.append(nn.ModuleList([
Parallel(*[attn_block() for _ in range(num_parallel_branches)]),
Parallel(*[ff_block() for _ in range(num_parallel_branches)]),
]))
def forward(self, x):
for attns, ffs in self.layers:
x = attns(x) + x
x = ffs(x) + x
return x
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', num_parallel_branches = 2, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.Linear(patch_dim, dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, num_parallel_branches, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)

View File

@@ -48,6 +48,7 @@ class Attention(nn.Module):
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
@@ -63,6 +64,7 @@ class Attention(nn.Module):
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
@@ -89,8 +91,8 @@ class DepthWiseConv2d(nn.Module):
def __init__(self, dim_in, dim_out, kernel_size, padding, stride, bias = True):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)
nn.Conv2d(dim_in, dim_out, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
nn.Conv2d(dim_out, dim_out, kernel_size = 1, bias = bias)
)
def forward(self, x):
return self.net(x)
@@ -129,14 +131,15 @@ class PiT(nn.Module):
mlp_dim,
dim_head = 64,
dropout = 0.,
emb_dropout = 0.
emb_dropout = 0.,
channels = 3
):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
assert isinstance(depth, tuple), 'depth must be a tuple of integers, specifying the number of blocks before each downsizing'
heads = cast_tuple(heads, len(depth))
patch_dim = 3 * patch_size ** 2
patch_dim = channels * patch_size ** 2
self.to_patch_embedding = nn.Sequential(
nn.Unfold(kernel_size = patch_size, stride = patch_size // 2),
@@ -162,8 +165,9 @@ class PiT(nn.Module):
layers.append(Pool(dim))
dim *= 2
self.layers = nn.Sequential(
*layers,
self.layers = nn.Sequential(*layers)
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
@@ -174,7 +178,9 @@ class PiT(nn.Module):
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding
x += self.pos_embedding[:, :n+1]
x = self.dropout(x)
return self.layers(x)
x = self.layers(x)
return self.mlp_head(x[:, 0])

View File

@@ -8,7 +8,7 @@ def find_modules(nn_module, type):
return [module for module in nn_module.modules() if isinstance(module, type)]
class Recorder(nn.Module):
def __init__(self, vit):
def __init__(self, vit, device = None):
super().__init__()
self.vit = vit
@@ -17,6 +17,7 @@ class Recorder(nn.Module):
self.hooks = []
self.hook_registered = False
self.ejected = False
self.device = device
def _hook(self, _, input, output):
self.recordings.append(output.clone().detach())
@@ -45,10 +46,14 @@ class Recorder(nn.Module):
def forward(self, img):
assert not self.ejected, 'recorder has been ejected, cannot be used anymore'
self.clear()
if not self.hook_registered:
self._register_hook()
pred = self.vit(img)
attns = torch.stack(self.recordings, dim = 1)
# move all recordings to one device before stacking
target_device = self.device if self.device is not None else img.device
recordings = tuple(map(lambda t: t.to(target_device), self.recordings))
attns = torch.stack(recordings, dim = 1) if len(recordings) > 0 else None
return pred, attns

269
vit_pytorch/regionvit.py Normal file
View File

@@ -0,0 +1,269 @@
import torch
from torch import nn, einsum
from einops import rearrange
from einops.layers.torch import Rearrange, Reduce
import torch.nn.functional as F
# helpers
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
def divisible_by(val, d):
return (val % d) == 0
# helper classes
class Downsample(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.conv = nn.Conv2d(dim_in, dim_out, 3, stride = 2, padding = 1)
def forward(self, x):
return self.conv(x)
class PEG(nn.Module):
def __init__(self, dim, kernel_size = 3):
super().__init__()
self.proj = nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1)
def forward(self, x):
return self.proj(x) + x
# transformer classes
def FeedForward(dim, mult = 4, dropout = 0.):
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, dim * mult, 1),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim, 1)
)
class Attention(nn.Module):
def __init__(
self,
dim,
heads = 4,
dim_head = 32,
dropout = 0.
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
inner_dim = dim_head * heads
self.norm = nn.LayerNorm(dim)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, rel_pos_bias = None):
h = self.heads
# prenorm
x = self.norm(x)
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
# split heads
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
q = q * self.scale
sim = einsum('b h i d, b h j d -> b h i j', q, k)
# add relative positional bias for local tokens
if exists(rel_pos_bias):
sim = sim + rel_pos_bias
attn = sim.softmax(dim = -1)
attn = self.dropout(attn)
# merge heads
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class R2LTransformer(nn.Module):
def __init__(
self,
dim,
*,
window_size,
depth = 4,
heads = 4,
dim_head = 32,
attn_dropout = 0.,
ff_dropout = 0.,
):
super().__init__()
self.layers = nn.ModuleList([])
self.window_size = window_size
rel_positions = 2 * window_size - 1
self.local_rel_pos_bias = nn.Embedding(rel_positions ** 2, heads)
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout),
FeedForward(dim, dropout = ff_dropout)
]))
def forward(self, local_tokens, region_tokens):
device = local_tokens.device
lh, lw = local_tokens.shape[-2:]
rh, rw = region_tokens.shape[-2:]
window_size_h, window_size_w = lh // rh, lw // rw
local_tokens = rearrange(local_tokens, 'b c h w -> b (h w) c')
region_tokens = rearrange(region_tokens, 'b c h w -> b (h w) c')
# calculate local relative positional bias
h_range = torch.arange(window_size_h, device = device)
w_range = torch.arange(window_size_w, device = device)
grid_x, grid_y = torch.meshgrid(h_range, w_range, indexing = 'ij')
grid = torch.stack((grid_x, grid_y))
grid = rearrange(grid, 'c h w -> c (h w)')
grid = (grid[:, :, None] - grid[:, None, :]) + (self.window_size - 1)
bias_indices = (grid * torch.tensor([1, self.window_size * 2 - 1], device = device)[:, None, None]).sum(dim = 0)
rel_pos_bias = self.local_rel_pos_bias(bias_indices)
rel_pos_bias = rearrange(rel_pos_bias, 'i j h -> () h i j')
rel_pos_bias = F.pad(rel_pos_bias, (1, 0, 1, 0), value = 0)
# go through r2l transformer layers
for attn, ff in self.layers:
region_tokens = attn(region_tokens) + region_tokens
# concat region tokens to local tokens
local_tokens = rearrange(local_tokens, 'b (h w) d -> b h w d', h = lh)
local_tokens = rearrange(local_tokens, 'b (h p1) (w p2) d -> (b h w) (p1 p2) d', p1 = window_size_h, p2 = window_size_w)
region_tokens = rearrange(region_tokens, 'b n d -> (b n) () d')
# do self attention on local tokens, along with its regional token
region_and_local_tokens = torch.cat((region_tokens, local_tokens), dim = 1)
region_and_local_tokens = attn(region_and_local_tokens, rel_pos_bias = rel_pos_bias) + region_and_local_tokens
# feedforward
region_and_local_tokens = ff(region_and_local_tokens) + region_and_local_tokens
# split back local and regional tokens
region_tokens, local_tokens = region_and_local_tokens[:, :1], region_and_local_tokens[:, 1:]
local_tokens = rearrange(local_tokens, '(b h w) (p1 p2) d -> b (h p1 w p2) d', h = lh // window_size_h, w = lw // window_size_w, p1 = window_size_h)
region_tokens = rearrange(region_tokens, '(b n) () d -> b n d', n = rh * rw)
local_tokens = rearrange(local_tokens, 'b (h w) c -> b c h w', h = lh, w = lw)
region_tokens = rearrange(region_tokens, 'b (h w) c -> b c h w', h = rh, w = rw)
return local_tokens, region_tokens
# classes
class RegionViT(nn.Module):
def __init__(
self,
*,
dim = (64, 128, 256, 512),
depth = (2, 2, 8, 2),
window_size = 7,
num_classes = 1000,
tokenize_local_3_conv = False,
local_patch_size = 4,
use_peg = False,
attn_dropout = 0.,
ff_dropout = 0.,
channels = 3,
):
super().__init__()
dim = cast_tuple(dim, 4)
depth = cast_tuple(depth, 4)
assert len(dim) == 4, 'dim needs to be a single value or a tuple of length 4'
assert len(depth) == 4, 'depth needs to be a single value or a tuple of length 4'
self.local_patch_size = local_patch_size
region_patch_size = local_patch_size * window_size
self.region_patch_size = local_patch_size * window_size
init_dim, *_, last_dim = dim
# local and region encoders
if tokenize_local_3_conv:
self.local_encoder = nn.Sequential(
nn.Conv2d(3, init_dim, 3, 2, 1),
nn.LayerNorm(init_dim),
nn.GELU(),
nn.Conv2d(init_dim, init_dim, 3, 2, 1),
nn.LayerNorm(init_dim),
nn.GELU(),
nn.Conv2d(init_dim, init_dim, 3, 1, 1)
)
else:
self.local_encoder = nn.Conv2d(3, init_dim, 8, 4, 3)
self.region_encoder = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = region_patch_size, p2 = region_patch_size),
nn.Conv2d((region_patch_size ** 2) * channels, init_dim, 1)
)
# layers
current_dim = init_dim
self.layers = nn.ModuleList([])
for ind, dim, num_layers in zip(range(4), dim, depth):
not_first = ind != 0
need_downsample = not_first
need_peg = not_first and use_peg
self.layers.append(nn.ModuleList([
Downsample(current_dim, dim) if need_downsample else nn.Identity(),
PEG(dim) if need_peg else nn.Identity(),
R2LTransformer(dim, depth = num_layers, window_size = window_size, attn_dropout = attn_dropout, ff_dropout = ff_dropout)
]))
current_dim = dim
# final logits
self.to_logits = nn.Sequential(
Reduce('b c h w -> b c', 'mean'),
nn.LayerNorm(last_dim),
nn.Linear(last_dim, num_classes)
)
def forward(self, x):
*_, h, w = x.shape
assert divisible_by(h, self.region_patch_size) and divisible_by(w, self.region_patch_size), 'height and width must be divisible by region patch size'
assert divisible_by(h, self.local_patch_size) and divisible_by(w, self.local_patch_size), 'height and width must be divisible by local patch size'
local_tokens = self.local_encoder(x)
region_tokens = self.region_encoder(x)
for down, peg, transformer in self.layers:
local_tokens, region_tokens = down(local_tokens), down(region_tokens)
local_tokens = peg(local_tokens)
local_tokens, region_tokens = transformer(local_tokens, region_tokens)
return self.to_logits(region_tokens)

View File

@@ -19,7 +19,7 @@ class AxialRotaryEmbedding(nn.Module):
def __init__(self, dim, max_freq = 10):
super().__init__()
self.dim = dim
scales = torch.logspace(0., log(max_freq / 2) / log(2), self.dim // 4, base = 2)
scales = torch.linspace(1., max_freq / 2, self.dim // 4)
self.register_buffer('scales', scales)
def forward(self, x):
@@ -67,12 +67,14 @@ class SpatialConv(nn.Module):
def __init__(self, dim_in, dim_out, kernel, bias = False):
super().__init__()
self.conv = DepthWiseConv2d(dim_in, dim_out, kernel, padding = kernel // 2, bias = False)
self.cls_proj = nn.Linear(dim_in, dim_out) if dim_in != dim_out else nn.Identity()
def forward(self, x, fmap_dims):
cls_token, x = x[:, :1], x[:, 1:]
x = rearrange(x, 'b (h w) d -> b d h w', **fmap_dims)
x = self.conv(x)
x = rearrange(x, 'b d h w -> b (h w) d')
cls_token = self.cls_proj(cls_token)
return torch.cat((cls_token, x), dim = 1)
class GEGLU(nn.Module):
@@ -81,11 +83,11 @@ class GEGLU(nn.Module):
return F.gelu(gates) * x
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
def __init__(self, dim, hidden_dim, dropout = 0., use_glu = True):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim * 2),
GEGLU(),
nn.Linear(dim, hidden_dim * 2 if use_glu else hidden_dim),
GEGLU() if use_glu else nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
@@ -94,16 +96,19 @@ class FeedForward(nn.Module):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., conv_query_kernel = 5):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., use_rotary = True, use_ds_conv = True, conv_query_kernel = 5):
super().__init__()
inner_dim = dim_head * heads
self.use_rotary = use_rotary
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_q = SpatialConv(dim, inner_dim, conv_query_kernel, bias = False)
self.use_ds_conv = use_ds_conv
self.to_q = SpatialConv(dim, inner_dim, conv_query_kernel, bias = False) if use_ds_conv else nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
@@ -115,39 +120,50 @@ class Attention(nn.Module):
def forward(self, x, pos_emb, fmap_dims):
b, n, _, h = *x.shape, self.heads
q = self.to_q(x, fmap_dims = fmap_dims)
to_q_kwargs = {'fmap_dims': fmap_dims} if self.use_ds_conv else {}
q = self.to_q(x, **to_q_kwargs)
qkv = (q, *self.to_kv(x).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)
# apply 2d rotary embeddings to queries and keys, excluding CLS tokens
if self.use_rotary:
# apply 2d rotary embeddings to queries and keys, excluding CLS tokens
sin, cos = pos_emb
(q_cls, q), (k_cls, k) = map(lambda t: (t[:, :1], t[:, 1:]), (q, k))
q, k = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k))
sin, cos = pos_emb
dim_rotary = sin.shape[-1]
# concat back the CLS tokens
(q_cls, q), (k_cls, k) = map(lambda t: (t[:, :1], t[:, 1:]), (q, k))
q = torch.cat((q_cls, q), dim = 1)
k = torch.cat((k_cls, k), dim = 1)
# handle the case where rotary dimension < head dimension
(q, q_pass), (k, k_pass) = map(lambda t: (t[..., :dim_rotary], t[..., dim_rotary:]), (q, k))
q, k = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k))
q, k = map(lambda t: torch.cat(t, dim = -1), ((q, q_pass), (k, k_pass)))
# concat back the CLS tokens
q = torch.cat((q_cls, q), dim = 1)
k = torch.cat((k_cls, k), dim = 1)
dots = einsum('b i d, b j d -> b i j', q, k) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, image_size, dropout = 0., use_rotary = True, use_ds_conv = True, use_glu = True):
super().__init__()
self.layers = nn.ModuleList([])
self.pos_emb = AxialRotaryEmbedding(dim_head)
self.pos_emb = AxialRotaryEmbedding(dim_head, max_freq = image_size)
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_rotary = use_rotary, use_ds_conv = use_ds_conv)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout, use_glu = use_glu))
]))
def forward(self, x, fmap_dims):
pos_emb = self.pos_emb(x[:, 1:])
@@ -160,7 +176,7 @@ class Transformer(nn.Module):
# Rotary Vision Transformer
class RvT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., use_rotary = True, use_ds_conv = True, use_glu = True):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
@@ -173,7 +189,7 @@ class RvT(nn.Module):
)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, image_size, dropout, use_rotary, use_ds_conv, use_glu)
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
@@ -192,4 +208,4 @@ class RvT(nn.Module):
fmap_dims = {'h': h // p, 'w': w // p}
x = self.transformer(x, fmap_dims = fmap_dims)
return self.mlp_head(x)
return self.mlp_head(x[:, 0])

306
vit_pytorch/scalable_vit.py Normal file
View File

@@ -0,0 +1,306 @@
from functools import partial
import torch
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange, Reduce
# helpers
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
# helper classes
class ChanLayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = ChanLayerNorm(dim)
self.fn = fn
def forward(self, x):
return self.fn(self.norm(x))
class Downsample(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.conv = nn.Conv2d(dim_in, dim_out, 3, stride = 2, padding = 1)
def forward(self, x):
return self.conv(x)
class PEG(nn.Module):
def __init__(self, dim, kernel_size = 3):
super().__init__()
self.proj = nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1)
def forward(self, x):
return self.proj(x) + x
# feedforward
class FeedForward(nn.Module):
def __init__(self, dim, expansion_factor = 4, dropout = 0.):
super().__init__()
inner_dim = dim * expansion_factor
self.net = nn.Sequential(
nn.Conv2d(dim, inner_dim, 1),
nn.GELU(),
nn.Dropout(dropout),
nn.Conv2d(inner_dim, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
# attention
class ScalableSelfAttention(nn.Module):
def __init__(
self,
dim,
heads = 8,
dim_key = 32,
dim_value = 32,
dropout = 0.,
reduction_factor = 1
):
super().__init__()
self.heads = heads
self.scale = dim_key ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_q = nn.Conv2d(dim, dim_key * heads, 1, bias = False)
self.to_k = nn.Conv2d(dim, dim_key * heads, reduction_factor, stride = reduction_factor, bias = False)
self.to_v = nn.Conv2d(dim, dim_value * heads, reduction_factor, stride = reduction_factor, bias = False)
self.to_out = nn.Sequential(
nn.Conv2d(dim_value * heads, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
height, width, heads = *x.shape[-2:], self.heads
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
# split out heads
q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = heads), (q, k, v))
# similarity
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
# attention
attn = self.attend(dots)
attn = self.dropout(attn)
# aggregate values
out = torch.matmul(attn, v)
# merge back heads
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = height, y = width)
return self.to_out(out)
class InteractiveWindowedSelfAttention(nn.Module):
def __init__(
self,
dim,
window_size,
heads = 8,
dim_key = 32,
dim_value = 32,
dropout = 0.
):
super().__init__()
self.heads = heads
self.scale = dim_key ** -0.5
self.window_size = window_size
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.local_interactive_module = nn.Conv2d(dim_value * heads, dim_value * heads, 3, padding = 1)
self.to_q = nn.Conv2d(dim, dim_key * heads, 1, bias = False)
self.to_k = nn.Conv2d(dim, dim_key * heads, 1, bias = False)
self.to_v = nn.Conv2d(dim, dim_value * heads, 1, bias = False)
self.to_out = nn.Sequential(
nn.Conv2d(dim_value * heads, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
height, width, heads, wsz = *x.shape[-2:], self.heads, self.window_size
wsz_h, wsz_w = default(wsz, height), default(wsz, width)
assert (height % wsz_h) == 0 and (width % wsz_w) == 0, f'height ({height}) or width ({width}) of feature map is not divisible by the window size ({wsz_h}, {wsz_w})'
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
# get output of LIM
local_out = self.local_interactive_module(v)
# divide into window (and split out heads) for efficient self attention
q, k, v = map(lambda t: rearrange(t, 'b (h d) (x w1) (y w2) -> (b x y) h (w1 w2) d', h = heads, w1 = wsz_h, w2 = wsz_w), (q, k, v))
# similarity
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
# attention
attn = self.attend(dots)
attn = self.dropout(attn)
# aggregate values
out = torch.matmul(attn, v)
# reshape the windows back to full feature map (and merge heads)
out = rearrange(out, '(b x y) h (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz_h, y = width // wsz_w, w1 = wsz_h, w2 = wsz_w)
# add LIM output
out = out + local_out
return self.to_out(out)
class Transformer(nn.Module):
def __init__(
self,
dim,
depth,
heads = 8,
ff_expansion_factor = 4,
dropout = 0.,
ssa_dim_key = 32,
ssa_dim_value = 32,
ssa_reduction_factor = 1,
iwsa_dim_key = 32,
iwsa_dim_value = 32,
iwsa_window_size = None,
norm_output = True
):
super().__init__()
self.layers = nn.ModuleList([])
for ind in range(depth):
is_first = ind == 0
self.layers.append(nn.ModuleList([
PreNorm(dim, ScalableSelfAttention(dim, heads = heads, dim_key = ssa_dim_key, dim_value = ssa_dim_value, reduction_factor = ssa_reduction_factor, dropout = dropout)),
PreNorm(dim, FeedForward(dim, expansion_factor = ff_expansion_factor, dropout = dropout)),
PEG(dim) if is_first else None,
PreNorm(dim, FeedForward(dim, expansion_factor = ff_expansion_factor, dropout = dropout)),
PreNorm(dim, InteractiveWindowedSelfAttention(dim, heads = heads, dim_key = iwsa_dim_key, dim_value = iwsa_dim_value, window_size = iwsa_window_size, dropout = dropout))
]))
self.norm = ChanLayerNorm(dim) if norm_output else nn.Identity()
def forward(self, x):
for ssa, ff1, peg, iwsa, ff2 in self.layers:
x = ssa(x) + x
x = ff1(x) + x
if exists(peg):
x = peg(x)
x = iwsa(x) + x
x = ff2(x) + x
return self.norm(x)
class ScalableViT(nn.Module):
def __init__(
self,
*,
num_classes,
dim,
depth,
heads,
reduction_factor,
window_size = None,
iwsa_dim_key = 32,
iwsa_dim_value = 32,
ssa_dim_key = 32,
ssa_dim_value = 32,
ff_expansion_factor = 4,
channels = 3,
dropout = 0.
):
super().__init__()
self.to_patches = nn.Conv2d(channels, dim, 7, stride = 4, padding = 3)
assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'
num_stages = len(depth)
dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))
hyperparams_per_stage = [
heads,
ssa_dim_key,
ssa_dim_value,
reduction_factor,
iwsa_dim_key,
iwsa_dim_value,
window_size,
]
hyperparams_per_stage = list(map(partial(cast_tuple, length = num_stages), hyperparams_per_stage))
assert all(tuple(map(lambda arr: len(arr) == num_stages, hyperparams_per_stage)))
self.layers = nn.ModuleList([])
for ind, (layer_dim, layer_depth, layer_heads, layer_ssa_dim_key, layer_ssa_dim_value, layer_ssa_reduction_factor, layer_iwsa_dim_key, layer_iwsa_dim_value, layer_window_size) in enumerate(zip(dims, depth, *hyperparams_per_stage)):
is_last = ind == (num_stages - 1)
self.layers.append(nn.ModuleList([
Transformer(dim = layer_dim, depth = layer_depth, heads = layer_heads, ff_expansion_factor = ff_expansion_factor, dropout = dropout, ssa_dim_key = layer_ssa_dim_key, ssa_dim_value = layer_ssa_dim_value, ssa_reduction_factor = layer_ssa_reduction_factor, iwsa_dim_key = layer_iwsa_dim_key, iwsa_dim_value = layer_iwsa_dim_value, iwsa_window_size = layer_window_size, norm_output = not is_last),
Downsample(layer_dim, layer_dim * 2) if not is_last else None
]))
self.mlp_head = nn.Sequential(
Reduce('b d h w -> b d', 'mean'),
nn.LayerNorm(dims[-1]),
nn.Linear(dims[-1], num_classes)
)
def forward(self, img):
x = self.to_patches(img)
for transformer, downsample in self.layers:
x = transformer(x)
if exists(downsample):
x = downsample(x)
return self.mlp_head(x)

294
vit_pytorch/sep_vit.py Normal file
View File

@@ -0,0 +1,294 @@
from functools import partial
import torch
from torch import nn, einsum
from einops import rearrange, repeat
from einops.layers.torch import Rearrange, Reduce
# helpers
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
# helper classes
class ChanLayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = ChanLayerNorm(dim)
self.fn = fn
def forward(self, x):
return self.fn(self.norm(x))
class OverlappingPatchEmbed(nn.Module):
def __init__(self, dim_in, dim_out, stride = 2):
super().__init__()
kernel_size = stride * 2 - 1
padding = kernel_size // 2
self.conv = nn.Conv2d(dim_in, dim_out, kernel_size, stride = stride, padding = padding)
def forward(self, x):
return self.conv(x)
class PEG(nn.Module):
def __init__(self, dim, kernel_size = 3):
super().__init__()
self.proj = nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1)
def forward(self, x):
return self.proj(x) + x
# feedforward
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
inner_dim = int(dim * mult)
self.net = nn.Sequential(
nn.Conv2d(dim, inner_dim, 1),
nn.GELU(),
nn.Dropout(dropout),
nn.Conv2d(inner_dim, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
# attention
class DSSA(nn.Module):
def __init__(
self,
dim,
heads = 8,
dim_head = 32,
dropout = 0.,
window_size = 7
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
self.window_size = window_size
inner_dim = dim_head * heads
self.attend = nn.Sequential(
nn.Softmax(dim = -1),
nn.Dropout(dropout)
)
self.to_qkv = nn.Conv1d(dim, inner_dim * 3, 1, bias = False)
# window tokens
self.window_tokens = nn.Parameter(torch.randn(dim))
# prenorm and non-linearity for window tokens
# then projection to queries and keys for window tokens
self.window_tokens_to_qk = nn.Sequential(
nn.LayerNorm(dim_head),
nn.GELU(),
Rearrange('b h n c -> b (h c) n'),
nn.Conv1d(inner_dim, inner_dim * 2, 1),
Rearrange('b (h c) n -> b h n c', h = heads),
)
# window attention
self.window_attend = nn.Sequential(
nn.Softmax(dim = -1),
nn.Dropout(dropout)
)
self.to_out = nn.Sequential(
nn.Conv2d(inner_dim, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
"""
einstein notation
b - batch
c - channels
w1 - window size (height)
w2 - also window size (width)
i - sequence dimension (source)
j - sequence dimension (target dimension to be reduced)
h - heads
x - height of feature map divided by window size
y - width of feature map divided by window size
"""
batch, height, width, heads, wsz = x.shape[0], *x.shape[-2:], self.heads, self.window_size
assert (height % wsz) == 0 and (width % wsz) == 0, f'height {height} and width {width} must be divisible by window size {wsz}'
num_windows = (height // wsz) * (width // wsz)
# fold in windows for "depthwise" attention - not sure why it is named depthwise when it is just "windowed" attention
x = rearrange(x, 'b c (h w1) (w w2) -> (b h w) c (w1 w2)', w1 = wsz, w2 = wsz)
# add windowing tokens
w = repeat(self.window_tokens, 'c -> b c 1', b = x.shape[0])
x = torch.cat((w, x), dim = -1)
# project for queries, keys, value
q, k, v = self.to_qkv(x).chunk(3, dim = 1)
# split out heads
q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = heads), (q, k, v))
# scale
q = q * self.scale
# similarity
dots = einsum('b h i d, b h j d -> b h i j', q, k)
# attention
attn = self.attend(dots)
# aggregate values
out = torch.matmul(attn, v)
# split out windowed tokens
window_tokens, windowed_fmaps = out[:, :, 0], out[:, :, 1:]
# early return if there is only 1 window
if num_windows == 1:
fmap = rearrange(windowed_fmaps, '(b x y) h (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz, y = width // wsz, w1 = wsz, w2 = wsz)
return self.to_out(fmap)
# carry out the pointwise attention, the main novelty in the paper
window_tokens = rearrange(window_tokens, '(b x y) h d -> b h (x y) d', x = height // wsz, y = width // wsz)
windowed_fmaps = rearrange(windowed_fmaps, '(b x y) h n d -> b h (x y) n d', x = height // wsz, y = width // wsz)
# windowed queries and keys (preceded by prenorm activation)
w_q, w_k = self.window_tokens_to_qk(window_tokens).chunk(2, dim = -1)
# scale
w_q = w_q * self.scale
# similarities
w_dots = einsum('b h i d, b h j d -> b h i j', w_q, w_k)
w_attn = self.window_attend(w_dots)
# aggregate the feature maps from the "depthwise" attention step (the most interesting part of the paper, one i haven't seen before)
aggregated_windowed_fmap = einsum('b h i j, b h j w d -> b h i w d', w_attn, windowed_fmaps)
# fold back the windows and then combine heads for aggregation
fmap = rearrange(aggregated_windowed_fmap, 'b h (x y) (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz, y = width // wsz, w1 = wsz, w2 = wsz)
return self.to_out(fmap)
class Transformer(nn.Module):
def __init__(
self,
dim,
depth,
dim_head = 32,
heads = 8,
ff_mult = 4,
dropout = 0.,
norm_output = True
):
super().__init__()
self.layers = nn.ModuleList([])
for ind in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, DSSA(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mult = ff_mult, dropout = dropout)),
]))
self.norm = ChanLayerNorm(dim) if norm_output else nn.Identity()
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)
class SepViT(nn.Module):
def __init__(
self,
*,
num_classes,
dim,
depth,
heads,
window_size = 7,
dim_head = 32,
ff_mult = 4,
channels = 3,
dropout = 0.
):
super().__init__()
assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'
num_stages = len(depth)
dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))
dims = (channels, *dims)
dim_pairs = tuple(zip(dims[:-1], dims[1:]))
strides = (4, *((2,) * (num_stages - 1)))
hyperparams_per_stage = [heads, window_size]
hyperparams_per_stage = list(map(partial(cast_tuple, length = num_stages), hyperparams_per_stage))
assert all(tuple(map(lambda arr: len(arr) == num_stages, hyperparams_per_stage)))
self.layers = nn.ModuleList([])
for ind, ((layer_dim_in, layer_dim), layer_depth, layer_stride, layer_heads, layer_window_size) in enumerate(zip(dim_pairs, depth, strides, *hyperparams_per_stage)):
is_last = ind == (num_stages - 1)
self.layers.append(nn.ModuleList([
OverlappingPatchEmbed(layer_dim_in, layer_dim, stride = layer_stride),
PEG(layer_dim),
Transformer(dim = layer_dim, depth = layer_depth, heads = layer_heads, ff_mult = ff_mult, dropout = dropout, norm_output = not is_last),
]))
self.mlp_head = nn.Sequential(
Reduce('b d h w -> b d', 'mean'),
nn.LayerNorm(dims[-1]),
nn.Linear(dims[-1], num_classes)
)
def forward(self, x):
for ope, peg, transformer in self.layers:
x = ope(x)
x = peg(x)
x = transformer(x)
return self.mlp_head(x)

84
vit_pytorch/simmim.py Normal file
View File

@@ -0,0 +1,84 @@
import torch
from torch import nn
import torch.nn.functional as F
from einops import repeat
class SimMIM(nn.Module):
def __init__(
self,
*,
encoder,
masking_ratio = 0.5
):
super().__init__()
assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1'
self.masking_ratio = masking_ratio
# extract some hyperparameters and functions from encoder (vision transformer to be trained)
self.encoder = encoder
num_patches, encoder_dim = encoder.pos_embedding.shape[-2:]
self.to_patch, self.patch_to_emb = encoder.to_patch_embedding[:2]
pixel_values_per_patch = self.patch_to_emb.weight.shape[-1]
# simple linear head
self.mask_token = nn.Parameter(torch.randn(encoder_dim))
self.to_pixels = nn.Linear(encoder_dim, pixel_values_per_patch)
def forward(self, img):
device = img.device
# get patches
patches = self.to_patch(img)
batch, num_patches, *_ = patches.shape
# for indexing purposes
batch_range = torch.arange(batch, device = device)[:, None]
# get positions
pos_emb = self.encoder.pos_embedding[:, 1:(num_patches + 1)]
# patch to encoder tokens and add positions
tokens = self.patch_to_emb(patches)
tokens = tokens + pos_emb
# prepare mask tokens
mask_tokens = repeat(self.mask_token, 'd -> b n d', b = batch, n = num_patches)
mask_tokens = mask_tokens + pos_emb
# calculate of patches needed to be masked, and get positions (indices) to be masked
num_masked = int(self.masking_ratio * num_patches)
masked_indices = torch.rand(batch, num_patches, device = device).topk(k = num_masked, dim = -1).indices
masked_bool_mask = torch.zeros((batch, num_patches), device = device).scatter_(-1, masked_indices, 1).bool()
# mask tokens
tokens = torch.where(masked_bool_mask[..., None], mask_tokens, tokens)
# attend with vision transformer
encoded = self.encoder.transformer(tokens)
# get the masked tokens
encoded_mask_tokens = encoded[batch_range, masked_indices]
# small linear projection for predicted pixel values
pred_pixel_values = self.to_pixels(encoded_mask_tokens)
# get the masked patches for the final reconstruction loss
masked_patches = patches[batch_range, masked_indices]
# calculate reconstruction loss
recon_loss = F.l1_loss(pred_pixel_values, masked_patches) / num_masked
return recon_loss

120
vit_pytorch/simple_vit.py Normal file
View File

@@ -0,0 +1,120 @@
import torch
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def posemb_sincos_2d(img, dim, temperature = 10000, dtype = torch.float32):
_, h, w, _, device, dtype = *img.shape, img.device, img.dtype
y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
omega = torch.arange(dim // 4) / (dim // 4 - 1)
omega = 1. / (temperature ** omega)
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
return pe.type(dtype)
# classes
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head),
FeedForward(dim, mlp_dim)
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class SimpleViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
self.dim = dim
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.Linear(patch_dim, dim),
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
self.to_latent = nn.Identity()
self.linear_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
*_, h, w, dtype = *img.shape, img.dtype
x = self.to_patch_embedding(img)
pe = posemb_sincos_2d(x, self.dim)
x = rearrange(x, 'b ... d -> b (...) d') + pe
x = self.transformer(x)
x = x.mean(dim = 1)
x = self.to_latent(x)
return self.linear_head(x)

View File

@@ -35,13 +35,14 @@ class T2TViT(nn.Module):
for i, (kernel_size, stride) in enumerate(t2t_layers):
layer_dim *= kernel_size ** 2
is_first = i == 0
is_last = i == (len(t2t_layers) - 1)
output_image_size = conv_output_size(output_image_size, kernel_size, stride, stride // 2)
layers.extend([
RearrangeImage() if not is_first else nn.Identity(),
nn.Unfold(kernel_size = kernel_size, stride = stride, padding = stride // 2),
Rearrange('b c n -> b n c'),
Transformer(dim = layer_dim, heads = 1, depth = 1, dim_head = layer_dim, mlp_dim = layer_dim, dropout = dropout),
Transformer(dim = layer_dim, heads = 1, depth = 1, dim_head = layer_dim, mlp_dim = layer_dim, dropout = dropout) if not is_last else nn.Identity(),
])
layers.append(nn.Linear(layer_dim, dim))
@@ -71,7 +72,7 @@ class T2TViT(nn.Module):
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding
x += self.pos_embedding[:, :n+1]
x = self.dropout(x)
x = self.transformer(x)

232
vit_pytorch/twins_svt.py Normal file
View File

@@ -0,0 +1,232 @@
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helper methods
def group_dict_by_key(cond, d):
return_val = [dict(), dict()]
for key in d.keys():
match = bool(cond(key))
ind = int(not match)
return_val[ind][key] = d[key]
return (*return_val,)
def group_by_key_prefix_and_remove_prefix(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(lambda x: x.startswith(prefix), d)
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs
# classes
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
class LayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(dim, dim * mult, 1),
nn.GELU(),
nn.Dropout(dropout),
nn.Conv2d(dim * mult, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class PatchEmbedding(nn.Module):
def __init__(self, *, dim, dim_out, patch_size):
super().__init__()
self.dim = dim
self.dim_out = dim_out
self.patch_size = patch_size
self.proj = nn.Conv2d(patch_size ** 2 * dim, dim_out, 1)
def forward(self, fmap):
p = self.patch_size
fmap = rearrange(fmap, 'b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = p, p2 = p)
return self.proj(fmap)
class PEG(nn.Module):
def __init__(self, dim, kernel_size = 3):
super().__init__()
self.proj = Residual(nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1))
def forward(self, x):
return self.proj(x)
class LocalAttention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., patch_size = 7):
super().__init__()
inner_dim = dim_head * heads
self.patch_size = patch_size
self.heads = heads
self.scale = dim_head ** -0.5
self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False)
self.to_kv = nn.Conv2d(dim, inner_dim * 2, 1, bias = False)
self.to_out = nn.Sequential(
nn.Conv2d(inner_dim, dim, 1),
nn.Dropout(dropout)
)
def forward(self, fmap):
shape, p = fmap.shape, self.patch_size
b, n, x, y, h = *shape, self.heads
x, y = map(lambda t: t // p, (x, y))
fmap = rearrange(fmap, 'b c (x p1) (y p2) -> (b x y) c p1 p2', p1 = p, p2 = p)
q, k, v = (self.to_q(fmap), *self.to_kv(fmap).chunk(2, dim = 1))
q, k, v = map(lambda t: rearrange(t, 'b (h d) p1 p2 -> (b h) (p1 p2) d', h = h), (q, k, v))
dots = einsum('b i d, b j d -> b i j', q, k) * self.scale
attn = dots.softmax(dim = - 1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b x y h) (p1 p2) d -> b (h d) (x p1) (y p2)', h = h, x = x, y = y, p1 = p, p2 = p)
return self.to_out(out)
class GlobalAttention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., k = 7):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False)
self.to_kv = nn.Conv2d(dim, inner_dim * 2, k, stride = k, bias = False)
self.dropout = nn.Dropout(dropout)
self.to_out = nn.Sequential(
nn.Conv2d(inner_dim, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
shape = x.shape
b, n, _, y, h = *shape, self.heads
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = 1))
q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> (b h) (x y) d', h = h), (q, k, v))
dots = einsum('b i d, b j d -> b i j', q, k) * self.scale
attn = dots.softmax(dim = -1)
attn = self.dropout(attn)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, y = y)
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads = 8, dim_head = 64, mlp_mult = 4, local_patch_size = 7, global_k = 7, dropout = 0., has_local = True):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Residual(PreNorm(dim, LocalAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout, patch_size = local_patch_size))) if has_local else nn.Identity(),
Residual(PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout))) if has_local else nn.Identity(),
Residual(PreNorm(dim, GlobalAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout, k = global_k))),
Residual(PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout)))
]))
def forward(self, x):
for local_attn, ff1, global_attn, ff2 in self.layers:
x = local_attn(x)
x = ff1(x)
x = global_attn(x)
x = ff2(x)
return x
class TwinsSVT(nn.Module):
def __init__(
self,
*,
num_classes,
s1_emb_dim = 64,
s1_patch_size = 4,
s1_local_patch_size = 7,
s1_global_k = 7,
s1_depth = 1,
s2_emb_dim = 128,
s2_patch_size = 2,
s2_local_patch_size = 7,
s2_global_k = 7,
s2_depth = 1,
s3_emb_dim = 256,
s3_patch_size = 2,
s3_local_patch_size = 7,
s3_global_k = 7,
s3_depth = 5,
s4_emb_dim = 512,
s4_patch_size = 2,
s4_local_patch_size = 7,
s4_global_k = 7,
s4_depth = 4,
peg_kernel_size = 3,
dropout = 0.
):
super().__init__()
kwargs = dict(locals())
dim = 3
layers = []
for prefix in ('s1', 's2', 's3', 's4'):
config, kwargs = group_by_key_prefix_and_remove_prefix(f'{prefix}_', kwargs)
is_last = prefix == 's4'
dim_next = config['emb_dim']
layers.append(nn.Sequential(
PatchEmbedding(dim = dim, dim_out = dim_next, patch_size = config['patch_size']),
Transformer(dim = dim_next, depth = 1, local_patch_size = config['local_patch_size'], global_k = config['global_k'], dropout = dropout, has_local = not is_last),
PEG(dim = dim_next, kernel_size = peg_kernel_size),
Transformer(dim = dim_next, depth = config['depth'], local_patch_size = config['local_patch_size'], global_k = config['global_k'], dropout = dropout, has_local = not is_last)
))
dim = dim_next
self.layers = nn.Sequential(
*layers,
nn.AdaptiveAvgPool2d(1),
Rearrange('... () () -> ...'),
nn.Linear(dim, num_classes)
)
def forward(self, x):
return self.layers(x)

View File

@@ -1,10 +1,16 @@
import torch
from torch import nn, einsum
import torch.nn.functional as F
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
@@ -36,6 +42,8 @@ class Attention(nn.Module):
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
@@ -44,15 +52,15 @@ class Attention(nn.Module):
) if project_out else nn.Identity()
def forward(self, x):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
@@ -74,13 +82,17 @@ class Transformer(nn.Module):
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.Linear(patch_dim, dim),
)
@@ -102,7 +114,7 @@ class ViT(nn.Module):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)

View File

@@ -0,0 +1,145 @@
from math import sqrt
import torch
import torch.nn.functional as F
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class LSA(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.temperature = nn.Parameter(torch.log(torch.tensor(dim_head ** -0.5)))
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.temperature.exp()
mask = torch.eye(dots.shape[-1], device = dots.device, dtype = torch.bool)
mask_value = -torch.finfo(dots.dtype).max
dots = dots.masked_fill(mask, mask_value)
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, LSA(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class SPT(nn.Module):
def __init__(self, *, dim, patch_size, channels = 3):
super().__init__()
patch_dim = patch_size * patch_size * 5 * channels
self.to_patch_tokens = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim)
)
def forward(self, x):
shifts = ((1, -1, 0, 0), (-1, 1, 0, 0), (0, 0, 1, -1), (0, 0, -1, 1))
shifted_x = list(map(lambda shift: F.pad(x, shift), shifts))
x_with_shifts = torch.cat((x, *shifted_x), dim = 1)
return self.to_patch_tokens(x_with_shifts)
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = SPT(dim = dim, patch_size = patch_size, channels = channels)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)

View File

@@ -0,0 +1,147 @@
import torch
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange, Reduce
# helpers
def exists(val):
return val is not None
def default(val ,d):
return val if exists(val) else d
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# patch merger class
class PatchMerger(nn.Module):
def __init__(self, dim, num_tokens_out):
super().__init__()
self.scale = dim ** -0.5
self.norm = nn.LayerNorm(dim)
self.queries = nn.Parameter(torch.randn(num_tokens_out, dim))
def forward(self, x):
x = self.norm(x)
sim = torch.matmul(self.queries, x.transpose(-1, -2)) * self.scale
attn = sim.softmax(dim = -1)
return torch.matmul(attn, x)
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., patch_merge_layer = None, patch_merge_num_tokens = 8):
super().__init__()
self.layers = nn.ModuleList([])
self.patch_merge_layer_index = default(patch_merge_layer, depth // 2) - 1 # default to mid-way through transformer, as shown in paper
self.patch_merger = PatchMerger(dim = dim, num_tokens_out = patch_merge_num_tokens)
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for index, (attn, ff) in enumerate(self.layers):
x = attn(x) + x
x = ff(x) + x
if index == self.patch_merge_layer_index:
x = self.patch_merger(x)
return x
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, patch_merge_layer = None, patch_merge_num_tokens = 8, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.Linear(patch_dim, dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, patch_merge_layer, patch_merge_num_tokens)
self.mlp_head = nn.Sequential(
Reduce('b n d -> b d', 'mean'),
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
x += self.pos_embedding[:, :n]
x = self.dropout(x)
x = self.transformer(x)
return self.mlp_head(x)