Compare commits
39 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
30b37c4028 | ||
|
|
4497f1e90f | ||
|
|
b50d3e1334 | ||
|
|
e075460937 | ||
|
|
5e23e48e4d | ||
|
|
db04c0f319 | ||
|
|
0f31ca79e3 | ||
|
|
2cb6b35030 | ||
|
|
2ec9161a98 | ||
|
|
3a3038c702 | ||
|
|
b1f1044c8e | ||
|
|
deb96201d5 | ||
|
|
05b47cc070 | ||
|
|
9ef8da4759 | ||
|
|
506fcf83a6 | ||
|
|
6fb360a1ff | ||
|
|
9332b9e8c9 | ||
|
|
da950e6d2c | ||
|
|
4b9a02d89c | ||
|
|
518924eac5 | ||
|
|
e712003dfb | ||
|
|
d04ce06a30 | ||
|
|
8135d70e4e | ||
|
|
3067155cea | ||
|
|
ab7315cca1 | ||
|
|
15294c304e | ||
|
|
b900850144 | ||
|
|
78489045cd | ||
|
|
173e07e02e | ||
|
|
0e63766e54 | ||
|
|
a6cbda37b9 | ||
|
|
73de1e8a73 | ||
|
|
1698b7bef8 | ||
|
|
fc14561de7 | ||
|
|
be5d560821 | ||
|
|
77703ae1fc | ||
|
|
a0a4fa5e7d | ||
|
|
174e71cf53 | ||
|
|
e14bd14a8f |
383
README.md
@@ -1,4 +1,4 @@
|
||||
<img src="./vit.gif" width="500px"></img>
|
||||
<img src="./images/vit.gif" width="500px"></img>
|
||||
|
||||
## Vision Transformer - Pytorch
|
||||
|
||||
@@ -33,9 +33,8 @@ v = ViT(
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 256, 256)
|
||||
mask = torch.ones(1, 8, 8).bool() # optional mask, designating which patch to attend to
|
||||
|
||||
preds = v(img, mask = mask) # (1, 1000)
|
||||
preds = v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## Parameters
|
||||
@@ -64,7 +63,7 @@ Embedding dropout rate.
|
||||
|
||||
## Distillation
|
||||
|
||||
<img src="./distill.png" width="300px"></img>
|
||||
<img src="./images/distill.png" width="300px"></img>
|
||||
|
||||
A recent <a href="https://arxiv.org/abs/2012.12877">paper</a> has shown that use of a distillation token for distilling knowledge from convolutional nets to vision transformer can yield small and efficient vision transformers. This repository offers the means to do distillation easily.
|
||||
|
||||
@@ -94,7 +93,8 @@ distiller = DistillWrapper(
|
||||
student = v,
|
||||
teacher = teacher,
|
||||
temperature = 3, # temperature of distillation
|
||||
alpha = 0.5 # trade between main loss and distillation loss
|
||||
alpha = 0.5, # trade between main loss and distillation loss
|
||||
hard = False # whether to use soft or hard distillation
|
||||
)
|
||||
|
||||
img = torch.randn(2, 3, 256, 256)
|
||||
@@ -117,9 +117,67 @@ v = v.to_vit()
|
||||
type(v) # <class 'vit_pytorch.vit_pytorch.ViT'>
|
||||
```
|
||||
|
||||
## Deep ViT
|
||||
|
||||
This <a href="https://arxiv.org/abs/2103.11886">paper</a> notes that ViT struggles to attend at greater depths (past 12 layers), and suggests mixing the attention of each head post-softmax as a solution, dubbed Re-attention. The results line up with the <a href="https://github.com/lucidrains/x-transformers#talking-heads-attention">Talking Heads</a> paper from NLP.
|
||||
|
||||
You can use it as follows
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.deepvit import DeepViT
|
||||
|
||||
v = DeepViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 256, 256)
|
||||
|
||||
preds = v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## CaiT
|
||||
|
||||
<a href="https://arxiv.org/abs/2103.17239">This paper</a> also notes difficulty in training vision transformers at greater depths and proposes two solutions. First it proposes to do per-channel multiplication of the output of the residual block. Second, it proposes to have the patches attend to one another, and only allow the CLS token to attend to the patches in the last few layers.
|
||||
|
||||
They also add <a href="https://github.com/lucidrains/x-transformers#talking-heads-attention">Talking Heads</a>, noting improvements
|
||||
|
||||
You can use this scheme as follows
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.cait import CaiT
|
||||
|
||||
v = CaiT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 12, # depth of transformer for patch to patch attention only
|
||||
cls_depth = 2, # depth of cross attention of CLS tokens to patch
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1,
|
||||
layer_dropout = 0.05 # randomly dropout 5% of the layers
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 256, 256)
|
||||
|
||||
preds = v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## Token-to-Token ViT
|
||||
|
||||
<img src="./t2t.png" width="400px"></img>
|
||||
<img src="./images/t2t.png" width="400px"></img>
|
||||
|
||||
<a href="https://arxiv.org/abs/2101.11986">This paper</a> proposes that the first couple layers should downsample the image sequence by unfolding, leading to overlapping image data in each token as shown in the figure above. You can use this variant of the `ViT` as follows.
|
||||
|
||||
@@ -138,7 +196,229 @@ v = T2TViT(
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 224, 224)
|
||||
v(img) # (1, 1000)
|
||||
|
||||
preds = v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## Cross ViT
|
||||
|
||||
<img src="./images/cross_vit.png" width="400px"></img>
|
||||
|
||||
<a href="https://arxiv.org/abs/2103.14899">This paper</a> proposes to have two vision transformers processing the image at different scales, cross attending to one every so often. They show improvements on top of the base vision transformer.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.cross_vit import CrossViT
|
||||
|
||||
v = CrossViT(
|
||||
image_size = 256,
|
||||
num_classes = 1000,
|
||||
depth = 4, # number of multi-scale encoding blocks
|
||||
sm_dim = 192, # high res dimension
|
||||
sm_patch_size = 16, # high res patch size (should be smaller than lg_patch_size)
|
||||
sm_enc_depth = 2, # high res depth
|
||||
sm_enc_heads = 8, # high res heads
|
||||
sm_enc_mlp_dim = 2048, # high res feedforward dimension
|
||||
lg_dim = 384, # low res dimension
|
||||
lg_patch_size = 64, # low res patch size
|
||||
lg_enc_depth = 3, # low res depth
|
||||
lg_enc_heads = 8, # low res heads
|
||||
lg_enc_mlp_dim = 2048, # low res feedforward dimensions
|
||||
cross_attn_depth = 2, # cross attention rounds
|
||||
cross_attn_heads = 8, # cross attention heads
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 256, 256)
|
||||
|
||||
pred = v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## PiT
|
||||
|
||||
<img src="./images/pit.png" width="400px"></img>
|
||||
|
||||
<a href="https://arxiv.org/abs/2103.16302">This paper</a> proposes to downsample the tokens through a pooling procedure using depth-wise convolutions.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.pit import PiT
|
||||
|
||||
v = PiT(
|
||||
image_size = 224,
|
||||
patch_size = 14,
|
||||
dim = 256,
|
||||
num_classes = 1000,
|
||||
depth = (3, 3, 3), # list of depths, indicating the number of rounds of each stage before a downsample
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
# forward pass now returns predictions and the attention maps
|
||||
|
||||
img = torch.randn(1, 3, 224, 224)
|
||||
|
||||
preds = v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## LeViT
|
||||
|
||||
<img src="./images/levit.png" width="300px"></img>
|
||||
|
||||
<a href="https://arxiv.org/abs/2104.01136">This paper</a> proposes a number of changes, including (1) convolutional embedding instead of patch-wise projection (2) downsampling in stages (3) extra non-linearity in attention (4) 2d relative positional biases instead of initial absolute positional bias (5) batchnorm in place of layernorm.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.levit import LeViT
|
||||
|
||||
levit = LeViT(
|
||||
image_size = 224,
|
||||
num_classes = 1000,
|
||||
stages = 3, # number of stages
|
||||
dim = (256, 384, 512), # dimensions at each stage
|
||||
depth = 4, # transformer of depth 4 at each stage
|
||||
heads = (4, 6, 8), # heads at each stage
|
||||
mlp_mult = 2,
|
||||
dropout = 0.1
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 224, 224)
|
||||
|
||||
levit(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## CvT
|
||||
|
||||
<img src="./images/cvt.png" width="400px"></img>
|
||||
|
||||
<a href="https://arxiv.org/abs/2103.15808">This paper</a> proposes mixing convolutions and attention. Specifically, convolutions are used to embed and downsample the image / feature map in three stages. Depthwise-convoltion is also used to project the queries, keys, and values for attention.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.cvt import CvT
|
||||
|
||||
v = CvT(
|
||||
num_classes = 1000,
|
||||
s1_emb_dim = 64, # stage 1 - dimension
|
||||
s1_emb_kernel = 7, # stage 1 - conv kernel
|
||||
s1_emb_stride = 4, # stage 1 - conv stride
|
||||
s1_proj_kernel = 3, # stage 1 - attention ds-conv kernel size
|
||||
s1_kv_proj_stride = 2, # stage 1 - attention key / value projection stride
|
||||
s1_heads = 1, # stage 1 - heads
|
||||
s1_depth = 1, # stage 1 - depth
|
||||
s1_mlp_mult = 4, # stage 1 - feedforward expansion factor
|
||||
s2_emb_dim = 192, # stage 2 - (same as above)
|
||||
s2_emb_kernel = 3,
|
||||
s2_emb_stride = 2,
|
||||
s2_proj_kernel = 3,
|
||||
s2_kv_proj_stride = 2,
|
||||
s2_heads = 3,
|
||||
s2_depth = 2,
|
||||
s2_mlp_mult = 4,
|
||||
s3_emb_dim = 384, # stage 3 - (same as above)
|
||||
s3_emb_kernel = 3,
|
||||
s3_emb_stride = 2,
|
||||
s3_proj_kernel = 3,
|
||||
s3_kv_proj_stride = 2,
|
||||
s3_heads = 4,
|
||||
s3_depth = 10,
|
||||
s3_mlp_mult = 4,
|
||||
dropout = 0.
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 224, 224)
|
||||
|
||||
pred = v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## Masked Patch Prediction
|
||||
|
||||
Thanks to <a href="https://github.com/zankner">Zach</a>, you can train using the original masked patch prediction task presented in the paper, with the following code.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch import ViT
|
||||
from vit_pytorch.mpp import MPP
|
||||
|
||||
model = ViT(
|
||||
image_size=256,
|
||||
patch_size=32,
|
||||
num_classes=1000,
|
||||
dim=1024,
|
||||
depth=6,
|
||||
heads=8,
|
||||
mlp_dim=2048,
|
||||
dropout=0.1,
|
||||
emb_dropout=0.1
|
||||
)
|
||||
|
||||
mpp_trainer = MPP(
|
||||
transformer=model,
|
||||
patch_size=32,
|
||||
dim=1024,
|
||||
mask_prob=0.15, # probability of using token in masked prediction task
|
||||
random_patch_prob=0.30, # probability of randomly replacing a token being used for mpp
|
||||
replace_prob=0.50, # probability of replacing a token being used for mpp with the mask token
|
||||
)
|
||||
|
||||
opt = torch.optim.Adam(mpp_trainer.parameters(), lr=3e-4)
|
||||
|
||||
def sample_unlabelled_images():
|
||||
return torch.randn(20, 3, 256, 256)
|
||||
|
||||
for _ in range(100):
|
||||
images = sample_unlabelled_images()
|
||||
loss = mpp_trainer(images)
|
||||
opt.zero_grad()
|
||||
loss.backward()
|
||||
opt.step()
|
||||
|
||||
# save your improved network
|
||||
torch.save(model.state_dict(), './pretrained-net.pt')
|
||||
```
|
||||
|
||||
## Accessing Attention
|
||||
|
||||
If you would like to visualize the attention weights (post-softmax) for your research, just follow the procedure below
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.vit import ViT
|
||||
|
||||
v = ViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
# import Recorder and wrap the ViT
|
||||
|
||||
from vit_pytorch.recorder import Recorder
|
||||
v = Recorder(v)
|
||||
|
||||
# forward pass now returns predictions and the attention maps
|
||||
|
||||
img = torch.randn(1, 3, 256, 256)
|
||||
preds, attns = v(img)
|
||||
|
||||
# there is one extra patch due to the CLS token
|
||||
|
||||
attns # (1, 6, 16, 65, 65) - (batch x layers x heads x patch x patch)
|
||||
```
|
||||
|
||||
to cleanup the class and the hooks once you have collected enough data
|
||||
|
||||
```python
|
||||
v = v.eject() # wrapper is discarded and original ViT instance is returned
|
||||
```
|
||||
|
||||
## Research Ideas
|
||||
@@ -299,12 +579,89 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
|
||||
```bibtex
|
||||
@misc{yuan2021tokenstotoken,
|
||||
title = {Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet},
|
||||
author = {Li Yuan and Yunpeng Chen and Tao Wang and Weihao Yu and Yujun Shi and Francis EH Tay and Jiashi Feng and Shuicheng Yan},
|
||||
year = {2021},
|
||||
eprint = {2101.11986},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
title = {Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet},
|
||||
author = {Li Yuan and Yunpeng Chen and Tao Wang and Weihao Yu and Yujun Shi and Francis EH Tay and Jiashi Feng and Shuicheng Yan},
|
||||
year = {2021},
|
||||
eprint = {2101.11986},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{zhou2021deepvit,
|
||||
title = {DeepViT: Towards Deeper Vision Transformer},
|
||||
author = {Daquan Zhou and Bingyi Kang and Xiaojie Jin and Linjie Yang and Xiaochen Lian and Qibin Hou and Jiashi Feng},
|
||||
year = {2021},
|
||||
eprint = {2103.11886},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{touvron2021going,
|
||||
title = {Going deeper with Image Transformers},
|
||||
author = {Hugo Touvron and Matthieu Cord and Alexandre Sablayrolles and Gabriel Synnaeve and Hervé Jégou},
|
||||
year = {2021},
|
||||
eprint = {2103.17239},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{chen2021crossvit,
|
||||
title = {CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification},
|
||||
author = {Chun-Fu Chen and Quanfu Fan and Rameswar Panda},
|
||||
year = {2021},
|
||||
eprint = {2103.14899},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{wu2021cvt,
|
||||
title = {CvT: Introducing Convolutions to Vision Transformers},
|
||||
author = {Haiping Wu and Bin Xiao and Noel Codella and Mengchen Liu and Xiyang Dai and Lu Yuan and Lei Zhang},
|
||||
year = {2021},
|
||||
eprint = {2103.15808},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{heo2021rethinking,
|
||||
title = {Rethinking Spatial Dimensions of Vision Transformers},
|
||||
author = {Byeongho Heo and Sangdoo Yun and Dongyoon Han and Sanghyuk Chun and Junsuk Choe and Seong Joon Oh},
|
||||
year = {2021},
|
||||
eprint = {2103.16302},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{graham2021levit,
|
||||
title = {LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference},
|
||||
author = {Ben Graham and Alaaeldin El-Nouby and Hugo Touvron and Pierre Stock and Armand Joulin and Hervé Jégou and Matthijs Douze},
|
||||
year = {2021},
|
||||
eprint = {2104.01136},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{li2021localvit,
|
||||
title = {LocalViT: Bringing Locality to Vision Transformers},
|
||||
author = {Yawei Li and Kai Zhang and Jiezhang Cao and Radu Timofte and Luc Van Gool},
|
||||
year = {2021},
|
||||
eprint = {2104.05707},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
BIN
images/cait.png
Normal file
|
After Width: | Height: | Size: 63 KiB |
BIN
images/cross_vit.png
Normal file
|
After Width: | Height: | Size: 82 KiB |
BIN
images/cvt.png
Normal file
|
After Width: | Height: | Size: 65 KiB |
|
Before Width: | Height: | Size: 49 KiB After Width: | Height: | Size: 49 KiB |
BIN
images/levit.png
Normal file
|
After Width: | Height: | Size: 71 KiB |
BIN
images/pit.png
Normal file
|
After Width: | Height: | Size: 24 KiB |
|
Before Width: | Height: | Size: 109 KiB After Width: | Height: | Size: 109 KiB |
|
Before Width: | Height: | Size: 5.8 MiB After Width: | Height: | Size: 5.8 MiB |
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.7.4',
|
||||
version = '0.16.0',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
|
||||
@@ -1 +1 @@
|
||||
from vit_pytorch.vit_pytorch import ViT
|
||||
from vit_pytorch.vit import ViT
|
||||
|
||||
177
vit_pytorch/cait.py
Normal file
@@ -0,0 +1,177 @@
|
||||
from random import randrange
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def dropout_layers(layers, dropout):
|
||||
if dropout == 0:
|
||||
return layers
|
||||
|
||||
num_layers = len(layers)
|
||||
to_drop = torch.zeros(num_layers).uniform_(0., 1.) < dropout
|
||||
|
||||
# make sure at least one layer makes it
|
||||
if all(to_drop):
|
||||
rand_index = randrange(num_layers)
|
||||
to_drop[rand_index] = False
|
||||
|
||||
layers = [layer for (layer, drop) in zip(layers, to_drop) if not drop]
|
||||
return layers
|
||||
|
||||
# classes
|
||||
|
||||
class LayerScale(nn.Module):
|
||||
def __init__(self, dim, fn, depth):
|
||||
super().__init__()
|
||||
if depth <= 18: # epsilon detailed in section 2 of paper
|
||||
init_eps = 0.1
|
||||
elif depth > 18 and depth <= 24:
|
||||
init_eps = 1e-5
|
||||
else:
|
||||
init_eps = 1e-6
|
||||
|
||||
scale = torch.zeros(1, 1, dim).fill_(init_eps)
|
||||
self.scale = nn.Parameter(scale)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(x, **kwargs) * self.scale
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
self.mix_heads_pre_attn = nn.Parameter(torch.randn(heads, heads))
|
||||
self.mix_heads_post_attn = nn.Parameter(torch.randn(heads, heads))
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, context = None):
|
||||
b, n, _, h = *x.shape, self.heads
|
||||
|
||||
context = x if not exists(context) else torch.cat((x, context), dim = 1)
|
||||
|
||||
qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
||||
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
|
||||
dots = einsum('b h i j, h g -> b g i j', dots, self.mix_heads_pre_attn) # talking heads, pre-softmax
|
||||
attn = self.attend(dots)
|
||||
attn = einsum('b h i j, h g -> b g i j', attn, self.mix_heads_post_attn) # talking heads, post-softmax
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., layer_dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
self.layer_dropout = layer_dropout
|
||||
|
||||
for ind in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
LayerScale(dim, PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), depth = ind + 1),
|
||||
LayerScale(dim, PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)), depth = ind + 1)
|
||||
]))
|
||||
def forward(self, x, context = None):
|
||||
layers = dropout_layers(self.layers, dropout = self.layer_dropout)
|
||||
|
||||
for attn, ff in layers:
|
||||
x = attn(x, context = context) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
class CaiT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_size,
|
||||
patch_size,
|
||||
num_classes,
|
||||
dim,
|
||||
depth,
|
||||
cls_depth,
|
||||
heads,
|
||||
mlp_dim,
|
||||
dim_head = 64,
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.,
|
||||
layer_dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
patch_dim = 3 * patch_size ** 2
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
||||
nn.Linear(patch_dim, dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.patch_transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, layer_dropout)
|
||||
self.cls_transformer = Transformer(dim, cls_depth, heads, dim_head, mlp_dim, dropout, layer_dropout)
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
x += self.pos_embedding[:, :n]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.patch_transformer(x)
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = self.cls_transformer(cls_tokens, context = x)
|
||||
|
||||
return self.mlp_head(x[:, 0])
|
||||
270
vit_pytorch/cross_vit.py
Normal file
@@ -0,0 +1,270 @@
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
# pre-layernorm
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
# feedforward
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
# attention
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, context = None, kv_include_self = False):
|
||||
b, n, _, h = *x.shape, self.heads
|
||||
context = default(context, x)
|
||||
|
||||
if kv_include_self:
|
||||
context = torch.cat((x, context), dim = 1) # cross attention requires CLS token includes itself as key / value
|
||||
|
||||
qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
||||
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
# transformer encoder, for small and large patches
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return self.norm(x)
|
||||
|
||||
# projecting CLS tokens, in the case that small and large patch tokens have different dimensions
|
||||
|
||||
class ProjectInOut(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
|
||||
need_projection = dim_in != dim_out
|
||||
self.project_in = nn.Linear(dim_in, dim_out) if need_projection else nn.Identity()
|
||||
self.project_out = nn.Linear(dim_out, dim_in) if need_projection else nn.Identity()
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
x = self.project_in(x)
|
||||
x = self.fn(x, *args, **kwargs)
|
||||
x = self.project_out(x)
|
||||
return x
|
||||
|
||||
# cross attention transformer
|
||||
|
||||
class CrossTransformer(nn.Module):
|
||||
def __init__(self, sm_dim, lg_dim, depth, heads, dim_head, dropout):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
ProjectInOut(sm_dim, lg_dim, PreNorm(lg_dim, Attention(lg_dim, heads = heads, dim_head = dim_head, dropout = dropout))),
|
||||
ProjectInOut(lg_dim, sm_dim, PreNorm(sm_dim, Attention(sm_dim, heads = heads, dim_head = dim_head, dropout = dropout)))
|
||||
]))
|
||||
|
||||
def forward(self, sm_tokens, lg_tokens):
|
||||
(sm_cls, sm_patch_tokens), (lg_cls, lg_patch_tokens) = map(lambda t: (t[:, :1], t[:, 1:]), (sm_tokens, lg_tokens))
|
||||
|
||||
for sm_attend_lg, lg_attend_sm in self.layers:
|
||||
sm_cls = sm_attend_lg(sm_cls, context = lg_patch_tokens, kv_include_self = True) + sm_cls
|
||||
lg_cls = lg_attend_sm(lg_cls, context = sm_patch_tokens, kv_include_self = True) + lg_cls
|
||||
|
||||
sm_tokens = torch.cat((sm_cls, sm_patch_tokens), dim = 1)
|
||||
lg_tokens = torch.cat((lg_cls, lg_patch_tokens), dim = 1)
|
||||
return sm_tokens, lg_tokens
|
||||
|
||||
# multi-scale encoder
|
||||
|
||||
class MultiScaleEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
depth,
|
||||
sm_dim,
|
||||
lg_dim,
|
||||
sm_enc_params,
|
||||
lg_enc_params,
|
||||
cross_attn_heads,
|
||||
cross_attn_depth,
|
||||
cross_attn_dim_head = 64,
|
||||
dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Transformer(dim = sm_dim, dropout = dropout, **sm_enc_params),
|
||||
Transformer(dim = lg_dim, dropout = dropout, **lg_enc_params),
|
||||
CrossTransformer(sm_dim = sm_dim, lg_dim = lg_dim, depth = cross_attn_depth, heads = cross_attn_heads, dim_head = cross_attn_dim_head, dropout = dropout)
|
||||
]))
|
||||
|
||||
def forward(self, sm_tokens, lg_tokens):
|
||||
for sm_enc, lg_enc, cross_attend in self.layers:
|
||||
sm_tokens, lg_tokens = sm_enc(sm_tokens), lg_enc(lg_tokens)
|
||||
sm_tokens, lg_tokens = cross_attend(sm_tokens, lg_tokens)
|
||||
|
||||
return sm_tokens, lg_tokens
|
||||
|
||||
# patch-based image to token embedder
|
||||
|
||||
class ImageEmbedder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
image_size,
|
||||
patch_size,
|
||||
dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
patch_dim = 3 * patch_size ** 2
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
||||
nn.Linear(patch_dim, dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, img):
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
|
||||
return self.dropout(x)
|
||||
|
||||
# cross ViT class
|
||||
|
||||
class CrossViT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_size,
|
||||
num_classes,
|
||||
sm_dim,
|
||||
lg_dim,
|
||||
sm_patch_size = 12,
|
||||
sm_enc_depth = 1,
|
||||
sm_enc_heads = 8,
|
||||
sm_enc_mlp_dim = 2048,
|
||||
sm_enc_dim_head = 64,
|
||||
lg_patch_size = 16,
|
||||
lg_enc_depth = 4,
|
||||
lg_enc_heads = 8,
|
||||
lg_enc_mlp_dim = 2048,
|
||||
lg_enc_dim_head = 64,
|
||||
cross_attn_depth = 2,
|
||||
cross_attn_heads = 8,
|
||||
cross_attn_dim_head = 64,
|
||||
depth = 3,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
):
|
||||
super().__init__()
|
||||
self.sm_image_embedder = ImageEmbedder(dim = sm_dim, image_size = image_size, patch_size = sm_patch_size, dropout = emb_dropout)
|
||||
self.lg_image_embedder = ImageEmbedder(dim = lg_dim, image_size = image_size, patch_size = lg_patch_size, dropout = emb_dropout)
|
||||
|
||||
self.multi_scale_encoder = MultiScaleEncoder(
|
||||
depth = depth,
|
||||
sm_dim = sm_dim,
|
||||
lg_dim = lg_dim,
|
||||
cross_attn_heads = cross_attn_heads,
|
||||
cross_attn_dim_head = cross_attn_dim_head,
|
||||
cross_attn_depth = cross_attn_depth,
|
||||
sm_enc_params = dict(
|
||||
depth = sm_enc_depth,
|
||||
heads = sm_enc_heads,
|
||||
mlp_dim = sm_enc_mlp_dim,
|
||||
dim_head = sm_enc_dim_head
|
||||
),
|
||||
lg_enc_params = dict(
|
||||
depth = lg_enc_depth,
|
||||
heads = lg_enc_heads,
|
||||
mlp_dim = lg_enc_mlp_dim,
|
||||
dim_head = lg_enc_dim_head
|
||||
),
|
||||
dropout = dropout
|
||||
)
|
||||
|
||||
self.sm_mlp_head = nn.Sequential(nn.LayerNorm(sm_dim), nn.Linear(sm_dim, num_classes))
|
||||
self.lg_mlp_head = nn.Sequential(nn.LayerNorm(lg_dim), nn.Linear(lg_dim, num_classes))
|
||||
|
||||
def forward(self, img):
|
||||
sm_tokens = self.sm_image_embedder(img)
|
||||
lg_tokens = self.lg_image_embedder(img)
|
||||
|
||||
sm_tokens, lg_tokens = self.multi_scale_encoder(sm_tokens, lg_tokens)
|
||||
|
||||
sm_cls, lg_cls = map(lambda t: t[:, 0], (sm_tokens, lg_tokens))
|
||||
|
||||
sm_logits = self.sm_mlp_head(sm_cls)
|
||||
lg_logits = self.lg_mlp_head(lg_cls)
|
||||
|
||||
return sm_logits + lg_logits
|
||||
162
vit_pytorch/cvt.py
Normal file
@@ -0,0 +1,162 @@
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helper methods
|
||||
|
||||
def group_dict_by_key(cond, d):
|
||||
return_val = [dict(), dict()]
|
||||
for key in d.keys():
|
||||
match = bool(cond(key))
|
||||
ind = int(not match)
|
||||
return_val[ind][key] = d[key]
|
||||
return (*return_val,)
|
||||
|
||||
def group_by_key_prefix_and_remove_prefix(prefix, d):
|
||||
kwargs_with_prefix, kwargs = group_dict_by_key(lambda x: x.startswith(prefix), d)
|
||||
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
||||
return kwargs_without_prefix, kwargs
|
||||
|
||||
# classes
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
x = rearrange(x, 'b c h w -> b h w c')
|
||||
x = self.norm(x)
|
||||
x = rearrange(x, 'b h w c -> b c h w')
|
||||
return self.fn(x, **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, mult = 4, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(dim, dim * mult, 1),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Conv2d(dim * mult, dim, 1),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class DepthWiseConv2d(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, kernel_size, padding, stride, bias = True):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
|
||||
nn.BatchNorm2d(dim_in),
|
||||
nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, proj_kernel, kv_proj_stride, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
padding = proj_kernel // 2
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
self.to_q = DepthWiseConv2d(dim, inner_dim, 3, padding = padding, stride = 1, bias = False)
|
||||
self.to_kv = DepthWiseConv2d(dim, inner_dim * 2, 3, padding = padding, stride = kv_proj_stride, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Conv2d(inner_dim, dim, 1),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
shape = x.shape
|
||||
b, n, _, y, h = *shape, self.heads
|
||||
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = 1))
|
||||
q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> (b h) (x y) d', h = h), (q, k, v))
|
||||
|
||||
dots = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = einsum('b i j, b j d -> b i d', attn, v)
|
||||
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, y = y)
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, proj_kernel, kv_proj_stride, depth, heads, dim_head = 64, mlp_mult = 4, dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, proj_kernel = proj_kernel, kv_proj_stride = kv_proj_stride, heads = heads, dim_head = dim_head, dropout = dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout))
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
class CvT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
num_classes,
|
||||
s1_emb_dim = 64,
|
||||
s1_emb_kernel = 7,
|
||||
s1_emb_stride = 4,
|
||||
s1_proj_kernel = 3,
|
||||
s1_kv_proj_stride = 2,
|
||||
s1_heads = 1,
|
||||
s1_depth = 1,
|
||||
s1_mlp_mult = 4,
|
||||
s2_emb_dim = 192,
|
||||
s2_emb_kernel = 3,
|
||||
s2_emb_stride = 2,
|
||||
s2_proj_kernel = 3,
|
||||
s2_kv_proj_stride = 2,
|
||||
s2_heads = 3,
|
||||
s2_depth = 2,
|
||||
s2_mlp_mult = 4,
|
||||
s3_emb_dim = 384,
|
||||
s3_emb_kernel = 3,
|
||||
s3_emb_stride = 2,
|
||||
s3_proj_kernel = 3,
|
||||
s3_kv_proj_stride = 2,
|
||||
s3_heads = 4,
|
||||
s3_depth = 10,
|
||||
s3_mlp_mult = 4,
|
||||
dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
kwargs = dict(locals())
|
||||
|
||||
dim = 3
|
||||
layers = []
|
||||
|
||||
for prefix in ('s1', 's2', 's3'):
|
||||
config, kwargs = group_by_key_prefix_and_remove_prefix(f'{prefix}_', kwargs)
|
||||
|
||||
layers.append(nn.Sequential(
|
||||
nn.Conv2d(dim, config['emb_dim'], kernel_size = config['emb_kernel'], padding = (config['emb_kernel'] // 2), stride = config['emb_stride']),
|
||||
Transformer(dim = config['emb_dim'], proj_kernel = config['proj_kernel'], kv_proj_stride = config['kv_proj_stride'], depth = config['depth'], heads = config['heads'], mlp_mult = config['mlp_mult'], dropout = dropout)
|
||||
))
|
||||
|
||||
dim = config['emb_dim']
|
||||
|
||||
self.layers = nn.Sequential(
|
||||
*layers,
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
Rearrange('... () () -> ...'),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
@@ -37,35 +37,41 @@ class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.reattn_weights = nn.Parameter(torch.randn(heads, heads))
|
||||
|
||||
self.reattn_norm = nn.Sequential(
|
||||
Rearrange('b h i j -> b i j h'),
|
||||
nn.LayerNorm(heads),
|
||||
Rearrange('b i j h -> b h i j')
|
||||
)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, x, mask = None):
|
||||
def forward(self, x):
|
||||
b, n, _, h = *x.shape, self.heads
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
||||
|
||||
# attention
|
||||
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
mask_value = -torch.finfo(dots.dtype).max
|
||||
|
||||
if mask is not None:
|
||||
mask = F.pad(mask.flatten(1), (1, 0), value = True)
|
||||
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
|
||||
mask = rearrange(mask, 'b i -> b () i ()') * rearrange(mask, 'b j -> b () () j')
|
||||
dots.masked_fill_(~mask, mask_value)
|
||||
del mask
|
||||
|
||||
attn = dots.softmax(dim=-1)
|
||||
|
||||
# re-attention
|
||||
|
||||
attn = einsum('b h i j, h g -> b g i j', attn, self.reattn_weights)
|
||||
attn = self.reattn_norm(attn)
|
||||
|
||||
# aggregate and out
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
out = self.to_out(out)
|
||||
@@ -80,13 +86,13 @@ class Transformer(nn.Module):
|
||||
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
|
||||
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
|
||||
]))
|
||||
def forward(self, x, mask = None):
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x, mask = mask)
|
||||
x = attn(x)
|
||||
x = ff(x)
|
||||
return x
|
||||
|
||||
class ViT(nn.Module):
|
||||
class DeepViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
|
||||
super().__init__()
|
||||
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
@@ -113,7 +119,7 @@ class ViT(nn.Module):
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img, mask = None):
|
||||
def forward(self, img):
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
@@ -122,7 +128,7 @@ class ViT(nn.Module):
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x, mask)
|
||||
x = self.transformer(x)
|
||||
|
||||
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from vit_pytorch.vit_pytorch import ViT
|
||||
from vit_pytorch.vit import ViT
|
||||
from vit_pytorch.t2t import T2TViT
|
||||
from vit_pytorch.efficient import ViT as EfficientViT
|
||||
|
||||
@@ -15,7 +15,7 @@ def exists(val):
|
||||
# classes
|
||||
|
||||
class DistillMixin:
|
||||
def forward(self, img, distill_token = None, mask = None):
|
||||
def forward(self, img, distill_token = None):
|
||||
distilling = exists(distill_token)
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
@@ -28,7 +28,7 @@ class DistillMixin:
|
||||
distill_tokens = repeat(distill_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((x, distill_tokens), dim = 1)
|
||||
|
||||
x = self._attend(x, mask)
|
||||
x = self._attend(x)
|
||||
|
||||
if distilling:
|
||||
x, distill_tokens = x[:, :-1], x[:, -1]
|
||||
@@ -56,9 +56,9 @@ class DistillableViT(DistillMixin, ViT):
|
||||
v.load_state_dict(self.state_dict())
|
||||
return v
|
||||
|
||||
def _attend(self, x, mask):
|
||||
def _attend(self, x):
|
||||
x = self.dropout(x)
|
||||
x = self.transformer(x, mask)
|
||||
x = self.transformer(x)
|
||||
return x
|
||||
|
||||
class DistillableT2TViT(DistillMixin, T2TViT):
|
||||
@@ -74,7 +74,7 @@ class DistillableT2TViT(DistillMixin, T2TViT):
|
||||
v.load_state_dict(self.state_dict())
|
||||
return v
|
||||
|
||||
def _attend(self, x, mask):
|
||||
def _attend(self, x):
|
||||
x = self.dropout(x)
|
||||
x = self.transformer(x)
|
||||
return x
|
||||
@@ -92,7 +92,7 @@ class DistillableEfficientViT(DistillMixin, EfficientViT):
|
||||
v.load_state_dict(self.state_dict())
|
||||
return v
|
||||
|
||||
def _attend(self, x, mask):
|
||||
def _attend(self, x):
|
||||
return self.transformer(x)
|
||||
|
||||
# knowledge distillation wrapper
|
||||
@@ -104,7 +104,8 @@ class DistillWrapper(nn.Module):
|
||||
teacher,
|
||||
student,
|
||||
temperature = 1.,
|
||||
alpha = 0.5
|
||||
alpha = 0.5,
|
||||
hard = False
|
||||
):
|
||||
super().__init__()
|
||||
assert (isinstance(student, (DistillableViT, DistillableT2TViT, DistillableEfficientViT))) , 'student must be a vision transformer'
|
||||
@@ -116,6 +117,7 @@ class DistillWrapper(nn.Module):
|
||||
num_classes = student.num_classes
|
||||
self.temperature = temperature
|
||||
self.alpha = alpha
|
||||
self.hard = hard
|
||||
|
||||
self.distillation_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
|
||||
@@ -137,11 +139,15 @@ class DistillWrapper(nn.Module):
|
||||
|
||||
loss = F.cross_entropy(student_logits, labels)
|
||||
|
||||
distill_loss = F.kl_div(
|
||||
F.log_softmax(distill_logits / T, dim = -1),
|
||||
F.softmax(teacher_logits / T, dim = -1).detach(),
|
||||
reduction = 'batchmean')
|
||||
if not self.hard:
|
||||
distill_loss = F.kl_div(
|
||||
F.log_softmax(distill_logits / T, dim = -1),
|
||||
F.softmax(teacher_logits / T, dim = -1).detach(),
|
||||
reduction = 'batchmean')
|
||||
distill_loss *= T ** 2
|
||||
|
||||
distill_loss *= T ** 2
|
||||
else:
|
||||
teacher_labels = teacher_logits.argmax(dim = -1)
|
||||
distill_loss = F.cross_entropy(student_logits, teacher_labels)
|
||||
|
||||
return loss * alpha + distill_loss * (1 - alpha)
|
||||
|
||||
190
vit_pytorch/levit.py
Normal file
@@ -0,0 +1,190 @@
|
||||
from math import ceil
|
||||
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def cast_tuple(val, l = 3):
|
||||
val = val if isinstance(val, tuple) else (val,)
|
||||
return (*val, *((val[-1],) * max(l - len(val), 0)))
|
||||
|
||||
def always(val):
|
||||
return lambda *args, **kwargs: val
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, mult, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(dim, dim * mult, 1),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Conv2d(dim * mult, dim, 1),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, fmap_size, heads = 8, dim_key = 32, dim_value = 64, dropout = 0., dim_out = None, downsample = False):
|
||||
super().__init__()
|
||||
inner_dim_key = dim_key * heads
|
||||
inner_dim_value = dim_value * heads
|
||||
dim_out = default(dim_out, dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_key ** -0.5
|
||||
|
||||
self.to_q = nn.Sequential(nn.Conv2d(dim, inner_dim_key, 1, stride = (2 if downsample else 1), bias = False), nn.BatchNorm2d(inner_dim_key))
|
||||
self.to_k = nn.Sequential(nn.Conv2d(dim, inner_dim_key, 1, bias = False), nn.BatchNorm2d(inner_dim_key))
|
||||
self.to_v = nn.Sequential(nn.Conv2d(dim, inner_dim_value, 1, bias = False), nn.BatchNorm2d(inner_dim_value))
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.GELU(),
|
||||
nn.Conv2d(inner_dim_value, dim_out, 1),
|
||||
nn.BatchNorm2d(dim_out),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
# positional bias
|
||||
|
||||
self.pos_bias = nn.Embedding(fmap_size * fmap_size, heads)
|
||||
|
||||
q_range = torch.arange(0, fmap_size, step = (2 if downsample else 1))
|
||||
k_range = torch.arange(fmap_size)
|
||||
|
||||
q_pos = torch.stack(torch.meshgrid(q_range, q_range), dim = -1)
|
||||
k_pos = torch.stack(torch.meshgrid(k_range, k_range), dim = -1)
|
||||
|
||||
q_pos, k_pos = map(lambda t: rearrange(t, 'i j c -> (i j) c'), (q_pos, k_pos))
|
||||
rel_pos = (q_pos[:, None, ...] - k_pos[None, :, ...]).abs()
|
||||
|
||||
x_rel, y_rel = rel_pos.unbind(dim = -1)
|
||||
pos_indices = (x_rel * fmap_size) + y_rel
|
||||
|
||||
self.register_buffer('pos_indices', pos_indices)
|
||||
|
||||
def apply_pos_bias(self, fmap):
|
||||
bias = self.pos_bias(self.pos_indices)
|
||||
bias = rearrange(bias, 'i j h -> () h i j')
|
||||
return fmap + bias
|
||||
|
||||
def forward(self, x):
|
||||
b, n, *_, h = *x.shape, self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
y = q.shape[2]
|
||||
|
||||
qkv = (q, self.to_k(x), self.to_v(x))
|
||||
q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = h), qkv)
|
||||
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
|
||||
dots = self.apply_pos_bias(dots)
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h (x y) d -> b (h d) x y', h = h, y = y)
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, fmap_size, depth, heads, dim_key, dim_value, mlp_mult = 2, dropout = 0., dim_out = None, downsample = False):
|
||||
super().__init__()
|
||||
dim_out = default(dim_out, dim)
|
||||
self.layers = nn.ModuleList([])
|
||||
self.attn_residual = (not downsample) and dim == dim_out
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, fmap_size = fmap_size, heads = heads, dim_key = dim_key, dim_value = dim_value, dropout = dropout, downsample = downsample, dim_out = dim_out),
|
||||
FeedForward(dim_out, mlp_mult, dropout = dropout)
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
attn_res = (x if self.attn_residual else 0)
|
||||
x = attn(x) + attn_res
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
class LeViT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_size,
|
||||
num_classes,
|
||||
dim,
|
||||
depth,
|
||||
heads,
|
||||
mlp_mult,
|
||||
stages = 3,
|
||||
dim_key = 32,
|
||||
dim_value = 64,
|
||||
dropout = 0.,
|
||||
num_distill_classes = None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
dims = cast_tuple(dim, stages)
|
||||
depths = cast_tuple(depth, stages)
|
||||
layer_heads = cast_tuple(heads, stages)
|
||||
|
||||
assert all(map(lambda t: len(t) == stages, (dims, depths, layer_heads))), 'dimensions, depths, and heads must be a tuple that is less than the designated number of stages'
|
||||
|
||||
self.conv_embedding = nn.Sequential(
|
||||
nn.Conv2d(3, 32, 3, stride = 2, padding = 1),
|
||||
nn.Conv2d(32, 64, 3, stride = 2, padding = 1),
|
||||
nn.Conv2d(64, 128, 3, stride = 2, padding = 1),
|
||||
nn.Conv2d(128, dims[0], 3, stride = 2, padding = 1)
|
||||
)
|
||||
|
||||
fmap_size = image_size // (2 ** 4)
|
||||
layers = []
|
||||
|
||||
for ind, dim, depth, heads in zip(range(stages), dims, depths, layer_heads):
|
||||
is_last = ind == (stages - 1)
|
||||
layers.append(Transformer(dim, fmap_size, depth, heads, dim_key, dim_value, mlp_mult, dropout))
|
||||
|
||||
if not is_last:
|
||||
next_dim = dims[ind + 1]
|
||||
layers.append(Transformer(dim, fmap_size, 1, heads * 2, dim_key, dim_value, dim_out = next_dim, downsample = True))
|
||||
fmap_size = ceil(fmap_size / 2)
|
||||
|
||||
self.backbone = nn.Sequential(*layers)
|
||||
|
||||
self.pool = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
Rearrange('... () () -> ...')
|
||||
)
|
||||
|
||||
self.distill_head = nn.Linear(dim, num_distill_classes) if exists(num_distill_classes) else always(None)
|
||||
self.mlp_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, img):
|
||||
x = self.conv_embedding(img)
|
||||
|
||||
x = self.backbone(x)
|
||||
|
||||
x = self.pool(x)
|
||||
|
||||
out = self.mlp_head(x)
|
||||
distill = self.distill_head(x)
|
||||
|
||||
if exists(distill):
|
||||
return out, distill
|
||||
|
||||
return out
|
||||
152
vit_pytorch/local_vit.py
Normal file
@@ -0,0 +1,152 @@
|
||||
from math import sqrt
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# classes
|
||||
|
||||
class Residual(nn.Module):
|
||||
def __init__(self, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(x, **kwargs) + x
|
||||
|
||||
class ExcludeCLS(nn.Module):
|
||||
def __init__(self, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
cls_token, x = x[:, :1], x[:, 1:]
|
||||
x = self.fn(x, **kwargs)
|
||||
return torch.cat((cls_token, x), dim = 1)
|
||||
|
||||
# prenorm
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
# feed forward related classes
|
||||
|
||||
class DepthWiseConv2d(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, kernel_size, padding, stride = 1, bias = True):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
|
||||
nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(dim, hidden_dim, 1),
|
||||
nn.Hardswish(),
|
||||
DepthWiseConv2d(hidden_dim, hidden_dim, 3, padding = 1),
|
||||
nn.Hardswish(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Conv2d(hidden_dim, dim, 1),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
h = w = int(sqrt(x.shape[-2]))
|
||||
x = rearrange(x, 'b (h w) c -> b c h w', h = h, w = w)
|
||||
x = self.net(x)
|
||||
x = rearrange(x, 'b c h w -> b (h w) c')
|
||||
return x
|
||||
|
||||
# attention
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
b, n, _, h = *x.shape, self.heads
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
||||
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
|
||||
ExcludeCLS(Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))))
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x)
|
||||
x = ff(x)
|
||||
return x
|
||||
|
||||
# main class
|
||||
|
||||
class LocalViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
|
||||
super().__init__()
|
||||
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
patch_dim = channels * patch_size ** 2
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
||||
nn.Linear(patch_dim, dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
return self.mlp_head(x)
|
||||
166
vit_pytorch/mpp.py
Normal file
@@ -0,0 +1,166 @@
|
||||
import math
|
||||
from functools import reduce
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
# helpers
|
||||
|
||||
|
||||
def prob_mask_like(t, prob):
|
||||
batch, seq_length, _ = t.shape
|
||||
return torch.zeros((batch, seq_length)).float().uniform_(0, 1) < prob
|
||||
|
||||
|
||||
def get_mask_subset_with_prob(patched_input, prob):
|
||||
batch, seq_len, _, device = *patched_input.shape, patched_input.device
|
||||
max_masked = math.ceil(prob * seq_len)
|
||||
|
||||
rand = torch.rand((batch, seq_len), device=device)
|
||||
_, sampled_indices = rand.topk(max_masked, dim=-1)
|
||||
|
||||
new_mask = torch.zeros((batch, seq_len), device=device)
|
||||
new_mask.scatter_(1, sampled_indices, 1)
|
||||
return new_mask.bool()
|
||||
|
||||
|
||||
# mpp loss
|
||||
|
||||
|
||||
class MPPLoss(nn.Module):
|
||||
def __init__(self, patch_size, channels, output_channel_bits,
|
||||
max_pixel_val):
|
||||
super(MPPLoss, self).__init__()
|
||||
self.patch_size = patch_size
|
||||
self.channels = channels
|
||||
self.output_channel_bits = output_channel_bits
|
||||
self.max_pixel_val = max_pixel_val
|
||||
|
||||
def forward(self, predicted_patches, target, mask):
|
||||
# reshape target to patches
|
||||
p = self.patch_size
|
||||
target = rearrange(target,
|
||||
"b c (h p1) (w p2) -> b (h w) c (p1 p2) ",
|
||||
p1=p,
|
||||
p2=p)
|
||||
|
||||
avg_target = target.mean(dim=3)
|
||||
|
||||
bin_size = self.max_pixel_val / self.output_channel_bits
|
||||
channel_bins = torch.arange(bin_size, self.max_pixel_val, bin_size)
|
||||
discretized_target = torch.bucketize(avg_target, channel_bins)
|
||||
discretized_target = F.one_hot(discretized_target,
|
||||
self.output_channel_bits)
|
||||
c, bi = self.channels, self.output_channel_bits
|
||||
discretized_target = rearrange(discretized_target,
|
||||
"b n c bi -> b n (c bi)",
|
||||
c=c,
|
||||
bi=bi)
|
||||
|
||||
bin_mask = 2**torch.arange(c * bi - 1, -1,
|
||||
-1).to(discretized_target.device,
|
||||
discretized_target.dtype)
|
||||
target_label = torch.sum(bin_mask * discretized_target, -1)
|
||||
|
||||
predicted_patches = predicted_patches[mask]
|
||||
target_label = target_label[mask]
|
||||
loss = F.cross_entropy(predicted_patches, target_label)
|
||||
return loss
|
||||
|
||||
|
||||
# main class
|
||||
|
||||
|
||||
class MPP(nn.Module):
|
||||
def __init__(self,
|
||||
transformer,
|
||||
patch_size,
|
||||
dim,
|
||||
output_channel_bits=3,
|
||||
channels=3,
|
||||
max_pixel_val=1.0,
|
||||
mask_prob=0.15,
|
||||
replace_prob=0.5,
|
||||
random_patch_prob=0.5):
|
||||
super().__init__()
|
||||
|
||||
self.transformer = transformer
|
||||
self.loss = MPPLoss(patch_size, channels, output_channel_bits,
|
||||
max_pixel_val)
|
||||
|
||||
# output transformation
|
||||
self.to_bits = nn.Linear(dim, 2**(output_channel_bits * channels))
|
||||
|
||||
# vit related dimensions
|
||||
self.patch_size = patch_size
|
||||
|
||||
# mpp related probabilities
|
||||
self.mask_prob = mask_prob
|
||||
self.replace_prob = replace_prob
|
||||
self.random_patch_prob = random_patch_prob
|
||||
|
||||
# token ids
|
||||
self.mask_token = nn.Parameter(torch.randn(1, 1, dim * channels))
|
||||
|
||||
def forward(self, input, **kwargs):
|
||||
transformer = self.transformer
|
||||
# clone original image for loss
|
||||
img = input.clone().detach()
|
||||
|
||||
# reshape raw image to patches
|
||||
p = self.patch_size
|
||||
input = rearrange(input,
|
||||
'b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
|
||||
p1=p,
|
||||
p2=p)
|
||||
|
||||
mask = get_mask_subset_with_prob(input, self.mask_prob)
|
||||
|
||||
# mask input with mask patches with probability of `replace_prob` (keep patches the same with probability 1 - replace_prob)
|
||||
masked_input = input.clone().detach()
|
||||
|
||||
# if random token probability > 0 for mpp
|
||||
if self.random_patch_prob > 0:
|
||||
random_patch_sampling_prob = self.random_patch_prob / (
|
||||
1 - self.replace_prob)
|
||||
random_patch_prob = prob_mask_like(input,
|
||||
random_patch_sampling_prob)
|
||||
bool_random_patch_prob = mask * random_patch_prob == True
|
||||
random_patches = torch.randint(0,
|
||||
input.shape[1],
|
||||
(input.shape[0], input.shape[1]),
|
||||
device=input.device)
|
||||
randomized_input = masked_input[
|
||||
torch.arange(masked_input.shape[0]).unsqueeze(-1),
|
||||
random_patches]
|
||||
masked_input[bool_random_patch_prob] = randomized_input[
|
||||
bool_random_patch_prob]
|
||||
|
||||
# [mask] input
|
||||
replace_prob = prob_mask_like(input, self.replace_prob)
|
||||
bool_mask_replace = (mask * replace_prob) == True
|
||||
masked_input[bool_mask_replace] = self.mask_token
|
||||
|
||||
# linear embedding of patches
|
||||
masked_input = transformer.to_patch_embedding[-1](masked_input)
|
||||
|
||||
# add cls token to input sequence
|
||||
b, n, _ = masked_input.shape
|
||||
cls_tokens = repeat(transformer.cls_token, '() n d -> b n d', b=b)
|
||||
masked_input = torch.cat((cls_tokens, masked_input), dim=1)
|
||||
|
||||
# add positional embeddings to input
|
||||
masked_input += transformer.pos_embedding[:, :(n + 1)]
|
||||
masked_input = transformer.dropout(masked_input)
|
||||
|
||||
# get generator output and get mpp loss
|
||||
masked_input = transformer.transformer(masked_input, **kwargs)
|
||||
cls_logits = self.to_bits(masked_input)
|
||||
logits = cls_logits[:, 1:, :]
|
||||
|
||||
mpp_loss = self.loss(logits, img, mask)
|
||||
|
||||
return mpp_loss
|
||||
180
vit_pytorch/pit.py
Normal file
@@ -0,0 +1,180 @@
|
||||
from math import sqrt
|
||||
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def cast_tuple(val, num):
|
||||
return val if isinstance(val, tuple) else (val,) * num
|
||||
|
||||
def conv_output_size(image_size, kernel_size, stride, padding = 0):
|
||||
return int(((image_size - kernel_size + (2 * padding)) / stride) + 1)
|
||||
|
||||
# classes
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
b, n, _, h = *x.shape, self.heads
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
||||
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
# depthwise convolution, for pooling
|
||||
|
||||
class DepthWiseConv2d(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, kernel_size, padding, stride, bias = True):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
|
||||
nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
# pooling layer
|
||||
|
||||
class Pool(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.downsample = DepthWiseConv2d(dim, dim * 2, kernel_size = 3, stride = 2, padding = 1)
|
||||
self.cls_ff = nn.Linear(dim, dim * 2)
|
||||
|
||||
def forward(self, x):
|
||||
cls_token, tokens = x[:, :1], x[:, 1:]
|
||||
|
||||
cls_token = self.cls_ff(cls_token)
|
||||
|
||||
tokens = rearrange(tokens, 'b (h w) c -> b c h w', h = int(sqrt(tokens.shape[1])))
|
||||
tokens = self.downsample(tokens)
|
||||
tokens = rearrange(tokens, 'b c h w -> b (h w) c')
|
||||
|
||||
return torch.cat((cls_token, tokens), dim = 1)
|
||||
|
||||
# main class
|
||||
|
||||
class PiT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_size,
|
||||
patch_size,
|
||||
num_classes,
|
||||
dim,
|
||||
depth,
|
||||
heads,
|
||||
mlp_dim,
|
||||
dim_head = 64,
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
assert isinstance(depth, tuple), 'depth must be a tuple of integers, specifying the number of blocks before each downsizing'
|
||||
heads = cast_tuple(heads, len(depth))
|
||||
|
||||
patch_dim = 3 * patch_size ** 2
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
nn.Unfold(kernel_size = patch_size, stride = patch_size // 2),
|
||||
Rearrange('b c n -> b n c'),
|
||||
nn.Linear(patch_dim, dim)
|
||||
)
|
||||
|
||||
output_size = conv_output_size(image_size, patch_size, patch_size // 2)
|
||||
num_patches = output_size ** 2
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
layers = []
|
||||
|
||||
for ind, (layer_depth, layer_heads) in enumerate(zip(depth, heads)):
|
||||
not_last = ind < (len(depth) - 1)
|
||||
|
||||
layers.append(Transformer(dim, layer_depth, layer_heads, dim_head, mlp_dim, dropout))
|
||||
|
||||
if not_last:
|
||||
layers.append(Pool(dim))
|
||||
dim *= 2
|
||||
|
||||
self.layers = nn.Sequential(
|
||||
*layers,
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding
|
||||
x = self.dropout(x)
|
||||
|
||||
return self.layers(x)
|
||||
54
vit_pytorch/recorder.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from functools import wraps
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vit_pytorch.vit import Attention
|
||||
|
||||
def find_modules(nn_module, type):
|
||||
return [module for module in nn_module.modules() if isinstance(module, type)]
|
||||
|
||||
class Recorder(nn.Module):
|
||||
def __init__(self, vit):
|
||||
super().__init__()
|
||||
self.vit = vit
|
||||
|
||||
self.data = None
|
||||
self.recordings = []
|
||||
self.hooks = []
|
||||
self.hook_registered = False
|
||||
self.ejected = False
|
||||
|
||||
def _hook(self, _, input, output):
|
||||
self.recordings.append(output.clone().detach())
|
||||
|
||||
def _register_hook(self):
|
||||
modules = find_modules(self.vit.transformer, Attention)
|
||||
for module in modules:
|
||||
handle = module.attend.register_forward_hook(self._hook)
|
||||
self.hooks.append(handle)
|
||||
self.hook_registered = True
|
||||
|
||||
def eject(self):
|
||||
self.ejected = True
|
||||
for hook in self.hooks:
|
||||
hook.remove()
|
||||
self.hooks.clear()
|
||||
return self.vit
|
||||
|
||||
def clear(self):
|
||||
self.recordings.clear()
|
||||
|
||||
def record(self, attn):
|
||||
recording = attn.clone().detach()
|
||||
self.recordings.append(recording)
|
||||
|
||||
def forward(self, img):
|
||||
assert not self.ejected, 'recorder has been ejected, cannot be used anymore'
|
||||
self.clear()
|
||||
|
||||
if not self.hook_registered:
|
||||
self._register_hook()
|
||||
|
||||
pred = self.vit(img)
|
||||
attns = torch.stack(self.recordings, dim = 1)
|
||||
return pred, attns
|
||||
163
vit_pytorch/rvt.py
Normal file
@@ -0,0 +1,163 @@
|
||||
from math import sqrt, pi, log
|
||||
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# rotary embeddings
|
||||
|
||||
def rotate_every_two(x):
|
||||
x = rearrange(x, '... (d j) -> ... d j', j = 2)
|
||||
x1, x2 = x.unbind(dim = -1)
|
||||
x = torch.stack((-x2, x1), dim = -1)
|
||||
return rearrange(x, '... d j -> ... (d j)')
|
||||
|
||||
class AxialRotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_freq = 10):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
scales = torch.logspace(1., log(max_freq / 2) / log(2), self.dim // 4, base = 2)
|
||||
self.register_buffer('scales', scales)
|
||||
|
||||
def forward(self, x):
|
||||
device, dtype, n = x.device, x.dtype, int(sqrt(x.shape[-2]))
|
||||
|
||||
seq = torch.linspace(-1., 1., steps = n, device = device)
|
||||
seq = seq.unsqueeze(-1)
|
||||
|
||||
scales = self.scales[(*((None,) * (len(seq.shape) - 1)), Ellipsis)]
|
||||
scales = scales.to(x)
|
||||
|
||||
seq = seq * scales * pi
|
||||
|
||||
x_sinu = repeat(seq, 'i d -> i j d', j = n)
|
||||
y_sinu = repeat(seq, 'j d -> i j d', i = n)
|
||||
|
||||
sin = torch.cat((x_sinu.sin(), y_sinu.sin()), dim = -1)
|
||||
cos = torch.cat((x_sinu.cos(), y_sinu.cos()), dim = -1)
|
||||
|
||||
sin, cos = map(lambda t: rearrange(t, 'i j d -> (i j) d'), (sin, cos))
|
||||
sin, cos = map(lambda t: repeat(t, 'n d -> () n (d j)', j = 2), (sin, cos))
|
||||
return sin, cos
|
||||
|
||||
# helper classes
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
class GEGLU(nn.Module):
|
||||
def forward(self, x):
|
||||
x, gates = x.chunk(2, dim = -1)
|
||||
return F.gelu(gates) * x
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim * 2),
|
||||
GEGLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, pos_emb):
|
||||
b, n, _, h = *x.shape, self.heads
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)
|
||||
|
||||
# apply 2d rotary embeddings to queries and keys, excluding CLS tokens
|
||||
|
||||
sin, cos = pos_emb
|
||||
(q_cls, q), (k_cls, k) = map(lambda t: (t[:, :1], t[:, 1:]), (q, k))
|
||||
q, k = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k))
|
||||
|
||||
# concat back the CLS tokens
|
||||
|
||||
q = torch.cat((q_cls, q), dim = 1)
|
||||
k = torch.cat((k_cls, k), dim = 1)
|
||||
|
||||
dots = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = einsum('b i j, b j d -> b i d', attn, v)
|
||||
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
self.pos_emb = AxialRotaryEmbedding(dim_head)
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
|
||||
]))
|
||||
def forward(self, x):
|
||||
pos_emb = self.pos_emb(x[:, 1:])
|
||||
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x, pos_emb = pos_emb) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
# Rotary Vision Transformer
|
||||
|
||||
class RvT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
|
||||
super().__init__()
|
||||
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
patch_dim = channels * patch_size ** 2
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
||||
nn.Linear(patch_dim, dim),
|
||||
)
|
||||
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
return self.mlp_head(x)
|
||||
@@ -2,16 +2,21 @@ import math
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vit_pytorch.vit_pytorch import Transformer
|
||||
from vit_pytorch.vit import Transformer
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# classes
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def conv_output_size(image_size, kernel_size, stride, padding):
|
||||
return int(((image_size - kernel_size + (2 * padding)) / stride) + 1)
|
||||
|
||||
# classes
|
||||
|
||||
class RearrangeImage(nn.Module):
|
||||
def forward(self, x):
|
||||
return rearrange(x, 'b (h w) c -> b c h w', h = int(math.sqrt(x.shape[1])))
|
||||
@@ -19,8 +24,7 @@ class RearrangeImage(nn.Module):
|
||||
# main class
|
||||
|
||||
class T2TViT(nn.Module):
|
||||
def __init__(
|
||||
self, *, image_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., t2t_layers = ((7, 4), (3, 2), (3, 2))):
|
||||
def __init__(self, *, image_size, num_classes, dim, depth = None, heads = None, mlp_dim = None, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., transformer = None, t2t_layers = ((7, 4), (3, 2), (3, 2))):
|
||||
super().__init__()
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
|
||||
@@ -47,7 +51,11 @@ class T2TViT(nn.Module):
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
if not exists(transformer):
|
||||
assert all([exists(depth), exists(heads), exists(mlp_dim)]), 'depth, heads, and mlp_dim must be supplied'
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
else:
|
||||
self.transformer = transformer
|
||||
|
||||
self.pool = pool
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
115
vit_pytorch/vit.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
b, n, _, h = *x.shape, self.heads
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
||||
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
class ViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
|
||||
super().__init__()
|
||||
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
patch_dim = channels * patch_size ** 2
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
||||
nn.Linear(patch_dim, dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
self.pool = pool
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.mlp_head(x)
|
||||