Compare commits

...

50 Commits

Author SHA1 Message Date
Phil Wang
6549522629 be able to accept non-square patches, thanks to @FilipAndersson245 2021-05-01 20:04:41 -07:00
Phil Wang
6a80a4ef89 update readme 2021-05-01 11:51:35 -07:00
Phil Wang
9f05587a7d 0.17.2 2021-04-30 06:44:59 -07:00
Phil Wang
65bb350e85 0.17.2 2021-04-30 06:44:54 -07:00
Phil Wang
fd4a7dfcf8 Merge pull request #102 from jon-tow/rvt-add-use-glu-flag
Add `use_glu` flag to `RvT`
2021-04-30 06:44:41 -07:00
Jonathan Tow
6f3a5fcf0b Add use_glu flag to RvT 2021-04-30 02:07:41 -04:00
Phil Wang
7807f24509 fix small bug 2021-04-29 15:39:41 -07:00
Phil Wang
a612327126 readme 2021-04-29 15:22:12 -07:00
Phil Wang
30a1335d31 release twins svt 2021-04-29 14:55:25 -07:00
Phil Wang
ab781f7ddb add Twins SVT (small) 2021-04-29 14:54:06 -07:00
Phil Wang
4f3dbd003f for PiT, project to increased dimensions on first grouped conv for depthwise-conv 2021-04-29 12:41:00 -07:00
Phil Wang
60b5687a79 cleanup rvt 2021-04-27 11:45:46 -07:00
Phil Wang
0df1505662 add zeroing of weight parameters of batchnorm in levit just before residual connection, noticed by @EelcoHoogendoorn 2021-04-27 08:41:16 -07:00
Phil Wang
3df6c31c61 fix norm issues in cvt 2021-04-27 08:36:17 -07:00
Phil Wang
54af220930 fix cvt 2021-04-26 20:37:51 -07:00
Phil Wang
bad4b94e7b fix all issues with rotary vision transformer 2021-04-25 12:09:32 -07:00
Phil Wang
fbced01fe7 cite 2021-04-20 18:36:54 -07:00
Phil Wang
e42e9876bc offer a way to turn off ds conv in rotary vision transformer for ablation 2021-04-20 10:12:03 -07:00
Phil Wang
566365978d add ability to turn off rotary, for ablation 2021-04-20 09:00:27 -07:00
Phil Wang
34f78294d3 fix pooling bugs across a few new archs 2021-04-19 22:36:23 -07:00
Phil Wang
4c29328363 fix frequency in rotary vision transformer 2021-04-15 16:06:32 -07:00
Phil Wang
27ac10c1f1 0.16.3 2021-04-14 16:53:05 -07:00
Phil Wang
fa216c45ea tweak 2021-04-14 16:52:53 -07:00
Phil Wang
1d8b7826bf update personal pet vit 2021-04-14 15:56:39 -07:00
Phil Wang
53b3af05f6 use convolution on query with padding to give the network absolute spatial awareness in addition to relative encoding from rotary embeddings 2021-04-14 15:56:02 -07:00
Phil Wang
6289619e3f 0.16.1 2021-04-14 08:05:08 -07:00
Phil Wang
b42fa7862e Merge pull request #91 from shabie/patch-1
Fix alpha coefficient multiplication in the loss
2021-04-14 08:04:50 -07:00
shabie
dc6622c05c Fix alpha coefficient multiplication in the loss 2021-04-14 11:36:43 +02:00
Phil Wang
30b37c4028 add LocalViT 2021-04-12 19:17:32 -07:00
Phil Wang
4497f1e90f add rotary vision transformer 2021-04-10 22:59:15 -07:00
Phil Wang
b50d3e1334 cleanup levit 2021-04-06 13:46:19 -07:00
Phil Wang
e075460937 stray print 2021-04-06 13:38:52 -07:00
Phil Wang
5e23e48e4d Merge pull request #88 from lucidrains/levit
fix images
2021-04-06 13:37:46 -07:00
Phil Wang
db04c0f319 fix images 2021-04-06 13:37:23 -07:00
Phil Wang
0f31ca79e3 Merge pull request #87 from lucidrains/levit
levit without pos emb
2021-04-06 13:36:26 -07:00
Phil Wang
2cb6b35030 complete levit 2021-04-06 13:36:11 -07:00
Phil Wang
2ec9161a98 levit without pos emb 2021-04-06 12:58:05 -07:00
Phil Wang
3a3038c702 add layer dropout for CaiT 2021-04-01 20:30:37 -07:00
Phil Wang
b1f1044c8e offer hard distillation as well 2021-04-01 16:56:14 -07:00
Phil Wang
deb96201d5 readme 2021-03-31 23:02:47 -07:00
Phil Wang
05b47cc070 make sure layerscale epsilon is a function of depth 2021-03-31 22:53:04 -07:00
Phil Wang
9ef8da4759 add CaiT, new vision transformer out of facebook AI, complete with layerscale, talking heads, and cls -> patch cross attention 2021-03-31 22:42:16 -07:00
Phil Wang
506fcf83a6 add documentation for three recent vision transformer follow-up papers 2021-03-31 09:22:15 -07:00
Phil Wang
6fb360a1ff add arxiv links for now, document in readme later 2021-03-30 22:26:44 -07:00
Phil Wang
9332b9e8c9 cite 2021-03-30 22:16:14 -07:00
Phil Wang
da950e6d2c add working PiT 2021-03-30 22:15:19 -07:00
Phil Wang
4b9a02d89c use depthwise conv for CvT projections 2021-03-30 18:18:35 -07:00
Phil Wang
518924eac5 add CvT 2021-03-30 14:42:39 -07:00
Phil Wang
e712003dfb add CrossViT 2021-03-30 00:53:27 -07:00
Phil Wang
d04ce06a30 make recorder work for t2t and deepvit 2021-03-29 18:16:34 -07:00
19 changed files with 1967 additions and 18 deletions

353
README.md
View File

@@ -38,6 +38,7 @@ preds = v(img) # (1, 1000)
```
## Parameters
- `image_size`: int.
Image size. If you have rectangular images, make sure your image size is the maximum of the width and height
- `patch_size`: int.
@@ -93,7 +94,8 @@ distiller = DistillWrapper(
student = v,
teacher = teacher,
temperature = 3, # temperature of distillation
alpha = 0.5 # trade between main loss and distillation loss
alpha = 0.5, # trade between main loss and distillation loss
hard = False # whether to use soft or hard distillation
)
img = torch.randn(2, 3, 256, 256)
@@ -143,6 +145,37 @@ img = torch.randn(1, 3, 256, 256)
preds = v(img) # (1, 1000)
```
## CaiT
<a href="https://arxiv.org/abs/2103.17239">This paper</a> also notes difficulty in training vision transformers at greater depths and proposes two solutions. First it proposes to do per-channel multiplication of the output of the residual block. Second, it proposes to have the patches attend to one another, and only allow the CLS token to attend to the patches in the last few layers.
They also add <a href="https://github.com/lucidrains/x-transformers#talking-heads-attention">Talking Heads</a>, noting improvements
You can use this scheme as follows
```python
import torch
from vit_pytorch.cait import CaiT
v = CaiT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 12, # depth of transformer for patch to patch attention only
cls_depth = 2, # depth of cross attention of CLS tokens to patch
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1,
layer_dropout = 0.05 # randomly dropout 5% of the layers
)
img = torch.randn(1, 3, 256, 256)
preds = v(img) # (1, 1000)
```
## Token-to-Token ViT
<img src="./images/t2t.png" width="400px"></img>
@@ -164,7 +197,183 @@ v = T2TViT(
)
img = torch.randn(1, 3, 224, 224)
v(img) # (1, 1000)
preds = v(img) # (1, 1000)
```
## Cross ViT
<img src="./images/cross_vit.png" width="400px"></img>
<a href="https://arxiv.org/abs/2103.14899">This paper</a> proposes to have two vision transformers processing the image at different scales, cross attending to one every so often. They show improvements on top of the base vision transformer.
```python
import torch
from vit_pytorch.cross_vit import CrossViT
v = CrossViT(
image_size = 256,
num_classes = 1000,
depth = 4, # number of multi-scale encoding blocks
sm_dim = 192, # high res dimension
sm_patch_size = 16, # high res patch size (should be smaller than lg_patch_size)
sm_enc_depth = 2, # high res depth
sm_enc_heads = 8, # high res heads
sm_enc_mlp_dim = 2048, # high res feedforward dimension
lg_dim = 384, # low res dimension
lg_patch_size = 64, # low res patch size
lg_enc_depth = 3, # low res depth
lg_enc_heads = 8, # low res heads
lg_enc_mlp_dim = 2048, # low res feedforward dimensions
cross_attn_depth = 2, # cross attention rounds
cross_attn_heads = 8, # cross attention heads
dropout = 0.1,
emb_dropout = 0.1
)
img = torch.randn(1, 3, 256, 256)
pred = v(img) # (1, 1000)
```
## PiT
<img src="./images/pit.png" width="400px"></img>
<a href="https://arxiv.org/abs/2103.16302">This paper</a> proposes to downsample the tokens through a pooling procedure using depth-wise convolutions.
```python
import torch
from vit_pytorch.pit import PiT
v = PiT(
image_size = 224,
patch_size = 14,
dim = 256,
num_classes = 1000,
depth = (3, 3, 3), # list of depths, indicating the number of rounds of each stage before a downsample
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
# forward pass now returns predictions and the attention maps
img = torch.randn(1, 3, 224, 224)
preds = v(img) # (1, 1000)
```
## LeViT
<img src="./images/levit.png" width="300px"></img>
<a href="https://arxiv.org/abs/2104.01136">This paper</a> proposes a number of changes, including (1) convolutional embedding instead of patch-wise projection (2) downsampling in stages (3) extra non-linearity in attention (4) 2d relative positional biases instead of initial absolute positional bias (5) batchnorm in place of layernorm.
```python
import torch
from vit_pytorch.levit import LeViT
levit = LeViT(
image_size = 224,
num_classes = 1000,
stages = 3, # number of stages
dim = (256, 384, 512), # dimensions at each stage
depth = 4, # transformer of depth 4 at each stage
heads = (4, 6, 8), # heads at each stage
mlp_mult = 2,
dropout = 0.1
)
img = torch.randn(1, 3, 224, 224)
levit(img) # (1, 1000)
```
## CvT
<img src="./images/cvt.png" width="400px"></img>
<a href="https://arxiv.org/abs/2103.15808">This paper</a> proposes mixing convolutions and attention. Specifically, convolutions are used to embed and downsample the image / feature map in three stages. Depthwise-convoltion is also used to project the queries, keys, and values for attention.
```python
import torch
from vit_pytorch.cvt import CvT
v = CvT(
num_classes = 1000,
s1_emb_dim = 64, # stage 1 - dimension
s1_emb_kernel = 7, # stage 1 - conv kernel
s1_emb_stride = 4, # stage 1 - conv stride
s1_proj_kernel = 3, # stage 1 - attention ds-conv kernel size
s1_kv_proj_stride = 2, # stage 1 - attention key / value projection stride
s1_heads = 1, # stage 1 - heads
s1_depth = 1, # stage 1 - depth
s1_mlp_mult = 4, # stage 1 - feedforward expansion factor
s2_emb_dim = 192, # stage 2 - (same as above)
s2_emb_kernel = 3,
s2_emb_stride = 2,
s2_proj_kernel = 3,
s2_kv_proj_stride = 2,
s2_heads = 3,
s2_depth = 2,
s2_mlp_mult = 4,
s3_emb_dim = 384, # stage 3 - (same as above)
s3_emb_kernel = 3,
s3_emb_stride = 2,
s3_proj_kernel = 3,
s3_kv_proj_stride = 2,
s3_heads = 4,
s3_depth = 10,
s3_mlp_mult = 4,
dropout = 0.
)
img = torch.randn(1, 3, 224, 224)
pred = v(img) # (1, 1000)
```
## Twins SVT
<img src="./images/twins_svt.png" width="400px"></img>
This <a href="https://arxiv.org/abs/2104.13840">paper</a> proposes mixing local and global attention, along with position encoding generator (proposed in <a href="https://arxiv.org/abs/2102.10882">CPVT</a>) and global average pooling, to achieve the same results as <a href="https://arxiv.org/abs/2103.14030">Swin</a>, without the extra complexity of shifted windows, CLS tokens, nor positional embeddings.
```python
import torch
from vit_pytorch.twins_svt import TwinsSVT
model = TwinsSVT(
num_classes = 1000, # number of output classes
s1_emb_dim = 64, # stage 1 - patch embedding projected dimension
s1_patch_size = 4, # stage 1 - patch size for patch embedding
s1_local_patch_size = 7, # stage 1 - patch size for local attention
s1_global_k = 7, # stage 1 - global attention key / value reduction factor, defaults to 7 as specified in paper
s1_depth = 1, # stage 1 - number of transformer blocks (local attn -> ff -> global attn -> ff)
s2_emb_dim = 128, # stage 2 (same as above)
s2_patch_size = 2,
s2_local_patch_size = 7,
s2_global_k = 7,
s2_depth = 1,
s3_emb_dim = 256, # stage 3 (same as above)
s3_patch_size = 2,
s3_local_patch_size = 7,
s3_global_k = 7,
s3_depth = 5,
s4_emb_dim = 512, # stage 4 (same as above)
s4_patch_size = 2,
s4_local_patch_size = 7,
s4_global_k = 7,
s4_depth = 4,
peg_kernel_size = 3, # positional encoding generator kernel size
dropout = 0. # dropout
)
img = torch.randn(1, 3, 224, 224)
pred = model(img) # (1, 1000)
```
## Masked Patch Prediction
@@ -375,6 +584,58 @@ img = torch.randn(1, 3, 224, 224)
v(img) # (1, 1000)
```
## FAQ
- How do I pass in non-square images?
You can already pass in non-square images - you just have to make sure your height and width is less than or equal to the `image_size`, and both divisible by the `patch_size`
ex.
```python
import torch
from vit_pytorch import ViT
v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
img = torch.randn(1, 3, 256, 128) # <-- not a square
preds = v(img) # (1, 1000)
```
- How do I pass in non-square patches?
```python
import torch
from vit_pytorch import ViT
v = ViT(
num_classes = 1000,
image_size = (256, 128), # image size is a tuple of (height, width)
patch_size = (32, 16), # patch size is a tuple of (height, width)
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
img = torch.randn(1, 3, 256, 128)
preds = v(img)
```
## Resources
Coming from computer vision and new to transformers? Here are some resources that greatly accelerated my learning.
@@ -432,6 +693,94 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```
```bibtex
@misc{touvron2021going,
title = {Going deeper with Image Transformers},
author = {Hugo Touvron and Matthieu Cord and Alexandre Sablayrolles and Gabriel Synnaeve and Hervé Jégou},
year = {2021},
eprint = {2103.17239},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@misc{chen2021crossvit,
title = {CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification},
author = {Chun-Fu Chen and Quanfu Fan and Rameswar Panda},
year = {2021},
eprint = {2103.14899},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@misc{wu2021cvt,
title = {CvT: Introducing Convolutions to Vision Transformers},
author = {Haiping Wu and Bin Xiao and Noel Codella and Mengchen Liu and Xiyang Dai and Lu Yuan and Lei Zhang},
year = {2021},
eprint = {2103.15808},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@misc{heo2021rethinking,
title = {Rethinking Spatial Dimensions of Vision Transformers},
author = {Byeongho Heo and Sangdoo Yun and Dongyoon Han and Sanghyuk Chun and Junsuk Choe and Seong Joon Oh},
year = {2021},
eprint = {2103.16302},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@misc{graham2021levit,
title = {LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference},
author = {Ben Graham and Alaaeldin El-Nouby and Hugo Touvron and Pierre Stock and Armand Joulin and Hervé Jégou and Matthijs Douze},
year = {2021},
eprint = {2104.01136},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@misc{li2021localvit,
title = {LocalViT: Bringing Locality to Vision Transformers},
author = {Yawei Li and Kai Zhang and Jiezhang Cao and Radu Timofte and Luc Van Gool},
year = {2021},
eprint = {2104.05707},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@misc{chu2021twins,
title = {Twins: Revisiting Spatial Attention Design in Vision Transformers},
author = {Xiangxiang Chu and Zhi Tian and Yuqing Wang and Bo Zhang and Haibing Ren and Xiaolin Wei and Huaxia Xia and Chunhua Shen},
year = {2021},
eprint = {2104.13840},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@misc{su2021roformer,
title = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
author = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
year = {2021},
eprint = {2104.09864},
archivePrefix = {arXiv},
primaryClass = {cs.CL}
}
```
```bibtex
@misc{vaswani2017attention,
title = {Attention Is All You Need},

BIN
images/cait.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 63 KiB

BIN
images/cross_vit.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 82 KiB

BIN
images/cvt.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 65 KiB

BIN
images/levit.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 71 KiB

BIN
images/pit.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 24 KiB

BIN
images/twins_svt.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 110 KiB

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.10.1',
version = '0.17.3',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',

177
vit_pytorch/cait.py Normal file
View File

@@ -0,0 +1,177 @@
from random import randrange
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
def exists(val):
return val is not None
def dropout_layers(layers, dropout):
if dropout == 0:
return layers
num_layers = len(layers)
to_drop = torch.zeros(num_layers).uniform_(0., 1.) < dropout
# make sure at least one layer makes it
if all(to_drop):
rand_index = randrange(num_layers)
to_drop[rand_index] = False
layers = [layer for (layer, drop) in zip(layers, to_drop) if not drop]
return layers
# classes
class LayerScale(nn.Module):
def __init__(self, dim, fn, depth):
super().__init__()
if depth <= 18: # epsilon detailed in section 2 of paper
init_eps = 0.1
elif depth > 18 and depth <= 24:
init_eps = 1e-5
else:
init_eps = 1e-6
scale = torch.zeros(1, 1, dim).fill_(init_eps)
self.scale = nn.Parameter(scale)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) * self.scale
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.attend = nn.Softmax(dim = -1)
self.mix_heads_pre_attn = nn.Parameter(torch.randn(heads, heads))
self.mix_heads_post_attn = nn.Parameter(torch.randn(heads, heads))
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, context = None):
b, n, _, h = *x.shape, self.heads
context = x if not exists(context) else torch.cat((x, context), dim = 1)
qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
dots = einsum('b h i j, h g -> b g i j', dots, self.mix_heads_pre_attn) # talking heads, pre-softmax
attn = self.attend(dots)
attn = einsum('b h i j, h g -> b g i j', attn, self.mix_heads_post_attn) # talking heads, post-softmax
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., layer_dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
self.layer_dropout = layer_dropout
for ind in range(depth):
self.layers.append(nn.ModuleList([
LayerScale(dim, PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), depth = ind + 1),
LayerScale(dim, PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)), depth = ind + 1)
]))
def forward(self, x, context = None):
layers = dropout_layers(self.layers, dropout = self.layer_dropout)
for attn, ff in layers:
x = attn(x, context = context) + x
x = ff(x) + x
return x
class CaiT(nn.Module):
def __init__(
self,
*,
image_size,
patch_size,
num_classes,
dim,
depth,
cls_depth,
heads,
mlp_dim,
dim_head = 64,
dropout = 0.,
emb_dropout = 0.,
layer_dropout = 0.
):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
patch_dim = 3 * patch_size ** 2
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
nn.Linear(patch_dim, dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.patch_transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, layer_dropout)
self.cls_transformer = Transformer(dim, cls_depth, heads, dim_head, mlp_dim, dropout, layer_dropout)
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
x += self.pos_embedding[:, :n]
x = self.dropout(x)
x = self.patch_transformer(x)
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = self.cls_transformer(cls_tokens, context = x)
return self.mlp_head(x[:, 0])

270
vit_pytorch/cross_vit.py Normal file
View File

@@ -0,0 +1,270 @@
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
# pre-layernorm
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
# feedforward
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
# attention
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, context = None, kv_include_self = False):
b, n, _, h = *x.shape, self.heads
context = default(context, x)
if kv_include_self:
context = torch.cat((x, context), dim = 1) # cross attention requires CLS token includes itself as key / value
qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = self.attend(dots)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
# transformer encoder, for small and large patches
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
self.norm = nn.LayerNorm(dim)
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)
# projecting CLS tokens, in the case that small and large patch tokens have different dimensions
class ProjectInOut(nn.Module):
def __init__(self, dim_in, dim_out, fn):
super().__init__()
self.fn = fn
need_projection = dim_in != dim_out
self.project_in = nn.Linear(dim_in, dim_out) if need_projection else nn.Identity()
self.project_out = nn.Linear(dim_out, dim_in) if need_projection else nn.Identity()
def forward(self, x, *args, **kwargs):
x = self.project_in(x)
x = self.fn(x, *args, **kwargs)
x = self.project_out(x)
return x
# cross attention transformer
class CrossTransformer(nn.Module):
def __init__(self, sm_dim, lg_dim, depth, heads, dim_head, dropout):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
ProjectInOut(sm_dim, lg_dim, PreNorm(lg_dim, Attention(lg_dim, heads = heads, dim_head = dim_head, dropout = dropout))),
ProjectInOut(lg_dim, sm_dim, PreNorm(sm_dim, Attention(sm_dim, heads = heads, dim_head = dim_head, dropout = dropout)))
]))
def forward(self, sm_tokens, lg_tokens):
(sm_cls, sm_patch_tokens), (lg_cls, lg_patch_tokens) = map(lambda t: (t[:, :1], t[:, 1:]), (sm_tokens, lg_tokens))
for sm_attend_lg, lg_attend_sm in self.layers:
sm_cls = sm_attend_lg(sm_cls, context = lg_patch_tokens, kv_include_self = True) + sm_cls
lg_cls = lg_attend_sm(lg_cls, context = sm_patch_tokens, kv_include_self = True) + lg_cls
sm_tokens = torch.cat((sm_cls, sm_patch_tokens), dim = 1)
lg_tokens = torch.cat((lg_cls, lg_patch_tokens), dim = 1)
return sm_tokens, lg_tokens
# multi-scale encoder
class MultiScaleEncoder(nn.Module):
def __init__(
self,
*,
depth,
sm_dim,
lg_dim,
sm_enc_params,
lg_enc_params,
cross_attn_heads,
cross_attn_depth,
cross_attn_dim_head = 64,
dropout = 0.
):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Transformer(dim = sm_dim, dropout = dropout, **sm_enc_params),
Transformer(dim = lg_dim, dropout = dropout, **lg_enc_params),
CrossTransformer(sm_dim = sm_dim, lg_dim = lg_dim, depth = cross_attn_depth, heads = cross_attn_heads, dim_head = cross_attn_dim_head, dropout = dropout)
]))
def forward(self, sm_tokens, lg_tokens):
for sm_enc, lg_enc, cross_attend in self.layers:
sm_tokens, lg_tokens = sm_enc(sm_tokens), lg_enc(lg_tokens)
sm_tokens, lg_tokens = cross_attend(sm_tokens, lg_tokens)
return sm_tokens, lg_tokens
# patch-based image to token embedder
class ImageEmbedder(nn.Module):
def __init__(
self,
*,
dim,
image_size,
patch_size,
dropout = 0.
):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
patch_dim = 3 * patch_size ** 2
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
nn.Linear(patch_dim, dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(dropout)
def forward(self, img):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
return self.dropout(x)
# cross ViT class
class CrossViT(nn.Module):
def __init__(
self,
*,
image_size,
num_classes,
sm_dim,
lg_dim,
sm_patch_size = 12,
sm_enc_depth = 1,
sm_enc_heads = 8,
sm_enc_mlp_dim = 2048,
sm_enc_dim_head = 64,
lg_patch_size = 16,
lg_enc_depth = 4,
lg_enc_heads = 8,
lg_enc_mlp_dim = 2048,
lg_enc_dim_head = 64,
cross_attn_depth = 2,
cross_attn_heads = 8,
cross_attn_dim_head = 64,
depth = 3,
dropout = 0.1,
emb_dropout = 0.1
):
super().__init__()
self.sm_image_embedder = ImageEmbedder(dim = sm_dim, image_size = image_size, patch_size = sm_patch_size, dropout = emb_dropout)
self.lg_image_embedder = ImageEmbedder(dim = lg_dim, image_size = image_size, patch_size = lg_patch_size, dropout = emb_dropout)
self.multi_scale_encoder = MultiScaleEncoder(
depth = depth,
sm_dim = sm_dim,
lg_dim = lg_dim,
cross_attn_heads = cross_attn_heads,
cross_attn_dim_head = cross_attn_dim_head,
cross_attn_depth = cross_attn_depth,
sm_enc_params = dict(
depth = sm_enc_depth,
heads = sm_enc_heads,
mlp_dim = sm_enc_mlp_dim,
dim_head = sm_enc_dim_head
),
lg_enc_params = dict(
depth = lg_enc_depth,
heads = lg_enc_heads,
mlp_dim = lg_enc_mlp_dim,
dim_head = lg_enc_dim_head
),
dropout = dropout
)
self.sm_mlp_head = nn.Sequential(nn.LayerNorm(sm_dim), nn.Linear(sm_dim, num_classes))
self.lg_mlp_head = nn.Sequential(nn.LayerNorm(lg_dim), nn.Linear(lg_dim, num_classes))
def forward(self, img):
sm_tokens = self.sm_image_embedder(img)
lg_tokens = self.lg_image_embedder(img)
sm_tokens, lg_tokens = self.multi_scale_encoder(sm_tokens, lg_tokens)
sm_cls, lg_cls = map(lambda t: t[:, 0], (sm_tokens, lg_tokens))
sm_logits = self.sm_mlp_head(sm_cls)
lg_logits = self.lg_mlp_head(lg_cls)
return sm_logits + lg_logits

173
vit_pytorch/cvt.py Normal file
View File

@@ -0,0 +1,173 @@
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helper methods
def group_dict_by_key(cond, d):
return_val = [dict(), dict()]
for key in d.keys():
match = bool(cond(key))
ind = int(not match)
return_val[ind][key] = d[key]
return (*return_val,)
def group_by_key_prefix_and_remove_prefix(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(lambda x: x.startswith(prefix), d)
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs
# classes
class LayerNorm(nn.Module): # layernorm, but done in the channel dimension #1
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (std + self.eps) * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(dim, dim * mult, 1),
nn.GELU(),
nn.Dropout(dropout),
nn.Conv2d(dim * mult, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class DepthWiseConv2d(nn.Module):
def __init__(self, dim_in, dim_out, kernel_size, padding, stride, bias = True):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
nn.BatchNorm2d(dim_in),
nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, proj_kernel, kv_proj_stride, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
padding = proj_kernel // 2
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.to_q = DepthWiseConv2d(dim, inner_dim, proj_kernel, padding = padding, stride = 1, bias = False)
self.to_kv = DepthWiseConv2d(dim, inner_dim * 2, proj_kernel, padding = padding, stride = kv_proj_stride, bias = False)
self.to_out = nn.Sequential(
nn.Conv2d(inner_dim, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
shape = x.shape
b, n, _, y, h = *shape, self.heads
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = 1))
q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> (b h) (x y) d', h = h), (q, k, v))
dots = einsum('b i d, b j d -> b i j', q, k) * self.scale
attn = self.attend(dots)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, y = y)
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, proj_kernel, kv_proj_stride, depth, heads, dim_head = 64, mlp_mult = 4, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, proj_kernel = proj_kernel, kv_proj_stride = kv_proj_stride, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class CvT(nn.Module):
def __init__(
self,
*,
num_classes,
s1_emb_dim = 64,
s1_emb_kernel = 7,
s1_emb_stride = 4,
s1_proj_kernel = 3,
s1_kv_proj_stride = 2,
s1_heads = 1,
s1_depth = 1,
s1_mlp_mult = 4,
s2_emb_dim = 192,
s2_emb_kernel = 3,
s2_emb_stride = 2,
s2_proj_kernel = 3,
s2_kv_proj_stride = 2,
s2_heads = 3,
s2_depth = 2,
s2_mlp_mult = 4,
s3_emb_dim = 384,
s3_emb_kernel = 3,
s3_emb_stride = 2,
s3_proj_kernel = 3,
s3_kv_proj_stride = 2,
s3_heads = 6,
s3_depth = 10,
s3_mlp_mult = 4,
dropout = 0.
):
super().__init__()
kwargs = dict(locals())
dim = 3
layers = []
for prefix in ('s1', 's2', 's3'):
config, kwargs = group_by_key_prefix_and_remove_prefix(f'{prefix}_', kwargs)
layers.append(nn.Sequential(
nn.Conv2d(dim, config['emb_dim'], kernel_size = config['emb_kernel'], padding = (config['emb_kernel'] // 2), stride = config['emb_stride']),
LayerNorm(config['emb_dim']),
Transformer(dim = config['emb_dim'], proj_kernel = config['proj_kernel'], kv_proj_stride = config['kv_proj_stride'], depth = config['depth'], heads = config['heads'], mlp_mult = config['mlp_mult'], dropout = dropout)
))
dim = config['emb_dim']
self.layers = nn.Sequential(
*layers,
nn.AdaptiveAvgPool2d(1),
Rearrange('... () () -> ...'),
nn.Linear(dim, num_classes)
)
def forward(self, x):
return self.layers(x)

View File

@@ -104,7 +104,8 @@ class DistillWrapper(nn.Module):
teacher,
student,
temperature = 1.,
alpha = 0.5
alpha = 0.5,
hard = False
):
super().__init__()
assert (isinstance(student, (DistillableViT, DistillableT2TViT, DistillableEfficientViT))) , 'student must be a vision transformer'
@@ -116,6 +117,7 @@ class DistillWrapper(nn.Module):
num_classes = student.num_classes
self.temperature = temperature
self.alpha = alpha
self.hard = hard
self.distillation_token = nn.Parameter(torch.randn(1, 1, dim))
@@ -137,11 +139,15 @@ class DistillWrapper(nn.Module):
loss = F.cross_entropy(student_logits, labels)
distill_loss = F.kl_div(
F.log_softmax(distill_logits / T, dim = -1),
F.softmax(teacher_logits / T, dim = -1).detach(),
reduction = 'batchmean')
if not self.hard:
distill_loss = F.kl_div(
F.log_softmax(distill_logits / T, dim = -1),
F.softmax(teacher_logits / T, dim = -1).detach(),
reduction = 'batchmean')
distill_loss *= T ** 2
distill_loss *= T ** 2
else:
teacher_labels = teacher_logits.argmax(dim = -1)
distill_loss = F.cross_entropy(student_logits, teacher_labels)
return loss * alpha + distill_loss * (1 - alpha)
return loss * (1 - alpha) + distill_loss * alpha

193
vit_pytorch/levit.py Normal file
View File

@@ -0,0 +1,193 @@
from math import ceil
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def cast_tuple(val, l = 3):
val = val if isinstance(val, tuple) else (val,)
return (*val, *((val[-1],) * max(l - len(val), 0)))
def always(val):
return lambda *args, **kwargs: val
# classes
class FeedForward(nn.Module):
def __init__(self, dim, mult, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(dim, dim * mult, 1),
nn.GELU(),
nn.Dropout(dropout),
nn.Conv2d(dim * mult, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, fmap_size, heads = 8, dim_key = 32, dim_value = 64, dropout = 0., dim_out = None, downsample = False):
super().__init__()
inner_dim_key = dim_key * heads
inner_dim_value = dim_value * heads
dim_out = default(dim_out, dim)
self.heads = heads
self.scale = dim_key ** -0.5
self.to_q = nn.Sequential(nn.Conv2d(dim, inner_dim_key, 1, stride = (2 if downsample else 1), bias = False), nn.BatchNorm2d(inner_dim_key))
self.to_k = nn.Sequential(nn.Conv2d(dim, inner_dim_key, 1, bias = False), nn.BatchNorm2d(inner_dim_key))
self.to_v = nn.Sequential(nn.Conv2d(dim, inner_dim_value, 1, bias = False), nn.BatchNorm2d(inner_dim_value))
self.attend = nn.Softmax(dim = -1)
out_batch_norm = nn.BatchNorm2d(dim_out)
nn.init.zeros_(out_batch_norm.weight)
self.to_out = nn.Sequential(
nn.GELU(),
nn.Conv2d(inner_dim_value, dim_out, 1),
out_batch_norm,
nn.Dropout(dropout)
)
# positional bias
self.pos_bias = nn.Embedding(fmap_size * fmap_size, heads)
q_range = torch.arange(0, fmap_size, step = (2 if downsample else 1))
k_range = torch.arange(fmap_size)
q_pos = torch.stack(torch.meshgrid(q_range, q_range), dim = -1)
k_pos = torch.stack(torch.meshgrid(k_range, k_range), dim = -1)
q_pos, k_pos = map(lambda t: rearrange(t, 'i j c -> (i j) c'), (q_pos, k_pos))
rel_pos = (q_pos[:, None, ...] - k_pos[None, :, ...]).abs()
x_rel, y_rel = rel_pos.unbind(dim = -1)
pos_indices = (x_rel * fmap_size) + y_rel
self.register_buffer('pos_indices', pos_indices)
def apply_pos_bias(self, fmap):
bias = self.pos_bias(self.pos_indices)
bias = rearrange(bias, 'i j h -> () h i j')
return fmap + bias
def forward(self, x):
b, n, *_, h = *x.shape, self.heads
q = self.to_q(x)
y = q.shape[2]
qkv = (q, self.to_k(x), self.to_v(x))
q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = h), qkv)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
dots = self.apply_pos_bias(dots)
attn = self.attend(dots)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h (x y) d -> b (h d) x y', h = h, y = y)
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, fmap_size, depth, heads, dim_key, dim_value, mlp_mult = 2, dropout = 0., dim_out = None, downsample = False):
super().__init__()
dim_out = default(dim_out, dim)
self.layers = nn.ModuleList([])
self.attn_residual = (not downsample) and dim == dim_out
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, fmap_size = fmap_size, heads = heads, dim_key = dim_key, dim_value = dim_value, dropout = dropout, downsample = downsample, dim_out = dim_out),
FeedForward(dim_out, mlp_mult, dropout = dropout)
]))
def forward(self, x):
for attn, ff in self.layers:
attn_res = (x if self.attn_residual else 0)
x = attn(x) + attn_res
x = ff(x) + x
return x
class LeViT(nn.Module):
def __init__(
self,
*,
image_size,
num_classes,
dim,
depth,
heads,
mlp_mult,
stages = 3,
dim_key = 32,
dim_value = 64,
dropout = 0.,
num_distill_classes = None
):
super().__init__()
dims = cast_tuple(dim, stages)
depths = cast_tuple(depth, stages)
layer_heads = cast_tuple(heads, stages)
assert all(map(lambda t: len(t) == stages, (dims, depths, layer_heads))), 'dimensions, depths, and heads must be a tuple that is less than the designated number of stages'
self.conv_embedding = nn.Sequential(
nn.Conv2d(3, 32, 3, stride = 2, padding = 1),
nn.Conv2d(32, 64, 3, stride = 2, padding = 1),
nn.Conv2d(64, 128, 3, stride = 2, padding = 1),
nn.Conv2d(128, dims[0], 3, stride = 2, padding = 1)
)
fmap_size = image_size // (2 ** 4)
layers = []
for ind, dim, depth, heads in zip(range(stages), dims, depths, layer_heads):
is_last = ind == (stages - 1)
layers.append(Transformer(dim, fmap_size, depth, heads, dim_key, dim_value, mlp_mult, dropout))
if not is_last:
next_dim = dims[ind + 1]
layers.append(Transformer(dim, fmap_size, 1, heads * 2, dim_key, dim_value, dim_out = next_dim, downsample = True))
fmap_size = ceil(fmap_size / 2)
self.backbone = nn.Sequential(*layers)
self.pool = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
Rearrange('... () () -> ...')
)
self.distill_head = nn.Linear(dim, num_distill_classes) if exists(num_distill_classes) else always(None)
self.mlp_head = nn.Linear(dim, num_classes)
def forward(self, img):
x = self.conv_embedding(img)
x = self.backbone(x)
x = self.pool(x)
out = self.mlp_head(x)
distill = self.distill_head(x)
if exists(distill):
return out, distill
return out

152
vit_pytorch/local_vit.py Normal file
View File

@@ -0,0 +1,152 @@
from math import sqrt
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# classes
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
class ExcludeCLS(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
cls_token, x = x[:, :1], x[:, 1:]
x = self.fn(x, **kwargs)
return torch.cat((cls_token, x), dim = 1)
# prenorm
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
# feed forward related classes
class DepthWiseConv2d(nn.Module):
def __init__(self, dim_in, dim_out, kernel_size, padding, stride = 1, bias = True):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)
)
def forward(self, x):
return self.net(x)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(dim, hidden_dim, 1),
nn.Hardswish(),
DepthWiseConv2d(hidden_dim, hidden_dim, 3, padding = 1),
nn.Hardswish(),
nn.Dropout(dropout),
nn.Conv2d(hidden_dim, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
h = w = int(sqrt(x.shape[-2]))
x = rearrange(x, 'b (h w) c -> b c h w', h = h, w = w)
x = self.net(x)
x = rearrange(x, 'b c h w -> b (h w) c')
return x
# attention
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = self.attend(dots)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
ExcludeCLS(Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x)
x = ff(x)
return x
# main class
class LocalViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
nn.Linear(patch_dim, dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x)
return self.mlp_head(x[:, 0])

183
vit_pytorch/pit.py Normal file
View File

@@ -0,0 +1,183 @@
from math import sqrt
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
def cast_tuple(val, num):
return val if isinstance(val, tuple) else (val,) * num
def conv_output_size(image_size, kernel_size, stride, padding = 0):
return int(((image_size - kernel_size + (2 * padding)) / stride) + 1)
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = self.attend(dots)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
# depthwise convolution, for pooling
class DepthWiseConv2d(nn.Module):
def __init__(self, dim_in, dim_out, kernel_size, padding, stride, bias = True):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(dim_in, dim_out, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
nn.Conv2d(dim_out, dim_out, kernel_size = 1, bias = bias)
)
def forward(self, x):
return self.net(x)
# pooling layer
class Pool(nn.Module):
def __init__(self, dim):
super().__init__()
self.downsample = DepthWiseConv2d(dim, dim * 2, kernel_size = 3, stride = 2, padding = 1)
self.cls_ff = nn.Linear(dim, dim * 2)
def forward(self, x):
cls_token, tokens = x[:, :1], x[:, 1:]
cls_token = self.cls_ff(cls_token)
tokens = rearrange(tokens, 'b (h w) c -> b c h w', h = int(sqrt(tokens.shape[1])))
tokens = self.downsample(tokens)
tokens = rearrange(tokens, 'b c h w -> b (h w) c')
return torch.cat((cls_token, tokens), dim = 1)
# main class
class PiT(nn.Module):
def __init__(
self,
*,
image_size,
patch_size,
num_classes,
dim,
depth,
heads,
mlp_dim,
dim_head = 64,
dropout = 0.,
emb_dropout = 0.
):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
assert isinstance(depth, tuple), 'depth must be a tuple of integers, specifying the number of blocks before each downsizing'
heads = cast_tuple(heads, len(depth))
patch_dim = 3 * patch_size ** 2
self.to_patch_embedding = nn.Sequential(
nn.Unfold(kernel_size = patch_size, stride = patch_size // 2),
Rearrange('b c n -> b n c'),
nn.Linear(patch_dim, dim)
)
output_size = conv_output_size(image_size, patch_size, patch_size // 2)
num_patches = output_size ** 2
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
layers = []
for ind, (layer_depth, layer_heads) in enumerate(zip(depth, heads)):
not_last = ind < (len(depth) - 1)
layers.append(Transformer(dim, layer_depth, layer_heads, dim_head, mlp_dim, dropout))
if not_last:
layers.append(Pool(dim))
dim *= 2
self.layers = nn.Sequential(*layers)
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding
x = self.dropout(x)
x = self.layers(x)
return self.mlp_head(x[:, 0])

View File

@@ -4,9 +4,6 @@ from torch import nn
from vit_pytorch.vit import Attention
def exists(val):
return val is not None
def find_modules(nn_module, type):
return [module for module in nn_module.modules() if isinstance(module, type)]
@@ -25,7 +22,7 @@ class Recorder(nn.Module):
self.recordings.append(output.clone().detach())
def _register_hook(self):
modules = find_modules(self, Attention)
modules = find_modules(self.vit.transformer, Attention)
for module in modules:
handle = module.attend.register_forward_hook(self._hook)
self.hooks.append(handle)

209
vit_pytorch/rvt.py Normal file
View File

@@ -0,0 +1,209 @@
from math import sqrt, pi, log
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# rotary embeddings
def rotate_every_two(x):
x = rearrange(x, '... (d j) -> ... d j', j = 2)
x1, x2 = x.unbind(dim = -1)
x = torch.stack((-x2, x1), dim = -1)
return rearrange(x, '... d j -> ... (d j)')
class AxialRotaryEmbedding(nn.Module):
def __init__(self, dim, max_freq = 10):
super().__init__()
self.dim = dim
scales = torch.logspace(0., log(max_freq / 2) / log(2), self.dim // 4, base = 2)
self.register_buffer('scales', scales)
def forward(self, x):
device, dtype, n = x.device, x.dtype, int(sqrt(x.shape[-2]))
seq = torch.linspace(-1., 1., steps = n, device = device)
seq = seq.unsqueeze(-1)
scales = self.scales[(*((None,) * (len(seq.shape) - 1)), Ellipsis)]
scales = scales.to(x)
seq = seq * scales * pi
x_sinu = repeat(seq, 'i d -> i j d', j = n)
y_sinu = repeat(seq, 'j d -> i j d', i = n)
sin = torch.cat((x_sinu.sin(), y_sinu.sin()), dim = -1)
cos = torch.cat((x_sinu.cos(), y_sinu.cos()), dim = -1)
sin, cos = map(lambda t: rearrange(t, 'i j d -> (i j) d'), (sin, cos))
sin, cos = map(lambda t: repeat(t, 'n d -> () n (d j)', j = 2), (sin, cos))
return sin, cos
class DepthWiseConv2d(nn.Module):
def __init__(self, dim_in, dim_out, kernel_size, padding, stride = 1, bias = True):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)
)
def forward(self, x):
return self.net(x)
# helper classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class SpatialConv(nn.Module):
def __init__(self, dim_in, dim_out, kernel, bias = False):
super().__init__()
self.conv = DepthWiseConv2d(dim_in, dim_out, kernel, padding = kernel // 2, bias = False)
self.cls_proj = nn.Linear(dim_in, dim_out) if dim_in != dim_out else nn.Identity()
def forward(self, x, fmap_dims):
cls_token, x = x[:, :1], x[:, 1:]
x = rearrange(x, 'b (h w) d -> b d h w', **fmap_dims)
x = self.conv(x)
x = rearrange(x, 'b d h w -> b (h w) d')
cls_token = self.cls_proj(cls_token)
return torch.cat((cls_token, x), dim = 1)
class GEGLU(nn.Module):
def forward(self, x):
x, gates = x.chunk(2, dim = -1)
return F.gelu(gates) * x
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0., use_glu = True):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim * 2 if use_glu else hidden_dim),
GEGLU() if use_glu else nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., use_rotary = True, use_ds_conv = True, conv_query_kernel = 5):
super().__init__()
inner_dim = dim_head * heads
self.use_rotary = use_rotary
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.use_ds_conv = use_ds_conv
self.to_q = SpatialConv(dim, inner_dim, conv_query_kernel, bias = False) if use_ds_conv else nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, pos_emb, fmap_dims):
b, n, _, h = *x.shape, self.heads
to_q_kwargs = {'fmap_dims': fmap_dims} if self.use_ds_conv else {}
q = self.to_q(x, **to_q_kwargs)
qkv = (q, *self.to_kv(x).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)
if self.use_rotary:
# apply 2d rotary embeddings to queries and keys, excluding CLS tokens
sin, cos = pos_emb
dim_rotary = sin.shape[-1]
(q_cls, q), (k_cls, k) = map(lambda t: (t[:, :1], t[:, 1:]), (q, k))
# handle the case where rotary dimension < head dimension
(q, q_pass), (k, k_pass) = map(lambda t: (t[..., :dim_rotary], t[..., dim_rotary:]), (q, k))
q, k = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k))
q, k = map(lambda t: torch.cat(t, dim = -1), ((q, q_pass), (k, k_pass)))
# concat back the CLS tokens
q = torch.cat((q_cls, q), dim = 1)
k = torch.cat((k_cls, k), dim = 1)
dots = einsum('b i d, b j d -> b i j', q, k) * self.scale
attn = self.attend(dots)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., use_rotary = True, use_ds_conv = True, use_glu = True):
super().__init__()
self.layers = nn.ModuleList([])
self.pos_emb = AxialRotaryEmbedding(dim_head)
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_rotary = use_rotary, use_ds_conv = use_ds_conv)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout, use_glu = use_glu))
]))
def forward(self, x, fmap_dims):
pos_emb = self.pos_emb(x[:, 1:])
for attn, ff in self.layers:
x = attn(x, pos_emb = pos_emb, fmap_dims = fmap_dims) + x
x = ff(x) + x
return x
# Rotary Vision Transformer
class RvT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., use_rotary = True, use_ds_conv = True, use_glu = True):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
self.patch_size = patch_size
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
nn.Linear(patch_dim, dim),
)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, use_rotary, use_ds_conv, use_glu)
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
b, _, h, w, p = *img.shape, self.patch_size
x = self.to_patch_embedding(img)
n = x.shape[1]
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
fmap_dims = {'h': h // p, 'w': w // p}
x = self.transformer(x, fmap_dims = fmap_dims)
return self.mlp_head(x[:, 0])

229
vit_pytorch/twins_svt.py Normal file
View File

@@ -0,0 +1,229 @@
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helper methods
def group_dict_by_key(cond, d):
return_val = [dict(), dict()]
for key in d.keys():
match = bool(cond(key))
ind = int(not match)
return_val[ind][key] = d[key]
return (*return_val,)
def group_by_key_prefix_and_remove_prefix(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(lambda x: x.startswith(prefix), d)
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs
# classes
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
class LayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (std + self.eps) * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(dim, dim * mult, 1),
nn.GELU(),
nn.Dropout(dropout),
nn.Conv2d(dim * mult, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class PatchEmbedding(nn.Module):
def __init__(self, *, dim, dim_out, patch_size):
super().__init__()
self.dim = dim
self.dim_out = dim_out
self.patch_size = patch_size
self.proj = nn.Conv2d(patch_size ** 2 * dim, dim_out, 1)
def forward(self, fmap):
p = self.patch_size
fmap = rearrange(fmap, 'b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = p, p2 = p)
return self.proj(fmap)
class PEG(nn.Module):
def __init__(self, dim, kernel_size = 3):
super().__init__()
self.proj = Residual(nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1))
def forward(self, x):
return self.proj(x)
class LocalAttention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., patch_size = 7):
super().__init__()
inner_dim = dim_head * heads
self.patch_size = patch_size
self.heads = heads
self.scale = dim_head ** -0.5
self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False)
self.to_kv = nn.Conv2d(dim, inner_dim * 2, 1, bias = False)
self.to_out = nn.Sequential(
nn.Conv2d(inner_dim, dim, 1),
nn.Dropout(dropout)
)
def forward(self, fmap):
shape, p = fmap.shape, self.patch_size
b, n, x, y, h = *shape, self.heads
x, y = map(lambda t: t // p, (x, y))
fmap = rearrange(fmap, 'b c (x p1) (y p2) -> (b x y) c p1 p2', p1 = p, p2 = p)
q, k, v = (self.to_q(fmap), *self.to_kv(fmap).chunk(2, dim = 1))
q, k, v = map(lambda t: rearrange(t, 'b (h d) p1 p2 -> (b h) (p1 p2) d', h = h), (q, k, v))
dots = einsum('b i d, b j d -> b i j', q, k) * self.scale
attn = dots.softmax(dim = - 1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b x y h) (p1 p2) d -> b (h d) (x p1) (y p2)', h = h, x = x, y = y, p1 = p, p2 = p)
return self.to_out(out)
class GlobalAttention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., k = 7):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False)
self.to_kv = nn.Conv2d(dim, inner_dim * 2, k, stride = k, bias = False)
self.to_out = nn.Sequential(
nn.Conv2d(inner_dim, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
shape = x.shape
b, n, _, y, h = *shape, self.heads
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = 1))
q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> (b h) (x y) d', h = h), (q, k, v))
dots = einsum('b i d, b j d -> b i j', q, k) * self.scale
attn = dots.softmax(dim = -1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, y = y)
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads = 8, dim_head = 64, mlp_mult = 4, local_patch_size = 7, global_k = 7, dropout = 0., has_local = True):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Residual(PreNorm(dim, LocalAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout, patch_size = local_patch_size))) if has_local else nn.Identity(),
Residual(PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout))) if has_local else nn.Identity(),
Residual(PreNorm(dim, GlobalAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout, k = global_k))),
Residual(PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout)))
]))
def forward(self, x):
for local_attn, ff1, global_attn, ff2 in self.layers:
x = local_attn(x)
x = ff1(x)
x = global_attn(x)
x = ff2(x)
return x
class TwinsSVT(nn.Module):
def __init__(
self,
*,
num_classes,
s1_emb_dim = 64,
s1_patch_size = 4,
s1_local_patch_size = 7,
s1_global_k = 7,
s1_depth = 1,
s2_emb_dim = 128,
s2_patch_size = 2,
s2_local_patch_size = 7,
s2_global_k = 7,
s2_depth = 1,
s3_emb_dim = 256,
s3_patch_size = 2,
s3_local_patch_size = 7,
s3_global_k = 7,
s3_depth = 5,
s4_emb_dim = 512,
s4_patch_size = 2,
s4_local_patch_size = 7,
s4_global_k = 7,
s4_depth = 4,
peg_kernel_size = 3,
dropout = 0.
):
super().__init__()
kwargs = dict(locals())
dim = 3
layers = []
for prefix in ('s1', 's2', 's3', 's4'):
config, kwargs = group_by_key_prefix_and_remove_prefix(f'{prefix}_', kwargs)
is_last = prefix == 's4'
dim_next = config['emb_dim']
layers.append(nn.Sequential(
PatchEmbedding(dim = dim, dim_out = dim_next, patch_size = config['patch_size']),
Transformer(dim = dim_next, depth = 1, local_patch_size = config['local_patch_size'], global_k = config['global_k'], dropout = dropout, has_local = not is_last),
PEG(dim = dim_next, kernel_size = peg_kernel_size),
Transformer(dim = dim_next, depth = config['depth'], local_patch_size = config['local_patch_size'], global_k = config['global_k'], dropout = dropout, has_local = not is_last)
))
dim = dim_next
self.layers = nn.Sequential(
*layers,
nn.AdaptiveAvgPool2d(1),
Rearrange('... () () -> ...'),
nn.Linear(dim, num_classes)
)
def forward(self, x):
return self.layers(x)

View File

@@ -5,6 +5,13 @@ import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
@@ -74,13 +81,17 @@ class Transformer(nn.Module):
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.Linear(patch_dim, dim),
)