Compare commits
67 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
60ad4e266e | ||
|
|
a254a0258a | ||
|
|
26df10c0b7 | ||
|
|
17cb8976df | ||
|
|
daf3abbeb5 | ||
|
|
b483b16833 | ||
|
|
c457573808 | ||
|
|
e75b6d0251 | ||
|
|
679e5be3e7 | ||
|
|
7333979e6b | ||
|
|
74b402377b | ||
|
|
41d2d460d0 | ||
|
|
04f86dee3c | ||
|
|
6549522629 | ||
|
|
6a80a4ef89 | ||
|
|
9f05587a7d | ||
|
|
65bb350e85 | ||
|
|
fd4a7dfcf8 | ||
|
|
6f3a5fcf0b | ||
|
|
7807f24509 | ||
|
|
a612327126 | ||
|
|
30a1335d31 | ||
|
|
ab781f7ddb | ||
|
|
4f3dbd003f | ||
|
|
60b5687a79 | ||
|
|
0df1505662 | ||
|
|
3df6c31c61 | ||
|
|
54af220930 | ||
|
|
bad4b94e7b | ||
|
|
fbced01fe7 | ||
|
|
e42e9876bc | ||
|
|
566365978d | ||
|
|
34f78294d3 | ||
|
|
4c29328363 | ||
|
|
27ac10c1f1 | ||
|
|
fa216c45ea | ||
|
|
1d8b7826bf | ||
|
|
53b3af05f6 | ||
|
|
6289619e3f | ||
|
|
b42fa7862e | ||
|
|
dc6622c05c | ||
|
|
30b37c4028 | ||
|
|
4497f1e90f | ||
|
|
b50d3e1334 | ||
|
|
e075460937 | ||
|
|
5e23e48e4d | ||
|
|
db04c0f319 | ||
|
|
0f31ca79e3 | ||
|
|
2cb6b35030 | ||
|
|
2ec9161a98 | ||
|
|
3a3038c702 | ||
|
|
b1f1044c8e | ||
|
|
deb96201d5 | ||
|
|
05b47cc070 | ||
|
|
9ef8da4759 | ||
|
|
506fcf83a6 | ||
|
|
6fb360a1ff | ||
|
|
9332b9e8c9 | ||
|
|
da950e6d2c | ||
|
|
4b9a02d89c | ||
|
|
518924eac5 | ||
|
|
e712003dfb | ||
|
|
d04ce06a30 | ||
|
|
8135d70e4e | ||
|
|
3067155cea | ||
|
|
ab7315cca1 | ||
|
|
15294c304e |
487
README.md
@@ -1,4 +1,4 @@
|
||||
<img src="./vit.gif" width="500px"></img>
|
||||
<img src="./images/vit.gif" width="500px"></img>
|
||||
|
||||
## Vision Transformer - Pytorch
|
||||
|
||||
@@ -33,12 +33,12 @@ v = ViT(
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 256, 256)
|
||||
mask = torch.ones(1, 8, 8).bool() # optional mask, designating which patch to attend to
|
||||
|
||||
preds = v(img, mask = mask) # (1, 1000)
|
||||
preds = v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## Parameters
|
||||
|
||||
- `image_size`: int.
|
||||
Image size. If you have rectangular images, make sure your image size is the maximum of the width and height
|
||||
- `patch_size`: int.
|
||||
@@ -64,7 +64,7 @@ Embedding dropout rate.
|
||||
|
||||
## Distillation
|
||||
|
||||
<img src="./distill.png" width="300px"></img>
|
||||
<img src="./images/distill.png" width="300px"></img>
|
||||
|
||||
A recent <a href="https://arxiv.org/abs/2012.12877">paper</a> has shown that use of a distillation token for distilling knowledge from convolutional nets to vision transformer can yield small and efficient vision transformers. This repository offers the means to do distillation easily.
|
||||
|
||||
@@ -94,7 +94,8 @@ distiller = DistillWrapper(
|
||||
student = v,
|
||||
teacher = teacher,
|
||||
temperature = 3, # temperature of distillation
|
||||
alpha = 0.5 # trade between main loss and distillation loss
|
||||
alpha = 0.5, # trade between main loss and distillation loss
|
||||
hard = False # whether to use soft or hard distillation
|
||||
)
|
||||
|
||||
img = torch.randn(2, 3, 256, 256)
|
||||
@@ -144,9 +145,40 @@ img = torch.randn(1, 3, 256, 256)
|
||||
preds = v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## CaiT
|
||||
|
||||
<a href="https://arxiv.org/abs/2103.17239">This paper</a> also notes difficulty in training vision transformers at greater depths and proposes two solutions. First it proposes to do per-channel multiplication of the output of the residual block. Second, it proposes to have the patches attend to one another, and only allow the CLS token to attend to the patches in the last few layers.
|
||||
|
||||
They also add <a href="https://github.com/lucidrains/x-transformers#talking-heads-attention">Talking Heads</a>, noting improvements
|
||||
|
||||
You can use this scheme as follows
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.cait import CaiT
|
||||
|
||||
v = CaiT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 12, # depth of transformer for patch to patch attention only
|
||||
cls_depth = 2, # depth of cross attention of CLS tokens to patch
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1,
|
||||
layer_dropout = 0.05 # randomly dropout 5% of the layers
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 256, 256)
|
||||
|
||||
preds = v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## Token-to-Token ViT
|
||||
|
||||
<img src="./t2t.png" width="400px"></img>
|
||||
<img src="./images/t2t.png" width="400px"></img>
|
||||
|
||||
<a href="https://arxiv.org/abs/2101.11986">This paper</a> proposes that the first couple layers should downsample the image sequence by unfolding, leading to overlapping image data in each token as shown in the figure above. You can use this variant of the `ViT` as follows.
|
||||
|
||||
@@ -165,7 +197,211 @@ v = T2TViT(
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 224, 224)
|
||||
v(img) # (1, 1000)
|
||||
|
||||
preds = v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## Cross ViT
|
||||
|
||||
<img src="./images/cross_vit.png" width="400px"></img>
|
||||
|
||||
<a href="https://arxiv.org/abs/2103.14899">This paper</a> proposes to have two vision transformers processing the image at different scales, cross attending to one every so often. They show improvements on top of the base vision transformer.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.cross_vit import CrossViT
|
||||
|
||||
v = CrossViT(
|
||||
image_size = 256,
|
||||
num_classes = 1000,
|
||||
depth = 4, # number of multi-scale encoding blocks
|
||||
sm_dim = 192, # high res dimension
|
||||
sm_patch_size = 16, # high res patch size (should be smaller than lg_patch_size)
|
||||
sm_enc_depth = 2, # high res depth
|
||||
sm_enc_heads = 8, # high res heads
|
||||
sm_enc_mlp_dim = 2048, # high res feedforward dimension
|
||||
lg_dim = 384, # low res dimension
|
||||
lg_patch_size = 64, # low res patch size
|
||||
lg_enc_depth = 3, # low res depth
|
||||
lg_enc_heads = 8, # low res heads
|
||||
lg_enc_mlp_dim = 2048, # low res feedforward dimensions
|
||||
cross_attn_depth = 2, # cross attention rounds
|
||||
cross_attn_heads = 8, # cross attention heads
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 256, 256)
|
||||
|
||||
pred = v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## PiT
|
||||
|
||||
<img src="./images/pit.png" width="400px"></img>
|
||||
|
||||
<a href="https://arxiv.org/abs/2103.16302">This paper</a> proposes to downsample the tokens through a pooling procedure using depth-wise convolutions.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.pit import PiT
|
||||
|
||||
v = PiT(
|
||||
image_size = 224,
|
||||
patch_size = 14,
|
||||
dim = 256,
|
||||
num_classes = 1000,
|
||||
depth = (3, 3, 3), # list of depths, indicating the number of rounds of each stage before a downsample
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
# forward pass now returns predictions and the attention maps
|
||||
|
||||
img = torch.randn(1, 3, 224, 224)
|
||||
|
||||
preds = v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## LeViT
|
||||
|
||||
<img src="./images/levit.png" width="300px"></img>
|
||||
|
||||
<a href="https://arxiv.org/abs/2104.01136">This paper</a> proposes a number of changes, including (1) convolutional embedding instead of patch-wise projection (2) downsampling in stages (3) extra non-linearity in attention (4) 2d relative positional biases instead of initial absolute positional bias (5) batchnorm in place of layernorm.
|
||||
|
||||
<a href="https://github.com/facebookresearch/LeViT">Official repository</a>
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.levit import LeViT
|
||||
|
||||
levit = LeViT(
|
||||
image_size = 224,
|
||||
num_classes = 1000,
|
||||
stages = 3, # number of stages
|
||||
dim = (256, 384, 512), # dimensions at each stage
|
||||
depth = 4, # transformer of depth 4 at each stage
|
||||
heads = (4, 6, 8), # heads at each stage
|
||||
mlp_mult = 2,
|
||||
dropout = 0.1
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 224, 224)
|
||||
|
||||
levit(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## CvT
|
||||
|
||||
<img src="./images/cvt.png" width="400px"></img>
|
||||
|
||||
<a href="https://arxiv.org/abs/2103.15808">This paper</a> proposes mixing convolutions and attention. Specifically, convolutions are used to embed and downsample the image / feature map in three stages. Depthwise-convoltion is also used to project the queries, keys, and values for attention.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.cvt import CvT
|
||||
|
||||
v = CvT(
|
||||
num_classes = 1000,
|
||||
s1_emb_dim = 64, # stage 1 - dimension
|
||||
s1_emb_kernel = 7, # stage 1 - conv kernel
|
||||
s1_emb_stride = 4, # stage 1 - conv stride
|
||||
s1_proj_kernel = 3, # stage 1 - attention ds-conv kernel size
|
||||
s1_kv_proj_stride = 2, # stage 1 - attention key / value projection stride
|
||||
s1_heads = 1, # stage 1 - heads
|
||||
s1_depth = 1, # stage 1 - depth
|
||||
s1_mlp_mult = 4, # stage 1 - feedforward expansion factor
|
||||
s2_emb_dim = 192, # stage 2 - (same as above)
|
||||
s2_emb_kernel = 3,
|
||||
s2_emb_stride = 2,
|
||||
s2_proj_kernel = 3,
|
||||
s2_kv_proj_stride = 2,
|
||||
s2_heads = 3,
|
||||
s2_depth = 2,
|
||||
s2_mlp_mult = 4,
|
||||
s3_emb_dim = 384, # stage 3 - (same as above)
|
||||
s3_emb_kernel = 3,
|
||||
s3_emb_stride = 2,
|
||||
s3_proj_kernel = 3,
|
||||
s3_kv_proj_stride = 2,
|
||||
s3_heads = 4,
|
||||
s3_depth = 10,
|
||||
s3_mlp_mult = 4,
|
||||
dropout = 0.
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 224, 224)
|
||||
|
||||
pred = v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## Twins SVT
|
||||
|
||||
<img src="./images/twins_svt.png" width="400px"></img>
|
||||
|
||||
This <a href="https://arxiv.org/abs/2104.13840">paper</a> proposes mixing local and global attention, along with position encoding generator (proposed in <a href="https://arxiv.org/abs/2102.10882">CPVT</a>) and global average pooling, to achieve the same results as <a href="https://arxiv.org/abs/2103.14030">Swin</a>, without the extra complexity of shifted windows, CLS tokens, nor positional embeddings.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.twins_svt import TwinsSVT
|
||||
|
||||
model = TwinsSVT(
|
||||
num_classes = 1000, # number of output classes
|
||||
s1_emb_dim = 64, # stage 1 - patch embedding projected dimension
|
||||
s1_patch_size = 4, # stage 1 - patch size for patch embedding
|
||||
s1_local_patch_size = 7, # stage 1 - patch size for local attention
|
||||
s1_global_k = 7, # stage 1 - global attention key / value reduction factor, defaults to 7 as specified in paper
|
||||
s1_depth = 1, # stage 1 - number of transformer blocks (local attn -> ff -> global attn -> ff)
|
||||
s2_emb_dim = 128, # stage 2 (same as above)
|
||||
s2_patch_size = 2,
|
||||
s2_local_patch_size = 7,
|
||||
s2_global_k = 7,
|
||||
s2_depth = 1,
|
||||
s3_emb_dim = 256, # stage 3 (same as above)
|
||||
s3_patch_size = 2,
|
||||
s3_local_patch_size = 7,
|
||||
s3_global_k = 7,
|
||||
s3_depth = 5,
|
||||
s4_emb_dim = 512, # stage 4 (same as above)
|
||||
s4_patch_size = 2,
|
||||
s4_local_patch_size = 7,
|
||||
s4_global_k = 7,
|
||||
s4_depth = 4,
|
||||
peg_kernel_size = 3, # positional encoding generator kernel size
|
||||
dropout = 0. # dropout
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 224, 224)
|
||||
|
||||
pred = model(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## NesT
|
||||
|
||||
<img src="./images/nest.png" width="400px"></img>
|
||||
|
||||
This <a href="https://arxiv.org/abs/2105.12723">paper</a> decided to process the image in hierarchical stages, with attention only within tokens of local blocks, which aggregate as it moves up the heirarchy. The aggregation is done in the image plane, and contains a convolution and subsequent maxpool to allow it to pass information across the boundary.
|
||||
|
||||
You can use it with the following code (ex. NesT-T)
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.nest import NesT
|
||||
|
||||
nest = NesT(
|
||||
image_size = 224,
|
||||
patch_size = 4,
|
||||
dim = 96,
|
||||
heads = 3,
|
||||
num_hierarchies = 3, # number of hierarchies
|
||||
block_repeats = (8, 4, 1), # the number of transformer blocks at each heirarchy, starting from the bottom
|
||||
num_classes = 1000
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 224, 224)
|
||||
pred = nest(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## Masked Patch Prediction
|
||||
@@ -214,22 +450,17 @@ for _ in range(100):
|
||||
torch.save(model.state_dict(), './pretrained-net.pt')
|
||||
```
|
||||
|
||||
## Research Ideas
|
||||
## Dino
|
||||
|
||||
### Self Supervised Training
|
||||
<img src="./images/dino.png" width="350px"></img>
|
||||
|
||||
You can train this with a near SOTA self-supervised learning technique, <a href="https://github.com/lucidrains/byol-pytorch">BYOL</a>, with the following code.
|
||||
You can train `ViT` with the recent SOTA self-supervised learning technique, <a href="https://arxiv.org/abs/2104.14294">Dino</a>, with the following code.
|
||||
|
||||
(1)
|
||||
```bash
|
||||
$ pip install byol-pytorch
|
||||
```
|
||||
<a href="https://www.youtube.com/watch?v=h3ij3F3cPIk">Yannic Kilcher</a> video
|
||||
|
||||
(2)
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch import ViT
|
||||
from byol_pytorch import BYOL
|
||||
from vit_pytorch import ViT, Dino
|
||||
|
||||
model = ViT(
|
||||
image_size = 256,
|
||||
@@ -241,13 +472,22 @@ model = ViT(
|
||||
mlp_dim = 2048
|
||||
)
|
||||
|
||||
learner = BYOL(
|
||||
learner = Dino(
|
||||
model,
|
||||
image_size = 256,
|
||||
hidden_layer = 'to_latent'
|
||||
hidden_layer = 'to_latent', # hidden layer name or index, from which to extract the embedding
|
||||
projection_hidden_size = 256, # projector network hidden dimension
|
||||
projection_layers = 4, # number of layers in projection network
|
||||
num_classes_K = 65336, # output logits dimensions (referenced as K in paper)
|
||||
student_temp = 0.9, # student temperature
|
||||
teacher_temp = 0.04, # teacher temperature, needs to be annealed from 0.04 to 0.07 over 30 epochs
|
||||
local_upper_crop_scale = 0.4, # upper bound for local crop - 0.4 was recommended in the paper
|
||||
global_lower_crop_scale = 0.5, # lower bound for global crop - 0.5 was recommended in the paper
|
||||
moving_average_decay = 0.9, # moving average of encoder - paper showed anywhere from 0.9 to 0.999 was ok
|
||||
center_moving_average_decay = 0.9, # moving average of teacher centers - paper showed anywhere from 0.9 to 0.999 was ok
|
||||
)
|
||||
|
||||
opt = torch.optim.Adam(learner.parameters(), lr=3e-4)
|
||||
opt = torch.optim.Adam(learner.parameters(), lr = 3e-4)
|
||||
|
||||
def sample_unlabelled_images():
|
||||
return torch.randn(20, 3, 256, 256)
|
||||
@@ -258,13 +498,54 @@ for _ in range(100):
|
||||
opt.zero_grad()
|
||||
loss.backward()
|
||||
opt.step()
|
||||
learner.update_moving_average() # update moving average of target encoder
|
||||
learner.update_moving_average() # update moving average of teacher encoder and teacher centers
|
||||
|
||||
# save your improved network
|
||||
torch.save(model.state_dict(), './pretrained-net.pt')
|
||||
```
|
||||
|
||||
A pytorch-lightning script is ready for you to use at the repository link above.
|
||||
## Accessing Attention
|
||||
|
||||
If you would like to visualize the attention weights (post-softmax) for your research, just follow the procedure below
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.vit import ViT
|
||||
|
||||
v = ViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
# import Recorder and wrap the ViT
|
||||
|
||||
from vit_pytorch.recorder import Recorder
|
||||
v = Recorder(v)
|
||||
|
||||
# forward pass now returns predictions and the attention maps
|
||||
|
||||
img = torch.randn(1, 3, 256, 256)
|
||||
preds, attns = v(img)
|
||||
|
||||
# there is one extra patch due to the CLS token
|
||||
|
||||
attns # (1, 6, 16, 65, 65) - (batch x layers x heads x patch x patch)
|
||||
```
|
||||
|
||||
to cleanup the class and the hooks once you have collected enough data
|
||||
|
||||
```python
|
||||
v = v.eject() # wrapper is discarded and original ViT instance is returned
|
||||
```
|
||||
|
||||
## Research Ideas
|
||||
|
||||
### Efficient Attention
|
||||
|
||||
@@ -335,6 +616,58 @@ img = torch.randn(1, 3, 224, 224)
|
||||
v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## FAQ
|
||||
|
||||
- How do I pass in non-square images?
|
||||
|
||||
You can already pass in non-square images - you just have to make sure your height and width is less than or equal to the `image_size`, and both divisible by the `patch_size`
|
||||
|
||||
ex.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch import ViT
|
||||
|
||||
v = ViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 256, 128) # <-- not a square
|
||||
|
||||
preds = v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
- How do I pass in non-square patches?
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch import ViT
|
||||
|
||||
v = ViT(
|
||||
num_classes = 1000,
|
||||
image_size = (256, 128), # image size is a tuple of (height, width)
|
||||
patch_size = (32, 16), # patch size is a tuple of (height, width)
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 256, 128)
|
||||
|
||||
preds = v(img)
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
Coming from computer vision and new to transformers? Here are some resources that greatly accelerated my learning.
|
||||
@@ -392,6 +725,116 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{touvron2021going,
|
||||
title = {Going deeper with Image Transformers},
|
||||
author = {Hugo Touvron and Matthieu Cord and Alexandre Sablayrolles and Gabriel Synnaeve and Hervé Jégou},
|
||||
year = {2021},
|
||||
eprint = {2103.17239},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{chen2021crossvit,
|
||||
title = {CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification},
|
||||
author = {Chun-Fu Chen and Quanfu Fan and Rameswar Panda},
|
||||
year = {2021},
|
||||
eprint = {2103.14899},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{wu2021cvt,
|
||||
title = {CvT: Introducing Convolutions to Vision Transformers},
|
||||
author = {Haiping Wu and Bin Xiao and Noel Codella and Mengchen Liu and Xiyang Dai and Lu Yuan and Lei Zhang},
|
||||
year = {2021},
|
||||
eprint = {2103.15808},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{heo2021rethinking,
|
||||
title = {Rethinking Spatial Dimensions of Vision Transformers},
|
||||
author = {Byeongho Heo and Sangdoo Yun and Dongyoon Han and Sanghyuk Chun and Junsuk Choe and Seong Joon Oh},
|
||||
year = {2021},
|
||||
eprint = {2103.16302},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{graham2021levit,
|
||||
title = {LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference},
|
||||
author = {Ben Graham and Alaaeldin El-Nouby and Hugo Touvron and Pierre Stock and Armand Joulin and Hervé Jégou and Matthijs Douze},
|
||||
year = {2021},
|
||||
eprint = {2104.01136},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{li2021localvit,
|
||||
title = {LocalViT: Bringing Locality to Vision Transformers},
|
||||
author = {Yawei Li and Kai Zhang and Jiezhang Cao and Radu Timofte and Luc Van Gool},
|
||||
year = {2021},
|
||||
eprint = {2104.05707},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{chu2021twins,
|
||||
title = {Twins: Revisiting Spatial Attention Design in Vision Transformers},
|
||||
author = {Xiangxiang Chu and Zhi Tian and Yuqing Wang and Bo Zhang and Haibing Ren and Xiaolin Wei and Huaxia Xia and Chunhua Shen},
|
||||
year = {2021},
|
||||
eprint = {2104.13840},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{su2021roformer,
|
||||
title = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
|
||||
author = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
|
||||
year = {2021},
|
||||
eprint = {2104.09864},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CL}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{zhang2021aggregating,
|
||||
title = {Aggregating Nested Transformers},
|
||||
author = {Zizhao Zhang and Han Zhang and Long Zhao and Ting Chen and Tomas Pfister},
|
||||
year = {2021},
|
||||
eprint = {2105.12723},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{caron2021emerging,
|
||||
title = {Emerging Properties in Self-Supervised Vision Transformers},
|
||||
author = {Mathilde Caron and Hugo Touvron and Ishan Misra and Hervé Jégou and Julien Mairal and Piotr Bojanowski and Armand Joulin},
|
||||
year = {2021},
|
||||
eprint = {2104.14294},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{vaswani2017attention,
|
||||
title = {Attention Is All You Need},
|
||||
|
||||
BIN
images/cait.png
Normal file
|
After Width: | Height: | Size: 63 KiB |
BIN
images/cross_vit.png
Normal file
|
After Width: | Height: | Size: 82 KiB |
BIN
images/cvt.png
Normal file
|
After Width: | Height: | Size: 65 KiB |
BIN
images/dino.png
Normal file
|
After Width: | Height: | Size: 84 KiB |
|
Before Width: | Height: | Size: 49 KiB After Width: | Height: | Size: 49 KiB |
BIN
images/levit.png
Normal file
|
After Width: | Height: | Size: 71 KiB |
BIN
images/nest.png
Normal file
|
After Width: | Height: | Size: 75 KiB |
BIN
images/pit.png
Normal file
|
After Width: | Height: | Size: 24 KiB |
|
Before Width: | Height: | Size: 109 KiB After Width: | Height: | Size: 109 KiB |
BIN
images/twins_svt.png
Normal file
|
After Width: | Height: | Size: 110 KiB |
|
Before Width: | Height: | Size: 5.8 MiB After Width: | Height: | Size: 5.8 MiB |
5
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.9.1',
|
||||
version = '0.19.4',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
@@ -15,8 +15,9 @@ setup(
|
||||
'image recognition'
|
||||
],
|
||||
install_requires=[
|
||||
'einops>=0.3',
|
||||
'torch>=1.6',
|
||||
'einops>=0.3'
|
||||
'torchvision'
|
||||
],
|
||||
classifiers=[
|
||||
'Development Status :: 4 - Beta',
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
from vit_pytorch.vit import ViT
|
||||
from vit_pytorch.dino import Dino
|
||||
|
||||
177
vit_pytorch/cait.py
Normal file
@@ -0,0 +1,177 @@
|
||||
from random import randrange
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def dropout_layers(layers, dropout):
|
||||
if dropout == 0:
|
||||
return layers
|
||||
|
||||
num_layers = len(layers)
|
||||
to_drop = torch.zeros(num_layers).uniform_(0., 1.) < dropout
|
||||
|
||||
# make sure at least one layer makes it
|
||||
if all(to_drop):
|
||||
rand_index = randrange(num_layers)
|
||||
to_drop[rand_index] = False
|
||||
|
||||
layers = [layer for (layer, drop) in zip(layers, to_drop) if not drop]
|
||||
return layers
|
||||
|
||||
# classes
|
||||
|
||||
class LayerScale(nn.Module):
|
||||
def __init__(self, dim, fn, depth):
|
||||
super().__init__()
|
||||
if depth <= 18: # epsilon detailed in section 2 of paper
|
||||
init_eps = 0.1
|
||||
elif depth > 18 and depth <= 24:
|
||||
init_eps = 1e-5
|
||||
else:
|
||||
init_eps = 1e-6
|
||||
|
||||
scale = torch.zeros(1, 1, dim).fill_(init_eps)
|
||||
self.scale = nn.Parameter(scale)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(x, **kwargs) * self.scale
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
self.mix_heads_pre_attn = nn.Parameter(torch.randn(heads, heads))
|
||||
self.mix_heads_post_attn = nn.Parameter(torch.randn(heads, heads))
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, context = None):
|
||||
b, n, _, h = *x.shape, self.heads
|
||||
|
||||
context = x if not exists(context) else torch.cat((x, context), dim = 1)
|
||||
|
||||
qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
||||
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
|
||||
dots = einsum('b h i j, h g -> b g i j', dots, self.mix_heads_pre_attn) # talking heads, pre-softmax
|
||||
attn = self.attend(dots)
|
||||
attn = einsum('b h i j, h g -> b g i j', attn, self.mix_heads_post_attn) # talking heads, post-softmax
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., layer_dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
self.layer_dropout = layer_dropout
|
||||
|
||||
for ind in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
LayerScale(dim, PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), depth = ind + 1),
|
||||
LayerScale(dim, PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)), depth = ind + 1)
|
||||
]))
|
||||
def forward(self, x, context = None):
|
||||
layers = dropout_layers(self.layers, dropout = self.layer_dropout)
|
||||
|
||||
for attn, ff in layers:
|
||||
x = attn(x, context = context) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
class CaiT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_size,
|
||||
patch_size,
|
||||
num_classes,
|
||||
dim,
|
||||
depth,
|
||||
cls_depth,
|
||||
heads,
|
||||
mlp_dim,
|
||||
dim_head = 64,
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.,
|
||||
layer_dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
patch_dim = 3 * patch_size ** 2
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
||||
nn.Linear(patch_dim, dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.patch_transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, layer_dropout)
|
||||
self.cls_transformer = Transformer(dim, cls_depth, heads, dim_head, mlp_dim, dropout, layer_dropout)
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
x += self.pos_embedding[:, :n]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.patch_transformer(x)
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = self.cls_transformer(cls_tokens, context = x)
|
||||
|
||||
return self.mlp_head(x[:, 0])
|
||||
270
vit_pytorch/cross_vit.py
Normal file
@@ -0,0 +1,270 @@
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
# pre-layernorm
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
# feedforward
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
# attention
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, context = None, kv_include_self = False):
|
||||
b, n, _, h = *x.shape, self.heads
|
||||
context = default(context, x)
|
||||
|
||||
if kv_include_self:
|
||||
context = torch.cat((x, context), dim = 1) # cross attention requires CLS token includes itself as key / value
|
||||
|
||||
qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
||||
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
# transformer encoder, for small and large patches
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return self.norm(x)
|
||||
|
||||
# projecting CLS tokens, in the case that small and large patch tokens have different dimensions
|
||||
|
||||
class ProjectInOut(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
|
||||
need_projection = dim_in != dim_out
|
||||
self.project_in = nn.Linear(dim_in, dim_out) if need_projection else nn.Identity()
|
||||
self.project_out = nn.Linear(dim_out, dim_in) if need_projection else nn.Identity()
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
x = self.project_in(x)
|
||||
x = self.fn(x, *args, **kwargs)
|
||||
x = self.project_out(x)
|
||||
return x
|
||||
|
||||
# cross attention transformer
|
||||
|
||||
class CrossTransformer(nn.Module):
|
||||
def __init__(self, sm_dim, lg_dim, depth, heads, dim_head, dropout):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
ProjectInOut(sm_dim, lg_dim, PreNorm(lg_dim, Attention(lg_dim, heads = heads, dim_head = dim_head, dropout = dropout))),
|
||||
ProjectInOut(lg_dim, sm_dim, PreNorm(sm_dim, Attention(sm_dim, heads = heads, dim_head = dim_head, dropout = dropout)))
|
||||
]))
|
||||
|
||||
def forward(self, sm_tokens, lg_tokens):
|
||||
(sm_cls, sm_patch_tokens), (lg_cls, lg_patch_tokens) = map(lambda t: (t[:, :1], t[:, 1:]), (sm_tokens, lg_tokens))
|
||||
|
||||
for sm_attend_lg, lg_attend_sm in self.layers:
|
||||
sm_cls = sm_attend_lg(sm_cls, context = lg_patch_tokens, kv_include_self = True) + sm_cls
|
||||
lg_cls = lg_attend_sm(lg_cls, context = sm_patch_tokens, kv_include_self = True) + lg_cls
|
||||
|
||||
sm_tokens = torch.cat((sm_cls, sm_patch_tokens), dim = 1)
|
||||
lg_tokens = torch.cat((lg_cls, lg_patch_tokens), dim = 1)
|
||||
return sm_tokens, lg_tokens
|
||||
|
||||
# multi-scale encoder
|
||||
|
||||
class MultiScaleEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
depth,
|
||||
sm_dim,
|
||||
lg_dim,
|
||||
sm_enc_params,
|
||||
lg_enc_params,
|
||||
cross_attn_heads,
|
||||
cross_attn_depth,
|
||||
cross_attn_dim_head = 64,
|
||||
dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Transformer(dim = sm_dim, dropout = dropout, **sm_enc_params),
|
||||
Transformer(dim = lg_dim, dropout = dropout, **lg_enc_params),
|
||||
CrossTransformer(sm_dim = sm_dim, lg_dim = lg_dim, depth = cross_attn_depth, heads = cross_attn_heads, dim_head = cross_attn_dim_head, dropout = dropout)
|
||||
]))
|
||||
|
||||
def forward(self, sm_tokens, lg_tokens):
|
||||
for sm_enc, lg_enc, cross_attend in self.layers:
|
||||
sm_tokens, lg_tokens = sm_enc(sm_tokens), lg_enc(lg_tokens)
|
||||
sm_tokens, lg_tokens = cross_attend(sm_tokens, lg_tokens)
|
||||
|
||||
return sm_tokens, lg_tokens
|
||||
|
||||
# patch-based image to token embedder
|
||||
|
||||
class ImageEmbedder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
image_size,
|
||||
patch_size,
|
||||
dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
patch_dim = 3 * patch_size ** 2
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
||||
nn.Linear(patch_dim, dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, img):
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
|
||||
return self.dropout(x)
|
||||
|
||||
# cross ViT class
|
||||
|
||||
class CrossViT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_size,
|
||||
num_classes,
|
||||
sm_dim,
|
||||
lg_dim,
|
||||
sm_patch_size = 12,
|
||||
sm_enc_depth = 1,
|
||||
sm_enc_heads = 8,
|
||||
sm_enc_mlp_dim = 2048,
|
||||
sm_enc_dim_head = 64,
|
||||
lg_patch_size = 16,
|
||||
lg_enc_depth = 4,
|
||||
lg_enc_heads = 8,
|
||||
lg_enc_mlp_dim = 2048,
|
||||
lg_enc_dim_head = 64,
|
||||
cross_attn_depth = 2,
|
||||
cross_attn_heads = 8,
|
||||
cross_attn_dim_head = 64,
|
||||
depth = 3,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
):
|
||||
super().__init__()
|
||||
self.sm_image_embedder = ImageEmbedder(dim = sm_dim, image_size = image_size, patch_size = sm_patch_size, dropout = emb_dropout)
|
||||
self.lg_image_embedder = ImageEmbedder(dim = lg_dim, image_size = image_size, patch_size = lg_patch_size, dropout = emb_dropout)
|
||||
|
||||
self.multi_scale_encoder = MultiScaleEncoder(
|
||||
depth = depth,
|
||||
sm_dim = sm_dim,
|
||||
lg_dim = lg_dim,
|
||||
cross_attn_heads = cross_attn_heads,
|
||||
cross_attn_dim_head = cross_attn_dim_head,
|
||||
cross_attn_depth = cross_attn_depth,
|
||||
sm_enc_params = dict(
|
||||
depth = sm_enc_depth,
|
||||
heads = sm_enc_heads,
|
||||
mlp_dim = sm_enc_mlp_dim,
|
||||
dim_head = sm_enc_dim_head
|
||||
),
|
||||
lg_enc_params = dict(
|
||||
depth = lg_enc_depth,
|
||||
heads = lg_enc_heads,
|
||||
mlp_dim = lg_enc_mlp_dim,
|
||||
dim_head = lg_enc_dim_head
|
||||
),
|
||||
dropout = dropout
|
||||
)
|
||||
|
||||
self.sm_mlp_head = nn.Sequential(nn.LayerNorm(sm_dim), nn.Linear(sm_dim, num_classes))
|
||||
self.lg_mlp_head = nn.Sequential(nn.LayerNorm(lg_dim), nn.Linear(lg_dim, num_classes))
|
||||
|
||||
def forward(self, img):
|
||||
sm_tokens = self.sm_image_embedder(img)
|
||||
lg_tokens = self.lg_image_embedder(img)
|
||||
|
||||
sm_tokens, lg_tokens = self.multi_scale_encoder(sm_tokens, lg_tokens)
|
||||
|
||||
sm_cls, lg_cls = map(lambda t: t[:, 0], (sm_tokens, lg_tokens))
|
||||
|
||||
sm_logits = self.sm_mlp_head(sm_cls)
|
||||
lg_logits = self.lg_mlp_head(lg_cls)
|
||||
|
||||
return sm_logits + lg_logits
|
||||
173
vit_pytorch/cvt.py
Normal file
@@ -0,0 +1,173 @@
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helper methods
|
||||
|
||||
def group_dict_by_key(cond, d):
|
||||
return_val = [dict(), dict()]
|
||||
for key in d.keys():
|
||||
match = bool(cond(key))
|
||||
ind = int(not match)
|
||||
return_val[ind][key] = d[key]
|
||||
return (*return_val,)
|
||||
|
||||
def group_by_key_prefix_and_remove_prefix(prefix, d):
|
||||
kwargs_with_prefix, kwargs = group_dict_by_key(lambda x: x.startswith(prefix), d)
|
||||
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
||||
return kwargs_without_prefix, kwargs
|
||||
|
||||
# classes
|
||||
|
||||
class LayerNorm(nn.Module): # layernorm, but done in the channel dimension #1
|
||||
def __init__(self, dim, eps = 1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (std + self.eps) * self.g + self.b
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
x = self.norm(x)
|
||||
return self.fn(x, **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, mult = 4, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(dim, dim * mult, 1),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Conv2d(dim * mult, dim, 1),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class DepthWiseConv2d(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, kernel_size, padding, stride, bias = True):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
|
||||
nn.BatchNorm2d(dim_in),
|
||||
nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, proj_kernel, kv_proj_stride, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
padding = proj_kernel // 2
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
self.to_q = DepthWiseConv2d(dim, inner_dim, proj_kernel, padding = padding, stride = 1, bias = False)
|
||||
self.to_kv = DepthWiseConv2d(dim, inner_dim * 2, proj_kernel, padding = padding, stride = kv_proj_stride, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Conv2d(inner_dim, dim, 1),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
shape = x.shape
|
||||
b, n, _, y, h = *shape, self.heads
|
||||
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = 1))
|
||||
q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> (b h) (x y) d', h = h), (q, k, v))
|
||||
|
||||
dots = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = einsum('b i j, b j d -> b i d', attn, v)
|
||||
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, y = y)
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, proj_kernel, kv_proj_stride, depth, heads, dim_head = 64, mlp_mult = 4, dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, proj_kernel = proj_kernel, kv_proj_stride = kv_proj_stride, heads = heads, dim_head = dim_head, dropout = dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout))
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
class CvT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
num_classes,
|
||||
s1_emb_dim = 64,
|
||||
s1_emb_kernel = 7,
|
||||
s1_emb_stride = 4,
|
||||
s1_proj_kernel = 3,
|
||||
s1_kv_proj_stride = 2,
|
||||
s1_heads = 1,
|
||||
s1_depth = 1,
|
||||
s1_mlp_mult = 4,
|
||||
s2_emb_dim = 192,
|
||||
s2_emb_kernel = 3,
|
||||
s2_emb_stride = 2,
|
||||
s2_proj_kernel = 3,
|
||||
s2_kv_proj_stride = 2,
|
||||
s2_heads = 3,
|
||||
s2_depth = 2,
|
||||
s2_mlp_mult = 4,
|
||||
s3_emb_dim = 384,
|
||||
s3_emb_kernel = 3,
|
||||
s3_emb_stride = 2,
|
||||
s3_proj_kernel = 3,
|
||||
s3_kv_proj_stride = 2,
|
||||
s3_heads = 6,
|
||||
s3_depth = 10,
|
||||
s3_mlp_mult = 4,
|
||||
dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
kwargs = dict(locals())
|
||||
|
||||
dim = 3
|
||||
layers = []
|
||||
|
||||
for prefix in ('s1', 's2', 's3'):
|
||||
config, kwargs = group_by_key_prefix_and_remove_prefix(f'{prefix}_', kwargs)
|
||||
|
||||
layers.append(nn.Sequential(
|
||||
nn.Conv2d(dim, config['emb_dim'], kernel_size = config['emb_kernel'], padding = (config['emb_kernel'] // 2), stride = config['emb_stride']),
|
||||
LayerNorm(config['emb_dim']),
|
||||
Transformer(dim = config['emb_dim'], proj_kernel = config['proj_kernel'], kv_proj_stride = config['kv_proj_stride'], depth = config['depth'], heads = config['heads'], mlp_mult = config['mlp_mult'], dropout = dropout)
|
||||
))
|
||||
|
||||
dim = config['emb_dim']
|
||||
|
||||
self.layers = nn.Sequential(
|
||||
*layers,
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
Rearrange('... () () -> ...'),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
303
vit_pytorch/dino.py
Normal file
@@ -0,0 +1,303 @@
|
||||
import copy
|
||||
import random
|
||||
from functools import wraps, partial
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torchvision import transforms as T
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, default):
|
||||
return val if exists(val) else default
|
||||
|
||||
def singleton(cache_key):
|
||||
def inner_fn(fn):
|
||||
@wraps(fn)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
instance = getattr(self, cache_key)
|
||||
if instance is not None:
|
||||
return instance
|
||||
|
||||
instance = fn(self, *args, **kwargs)
|
||||
setattr(self, cache_key, instance)
|
||||
return instance
|
||||
return wrapper
|
||||
return inner_fn
|
||||
|
||||
def get_module_device(module):
|
||||
return next(module.parameters()).device
|
||||
|
||||
def set_requires_grad(model, val):
|
||||
for p in model.parameters():
|
||||
p.requires_grad = val
|
||||
|
||||
# loss function # (algorithm 1 in the paper)
|
||||
|
||||
def loss_fn(
|
||||
teacher_logits,
|
||||
student_logits,
|
||||
teacher_temp,
|
||||
student_temp,
|
||||
centers,
|
||||
eps = 1e-20
|
||||
):
|
||||
teacher_logits = teacher_logits.detach()
|
||||
student_probs = (student_logits / student_temp).softmax(dim = -1)
|
||||
teacher_probs = ((teacher_logits - centers) / teacher_temp).softmax(dim = -1)
|
||||
return - (teacher_probs * torch.log(student_probs + eps)).sum(dim = -1).mean()
|
||||
|
||||
# augmentation utils
|
||||
|
||||
class RandomApply(nn.Module):
|
||||
def __init__(self, fn, p):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.p = p
|
||||
|
||||
def forward(self, x):
|
||||
if random.random() > self.p:
|
||||
return x
|
||||
return self.fn(x)
|
||||
|
||||
# exponential moving average
|
||||
|
||||
class EMA():
|
||||
def __init__(self, beta):
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
|
||||
def update_average(self, old, new):
|
||||
if old is None:
|
||||
return new
|
||||
return old * self.beta + (1 - self.beta) * new
|
||||
|
||||
def update_moving_average(ema_updater, ma_model, current_model):
|
||||
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
|
||||
old_weight, up_weight = ma_params.data, current_params.data
|
||||
ma_params.data = ema_updater.update_average(old_weight, up_weight)
|
||||
|
||||
# MLP class for projector and predictor
|
||||
|
||||
class L2Norm(nn.Module):
|
||||
def forward(self, x, eps = 1e-6):
|
||||
norm = x.norm(dim = 1, keepdim = True).clamp(min = eps)
|
||||
return x / norm
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, dim_out, num_layers, hidden_size = 256):
|
||||
super().__init__()
|
||||
|
||||
layers = []
|
||||
dims = (dim, *((hidden_size,) * (num_layers - 1)))
|
||||
|
||||
for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
is_last = ind == (len(dims) - 1)
|
||||
|
||||
layers.extend([
|
||||
nn.Linear(layer_dim_in, layer_dim_out),
|
||||
nn.GELU() if not is_last else nn.Identity()
|
||||
])
|
||||
|
||||
self.net = nn.Sequential(
|
||||
*layers,
|
||||
L2Norm(),
|
||||
nn.Linear(hidden_size, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
# a wrapper class for the base neural network
|
||||
# will manage the interception of the hidden layer output
|
||||
# and pipe it into the projecter and predictor nets
|
||||
|
||||
class NetWrapper(nn.Module):
|
||||
def __init__(self, net, output_dim, projection_hidden_size, projection_num_layers, layer = -2):
|
||||
super().__init__()
|
||||
self.net = net
|
||||
self.layer = layer
|
||||
|
||||
self.projector = None
|
||||
self.projection_hidden_size = projection_hidden_size
|
||||
self.projection_num_layers = projection_num_layers
|
||||
self.output_dim = output_dim
|
||||
|
||||
self.hidden = {}
|
||||
self.hook_registered = False
|
||||
|
||||
def _find_layer(self):
|
||||
if type(self.layer) == str:
|
||||
modules = dict([*self.net.named_modules()])
|
||||
return modules.get(self.layer, None)
|
||||
elif type(self.layer) == int:
|
||||
children = [*self.net.children()]
|
||||
return children[self.layer]
|
||||
return None
|
||||
|
||||
def _hook(self, _, input, output):
|
||||
device = input[0].device
|
||||
self.hidden[device] = output.flatten(1)
|
||||
|
||||
def _register_hook(self):
|
||||
layer = self._find_layer()
|
||||
assert layer is not None, f'hidden layer ({self.layer}) not found'
|
||||
handle = layer.register_forward_hook(self._hook)
|
||||
self.hook_registered = True
|
||||
|
||||
@singleton('projector')
|
||||
def _get_projector(self, hidden):
|
||||
_, dim = hidden.shape
|
||||
projector = MLP(dim, self.output_dim, self.projection_num_layers, self.projection_hidden_size)
|
||||
return projector.to(hidden)
|
||||
|
||||
def get_embedding(self, x):
|
||||
if self.layer == -1:
|
||||
return self.net(x)
|
||||
|
||||
if not self.hook_registered:
|
||||
self._register_hook()
|
||||
|
||||
self.hidden.clear()
|
||||
_ = self.net(x)
|
||||
hidden = self.hidden[x.device]
|
||||
self.hidden.clear()
|
||||
|
||||
assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
|
||||
return hidden
|
||||
|
||||
def forward(self, x, return_projection = True):
|
||||
embed = self.get_embedding(x)
|
||||
if not return_projection:
|
||||
return embed
|
||||
|
||||
projector = self._get_projector(embed)
|
||||
return projector(embed), embed
|
||||
|
||||
# main class
|
||||
|
||||
class Dino(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
net,
|
||||
image_size,
|
||||
hidden_layer = -2,
|
||||
projection_hidden_size = 256,
|
||||
num_classes_K = 65336,
|
||||
projection_layers = 4,
|
||||
student_temp = 0.9,
|
||||
teacher_temp = 0.04,
|
||||
local_upper_crop_scale = 0.4,
|
||||
global_lower_crop_scale = 0.5,
|
||||
moving_average_decay = 0.9,
|
||||
center_moving_average_decay = 0.9,
|
||||
augment_fn = None,
|
||||
augment_fn2 = None
|
||||
):
|
||||
super().__init__()
|
||||
self.net = net
|
||||
|
||||
# default BYOL augmentation
|
||||
|
||||
DEFAULT_AUG = torch.nn.Sequential(
|
||||
RandomApply(
|
||||
T.ColorJitter(0.8, 0.8, 0.8, 0.2),
|
||||
p = 0.3
|
||||
),
|
||||
T.RandomGrayscale(p=0.2),
|
||||
T.RandomHorizontalFlip(),
|
||||
RandomApply(
|
||||
T.GaussianBlur((3, 3), (1.0, 2.0)),
|
||||
p = 0.2
|
||||
),
|
||||
T.Normalize(
|
||||
mean=torch.tensor([0.485, 0.456, 0.406]),
|
||||
std=torch.tensor([0.229, 0.224, 0.225])),
|
||||
)
|
||||
|
||||
self.augment1 = default(augment_fn, DEFAULT_AUG)
|
||||
self.augment2 = default(augment_fn2, DEFAULT_AUG)
|
||||
|
||||
# local and global crops
|
||||
|
||||
self.local_crop = T.RandomResizedCrop((image_size, image_size), scale = (0.05, local_upper_crop_scale))
|
||||
self.global_crop = T.RandomResizedCrop((image_size, image_size), scale = (global_lower_crop_scale, 1.))
|
||||
|
||||
self.student_encoder = NetWrapper(net, num_classes_K, projection_hidden_size, projection_layers, layer = hidden_layer)
|
||||
|
||||
self.teacher_encoder = None
|
||||
self.teacher_ema_updater = EMA(moving_average_decay)
|
||||
|
||||
self.register_buffer('teacher_centers', torch.zeros(1, num_classes_K))
|
||||
self.register_buffer('last_teacher_centers', torch.zeros(1, num_classes_K))
|
||||
|
||||
self.teacher_centering_ema_updater = EMA(center_moving_average_decay)
|
||||
|
||||
self.student_temp = student_temp
|
||||
self.teacher_temp = teacher_temp
|
||||
|
||||
# get device of network and make wrapper same device
|
||||
device = get_module_device(net)
|
||||
self.to(device)
|
||||
|
||||
# send a mock image tensor to instantiate singleton parameters
|
||||
self.forward(torch.randn(2, 3, image_size, image_size, device=device))
|
||||
|
||||
@singleton('teacher_encoder')
|
||||
def _get_teacher_encoder(self):
|
||||
teacher_encoder = copy.deepcopy(self.student_encoder)
|
||||
set_requires_grad(teacher_encoder, False)
|
||||
return teacher_encoder
|
||||
|
||||
def reset_moving_average(self):
|
||||
del self.teacher_encoder
|
||||
self.teacher_encoder = None
|
||||
|
||||
def update_moving_average(self):
|
||||
assert self.teacher_encoder is not None, 'target encoder has not been created yet'
|
||||
update_moving_average(self.teacher_ema_updater, self.teacher_encoder, self.student_encoder)
|
||||
|
||||
new_teacher_centers = self.teacher_centering_ema_updater.update_average(self.teacher_centers, self.last_teacher_centers)
|
||||
self.teacher_centers.copy_(new_teacher_centers)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
return_embedding = False,
|
||||
return_projection = True,
|
||||
student_temp = None,
|
||||
teacher_temp = None
|
||||
):
|
||||
if return_embedding:
|
||||
return self.student_encoder(x, return_projection = return_projection)
|
||||
|
||||
image_one, image_two = self.augment1(x), self.augment2(x)
|
||||
|
||||
local_image_one, local_image_two = self.local_crop(image_one), self.local_crop(image_two)
|
||||
global_image_one, global_image_two = self.global_crop(image_one), self.global_crop(image_two)
|
||||
|
||||
student_proj_one, _ = self.student_encoder(local_image_one)
|
||||
student_proj_two, _ = self.student_encoder(local_image_two)
|
||||
|
||||
with torch.no_grad():
|
||||
teacher_encoder = self._get_teacher_encoder()
|
||||
teacher_proj_one, _ = teacher_encoder(global_image_one)
|
||||
teacher_proj_two, _ = teacher_encoder(global_image_two)
|
||||
|
||||
loss_fn_ = partial(
|
||||
loss_fn,
|
||||
student_temp = default(student_temp, self.student_temp),
|
||||
teacher_temp = default(teacher_temp, self.teacher_temp),
|
||||
centers = self.teacher_centers
|
||||
)
|
||||
|
||||
teacher_logits_avg = torch.cat((teacher_proj_one, teacher_proj_two)).mean(dim = 0)
|
||||
self.last_teacher_centers.copy_(teacher_logits_avg)
|
||||
|
||||
loss = (loss_fn_(teacher_proj_one, student_proj_two) + loss_fn_(teacher_proj_two, student_proj_one)) / 2
|
||||
return loss
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from vit_pytorch.vit_pytorch import ViT
|
||||
from vit_pytorch.vit import ViT
|
||||
from vit_pytorch.t2t import T2TViT
|
||||
from vit_pytorch.efficient import ViT as EfficientViT
|
||||
|
||||
@@ -15,7 +15,7 @@ def exists(val):
|
||||
# classes
|
||||
|
||||
class DistillMixin:
|
||||
def forward(self, img, distill_token = None, mask = None):
|
||||
def forward(self, img, distill_token = None):
|
||||
distilling = exists(distill_token)
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
@@ -28,7 +28,7 @@ class DistillMixin:
|
||||
distill_tokens = repeat(distill_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((x, distill_tokens), dim = 1)
|
||||
|
||||
x = self._attend(x, mask)
|
||||
x = self._attend(x)
|
||||
|
||||
if distilling:
|
||||
x, distill_tokens = x[:, :-1], x[:, -1]
|
||||
@@ -56,9 +56,9 @@ class DistillableViT(DistillMixin, ViT):
|
||||
v.load_state_dict(self.state_dict())
|
||||
return v
|
||||
|
||||
def _attend(self, x, mask):
|
||||
def _attend(self, x):
|
||||
x = self.dropout(x)
|
||||
x = self.transformer(x, mask)
|
||||
x = self.transformer(x)
|
||||
return x
|
||||
|
||||
class DistillableT2TViT(DistillMixin, T2TViT):
|
||||
@@ -74,7 +74,7 @@ class DistillableT2TViT(DistillMixin, T2TViT):
|
||||
v.load_state_dict(self.state_dict())
|
||||
return v
|
||||
|
||||
def _attend(self, x, mask):
|
||||
def _attend(self, x):
|
||||
x = self.dropout(x)
|
||||
x = self.transformer(x)
|
||||
return x
|
||||
@@ -92,7 +92,7 @@ class DistillableEfficientViT(DistillMixin, EfficientViT):
|
||||
v.load_state_dict(self.state_dict())
|
||||
return v
|
||||
|
||||
def _attend(self, x, mask):
|
||||
def _attend(self, x):
|
||||
return self.transformer(x)
|
||||
|
||||
# knowledge distillation wrapper
|
||||
@@ -104,7 +104,8 @@ class DistillWrapper(nn.Module):
|
||||
teacher,
|
||||
student,
|
||||
temperature = 1.,
|
||||
alpha = 0.5
|
||||
alpha = 0.5,
|
||||
hard = False
|
||||
):
|
||||
super().__init__()
|
||||
assert (isinstance(student, (DistillableViT, DistillableT2TViT, DistillableEfficientViT))) , 'student must be a vision transformer'
|
||||
@@ -116,6 +117,7 @@ class DistillWrapper(nn.Module):
|
||||
num_classes = student.num_classes
|
||||
self.temperature = temperature
|
||||
self.alpha = alpha
|
||||
self.hard = hard
|
||||
|
||||
self.distillation_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
|
||||
@@ -137,11 +139,15 @@ class DistillWrapper(nn.Module):
|
||||
|
||||
loss = F.cross_entropy(student_logits, labels)
|
||||
|
||||
distill_loss = F.kl_div(
|
||||
F.log_softmax(distill_logits / T, dim = -1),
|
||||
F.softmax(teacher_logits / T, dim = -1).detach(),
|
||||
reduction = 'batchmean')
|
||||
if not self.hard:
|
||||
distill_loss = F.kl_div(
|
||||
F.log_softmax(distill_logits / T, dim = -1),
|
||||
F.softmax(teacher_logits / T, dim = -1).detach(),
|
||||
reduction = 'batchmean')
|
||||
distill_loss *= T ** 2
|
||||
|
||||
distill_loss *= T ** 2
|
||||
else:
|
||||
teacher_labels = teacher_logits.argmax(dim = -1)
|
||||
distill_loss = F.cross_entropy(student_logits, teacher_labels)
|
||||
|
||||
return loss * alpha + distill_loss * (1 - alpha)
|
||||
return loss * (1 - alpha) + distill_loss * alpha
|
||||
|
||||
193
vit_pytorch/levit.py
Normal file
@@ -0,0 +1,193 @@
|
||||
from math import ceil
|
||||
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def cast_tuple(val, l = 3):
|
||||
val = val if isinstance(val, tuple) else (val,)
|
||||
return (*val, *((val[-1],) * max(l - len(val), 0)))
|
||||
|
||||
def always(val):
|
||||
return lambda *args, **kwargs: val
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, mult, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(dim, dim * mult, 1),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Conv2d(dim * mult, dim, 1),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, fmap_size, heads = 8, dim_key = 32, dim_value = 64, dropout = 0., dim_out = None, downsample = False):
|
||||
super().__init__()
|
||||
inner_dim_key = dim_key * heads
|
||||
inner_dim_value = dim_value * heads
|
||||
dim_out = default(dim_out, dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_key ** -0.5
|
||||
|
||||
self.to_q = nn.Sequential(nn.Conv2d(dim, inner_dim_key, 1, stride = (2 if downsample else 1), bias = False), nn.BatchNorm2d(inner_dim_key))
|
||||
self.to_k = nn.Sequential(nn.Conv2d(dim, inner_dim_key, 1, bias = False), nn.BatchNorm2d(inner_dim_key))
|
||||
self.to_v = nn.Sequential(nn.Conv2d(dim, inner_dim_value, 1, bias = False), nn.BatchNorm2d(inner_dim_value))
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
out_batch_norm = nn.BatchNorm2d(dim_out)
|
||||
nn.init.zeros_(out_batch_norm.weight)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.GELU(),
|
||||
nn.Conv2d(inner_dim_value, dim_out, 1),
|
||||
out_batch_norm,
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
# positional bias
|
||||
|
||||
self.pos_bias = nn.Embedding(fmap_size * fmap_size, heads)
|
||||
|
||||
q_range = torch.arange(0, fmap_size, step = (2 if downsample else 1))
|
||||
k_range = torch.arange(fmap_size)
|
||||
|
||||
q_pos = torch.stack(torch.meshgrid(q_range, q_range), dim = -1)
|
||||
k_pos = torch.stack(torch.meshgrid(k_range, k_range), dim = -1)
|
||||
|
||||
q_pos, k_pos = map(lambda t: rearrange(t, 'i j c -> (i j) c'), (q_pos, k_pos))
|
||||
rel_pos = (q_pos[:, None, ...] - k_pos[None, :, ...]).abs()
|
||||
|
||||
x_rel, y_rel = rel_pos.unbind(dim = -1)
|
||||
pos_indices = (x_rel * fmap_size) + y_rel
|
||||
|
||||
self.register_buffer('pos_indices', pos_indices)
|
||||
|
||||
def apply_pos_bias(self, fmap):
|
||||
bias = self.pos_bias(self.pos_indices)
|
||||
bias = rearrange(bias, 'i j h -> () h i j')
|
||||
return fmap + (bias / self.scale)
|
||||
|
||||
def forward(self, x):
|
||||
b, n, *_, h = *x.shape, self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
y = q.shape[2]
|
||||
|
||||
qkv = (q, self.to_k(x), self.to_v(x))
|
||||
q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = h), qkv)
|
||||
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
|
||||
dots = self.apply_pos_bias(dots)
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h (x y) d -> b (h d) x y', h = h, y = y)
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, fmap_size, depth, heads, dim_key, dim_value, mlp_mult = 2, dropout = 0., dim_out = None, downsample = False):
|
||||
super().__init__()
|
||||
dim_out = default(dim_out, dim)
|
||||
self.layers = nn.ModuleList([])
|
||||
self.attn_residual = (not downsample) and dim == dim_out
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, fmap_size = fmap_size, heads = heads, dim_key = dim_key, dim_value = dim_value, dropout = dropout, downsample = downsample, dim_out = dim_out),
|
||||
FeedForward(dim_out, mlp_mult, dropout = dropout)
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
attn_res = (x if self.attn_residual else 0)
|
||||
x = attn(x) + attn_res
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
class LeViT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_size,
|
||||
num_classes,
|
||||
dim,
|
||||
depth,
|
||||
heads,
|
||||
mlp_mult,
|
||||
stages = 3,
|
||||
dim_key = 32,
|
||||
dim_value = 64,
|
||||
dropout = 0.,
|
||||
num_distill_classes = None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
dims = cast_tuple(dim, stages)
|
||||
depths = cast_tuple(depth, stages)
|
||||
layer_heads = cast_tuple(heads, stages)
|
||||
|
||||
assert all(map(lambda t: len(t) == stages, (dims, depths, layer_heads))), 'dimensions, depths, and heads must be a tuple that is less than the designated number of stages'
|
||||
|
||||
self.conv_embedding = nn.Sequential(
|
||||
nn.Conv2d(3, 32, 3, stride = 2, padding = 1),
|
||||
nn.Conv2d(32, 64, 3, stride = 2, padding = 1),
|
||||
nn.Conv2d(64, 128, 3, stride = 2, padding = 1),
|
||||
nn.Conv2d(128, dims[0], 3, stride = 2, padding = 1)
|
||||
)
|
||||
|
||||
fmap_size = image_size // (2 ** 4)
|
||||
layers = []
|
||||
|
||||
for ind, dim, depth, heads in zip(range(stages), dims, depths, layer_heads):
|
||||
is_last = ind == (stages - 1)
|
||||
layers.append(Transformer(dim, fmap_size, depth, heads, dim_key, dim_value, mlp_mult, dropout))
|
||||
|
||||
if not is_last:
|
||||
next_dim = dims[ind + 1]
|
||||
layers.append(Transformer(dim, fmap_size, 1, heads * 2, dim_key, dim_value, dim_out = next_dim, downsample = True))
|
||||
fmap_size = ceil(fmap_size / 2)
|
||||
|
||||
self.backbone = nn.Sequential(*layers)
|
||||
|
||||
self.pool = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
Rearrange('... () () -> ...')
|
||||
)
|
||||
|
||||
self.distill_head = nn.Linear(dim, num_distill_classes) if exists(num_distill_classes) else always(None)
|
||||
self.mlp_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, img):
|
||||
x = self.conv_embedding(img)
|
||||
|
||||
x = self.backbone(x)
|
||||
|
||||
x = self.pool(x)
|
||||
|
||||
out = self.mlp_head(x)
|
||||
distill = self.distill_head(x)
|
||||
|
||||
if exists(distill):
|
||||
return out, distill
|
||||
|
||||
return out
|
||||
152
vit_pytorch/local_vit.py
Normal file
@@ -0,0 +1,152 @@
|
||||
from math import sqrt
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# classes
|
||||
|
||||
class Residual(nn.Module):
|
||||
def __init__(self, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(x, **kwargs) + x
|
||||
|
||||
class ExcludeCLS(nn.Module):
|
||||
def __init__(self, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
cls_token, x = x[:, :1], x[:, 1:]
|
||||
x = self.fn(x, **kwargs)
|
||||
return torch.cat((cls_token, x), dim = 1)
|
||||
|
||||
# prenorm
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
# feed forward related classes
|
||||
|
||||
class DepthWiseConv2d(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, kernel_size, padding, stride = 1, bias = True):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
|
||||
nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(dim, hidden_dim, 1),
|
||||
nn.Hardswish(),
|
||||
DepthWiseConv2d(hidden_dim, hidden_dim, 3, padding = 1),
|
||||
nn.Hardswish(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Conv2d(hidden_dim, dim, 1),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
h = w = int(sqrt(x.shape[-2]))
|
||||
x = rearrange(x, 'b (h w) c -> b c h w', h = h, w = w)
|
||||
x = self.net(x)
|
||||
x = rearrange(x, 'b c h w -> b (h w) c')
|
||||
return x
|
||||
|
||||
# attention
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
b, n, _, h = *x.shape, self.heads
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
||||
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
|
||||
ExcludeCLS(Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))))
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x)
|
||||
x = ff(x)
|
||||
return x
|
||||
|
||||
# main class
|
||||
|
||||
class LocalViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
|
||||
super().__init__()
|
||||
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
patch_dim = channels * patch_size ** 2
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
||||
nn.Linear(patch_dim, dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
return self.mlp_head(x[:, 0])
|
||||
@@ -50,7 +50,7 @@ class MPPLoss(nn.Module):
|
||||
avg_target = target.mean(dim=3)
|
||||
|
||||
bin_size = self.max_pixel_val / self.output_channel_bits
|
||||
channel_bins = torch.arange(bin_size, self.max_pixel_val, bin_size)
|
||||
channel_bins = torch.arange(bin_size, self.max_pixel_val, bin_size).to(avg_target.device)
|
||||
discretized_target = torch.bucketize(avg_target, channel_bins)
|
||||
discretized_target = F.one_hot(discretized_target,
|
||||
self.output_channel_bits)
|
||||
@@ -86,7 +86,6 @@ class MPP(nn.Module):
|
||||
replace_prob=0.5,
|
||||
random_patch_prob=0.5):
|
||||
super().__init__()
|
||||
|
||||
self.transformer = transformer
|
||||
self.loss = MPPLoss(patch_size, channels, output_channel_bits,
|
||||
max_pixel_val)
|
||||
@@ -127,8 +126,9 @@ class MPP(nn.Module):
|
||||
random_patch_sampling_prob = self.random_patch_prob / (
|
||||
1 - self.replace_prob)
|
||||
random_patch_prob = prob_mask_like(input,
|
||||
random_patch_sampling_prob)
|
||||
bool_random_patch_prob = mask * random_patch_prob == True
|
||||
random_patch_sampling_prob).to(mask.device)
|
||||
|
||||
bool_random_patch_prob = mask * (random_patch_prob == True)
|
||||
random_patches = torch.randint(0,
|
||||
input.shape[1],
|
||||
(input.shape[0], input.shape[1]),
|
||||
@@ -140,7 +140,7 @@ class MPP(nn.Module):
|
||||
bool_random_patch_prob]
|
||||
|
||||
# [mask] input
|
||||
replace_prob = prob_mask_like(input, self.replace_prob)
|
||||
replace_prob = prob_mask_like(input, self.replace_prob).to(mask.device)
|
||||
bool_mask_replace = (mask * replace_prob) == True
|
||||
masked_input[bool_mask_replace] = self.mask_token
|
||||
|
||||
|
||||
169
vit_pytorch/nest.py
Normal file
@@ -0,0 +1,169 @@
|
||||
from functools import partial
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange, Reduce
|
||||
|
||||
# helpers
|
||||
|
||||
def cast_tuple(val, depth):
|
||||
return val if isinstance(val, tuple) else ((val,) * depth)
|
||||
|
||||
LayerNorm = partial(nn.InstanceNorm2d, affine = True)
|
||||
|
||||
# classes
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = LayerNorm(dim)
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, mlp_mult = 4, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(dim, dim * mlp_mult, 1),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Conv2d(dim * mlp_mult, dim, 1),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dropout = 0.):
|
||||
super().__init__()
|
||||
dim_head = dim // heads
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Conv2d(inner_dim, dim, 1),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w, heads = *x.shape, self.heads
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = 1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> b h (x y) d', h = heads), qkv)
|
||||
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
|
||||
return self.to_out(out)
|
||||
|
||||
def Aggregate(dim, dim_out):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(dim, dim_out, 3, padding = 1),
|
||||
LayerNorm(dim_out),
|
||||
nn.MaxPool2d(3, stride = 2, padding = 1)
|
||||
)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, seq_len, depth, heads, mlp_mult, dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
self.pos_emb = nn.Parameter(torch.randn(seq_len))
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, heads = heads, dropout = dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout))
|
||||
]))
|
||||
def forward(self, x):
|
||||
*_, h, w = x.shape
|
||||
|
||||
pos_emb = self.pos_emb[:(h * w)]
|
||||
pos_emb = rearrange(pos_emb, '(h w) -> () () h w', h = h, w = w)
|
||||
x = x + pos_emb
|
||||
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
class NesT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_size,
|
||||
patch_size,
|
||||
num_classes,
|
||||
dim,
|
||||
heads,
|
||||
num_hierarchies,
|
||||
block_repeats,
|
||||
mlp_mult = 4,
|
||||
channels = 3,
|
||||
dim_head = 64,
|
||||
dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
assert (image_size % patch_size) == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
patch_dim = channels * patch_size ** 2
|
||||
fmap_size = image_size // patch_size
|
||||
blocks = 2 ** (num_hierarchies - 1)
|
||||
|
||||
seq_len = (fmap_size // blocks) ** 2 # sequence length is held constant across heirarchy
|
||||
hierarchies = list(reversed(range(num_hierarchies)))
|
||||
mults = [2 ** i for i in hierarchies]
|
||||
|
||||
layer_heads = list(map(lambda t: t * heads, mults))
|
||||
layer_dims = list(map(lambda t: t * dim, mults))
|
||||
|
||||
layer_dims = [*layer_dims, layer_dims[-1]]
|
||||
dim_pairs = zip(layer_dims[:-1], layer_dims[1:])
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (p1 p2 c) h w', p1 = patch_size, p2 = patch_size),
|
||||
nn.Conv2d(patch_dim, layer_dims[0], 1),
|
||||
)
|
||||
|
||||
block_repeats = cast_tuple(block_repeats, num_hierarchies)
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
for level, heads, (dim_in, dim_out), block_repeat in zip(hierarchies, layer_heads, dim_pairs, block_repeats):
|
||||
is_last = level == 0
|
||||
depth = block_repeat
|
||||
|
||||
self.layers.append(nn.ModuleList([
|
||||
Transformer(dim_in, seq_len, depth, heads, mlp_mult, dropout),
|
||||
Aggregate(dim_in, dim_out) if not is_last else nn.Identity()
|
||||
]))
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
LayerNorm(dim),
|
||||
Reduce('b c h w -> b c', 'mean'),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
x = self.to_patch_embedding(img)
|
||||
b, c, h, w = x.shape
|
||||
|
||||
num_hierarchies = len(self.layers)
|
||||
|
||||
for level, (transformer, aggregate) in zip(reversed(range(num_hierarchies)), self.layers):
|
||||
block_size = 2 ** level
|
||||
x = rearrange(x, 'b c (b1 h) (b2 w) -> (b b1 b2) c h w', b1 = block_size, b2 = block_size)
|
||||
x = transformer(x)
|
||||
x = rearrange(x, '(b b1 b2) c h w -> b c (b1 h) (b2 w)', b1 = block_size, b2 = block_size)
|
||||
x = aggregate(x)
|
||||
|
||||
return self.mlp_head(x)
|
||||
183
vit_pytorch/pit.py
Normal file
@@ -0,0 +1,183 @@
|
||||
from math import sqrt
|
||||
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def cast_tuple(val, num):
|
||||
return val if isinstance(val, tuple) else (val,) * num
|
||||
|
||||
def conv_output_size(image_size, kernel_size, stride, padding = 0):
|
||||
return int(((image_size - kernel_size + (2 * padding)) / stride) + 1)
|
||||
|
||||
# classes
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
b, n, _, h = *x.shape, self.heads
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
||||
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
# depthwise convolution, for pooling
|
||||
|
||||
class DepthWiseConv2d(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, kernel_size, padding, stride, bias = True):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(dim_in, dim_out, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
|
||||
nn.Conv2d(dim_out, dim_out, kernel_size = 1, bias = bias)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
# pooling layer
|
||||
|
||||
class Pool(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.downsample = DepthWiseConv2d(dim, dim * 2, kernel_size = 3, stride = 2, padding = 1)
|
||||
self.cls_ff = nn.Linear(dim, dim * 2)
|
||||
|
||||
def forward(self, x):
|
||||
cls_token, tokens = x[:, :1], x[:, 1:]
|
||||
|
||||
cls_token = self.cls_ff(cls_token)
|
||||
|
||||
tokens = rearrange(tokens, 'b (h w) c -> b c h w', h = int(sqrt(tokens.shape[1])))
|
||||
tokens = self.downsample(tokens)
|
||||
tokens = rearrange(tokens, 'b c h w -> b (h w) c')
|
||||
|
||||
return torch.cat((cls_token, tokens), dim = 1)
|
||||
|
||||
# main class
|
||||
|
||||
class PiT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_size,
|
||||
patch_size,
|
||||
num_classes,
|
||||
dim,
|
||||
depth,
|
||||
heads,
|
||||
mlp_dim,
|
||||
dim_head = 64,
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
assert isinstance(depth, tuple), 'depth must be a tuple of integers, specifying the number of blocks before each downsizing'
|
||||
heads = cast_tuple(heads, len(depth))
|
||||
|
||||
patch_dim = 3 * patch_size ** 2
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
nn.Unfold(kernel_size = patch_size, stride = patch_size // 2),
|
||||
Rearrange('b c n -> b n c'),
|
||||
nn.Linear(patch_dim, dim)
|
||||
)
|
||||
|
||||
output_size = conv_output_size(image_size, patch_size, patch_size // 2)
|
||||
num_patches = output_size ** 2
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
layers = []
|
||||
|
||||
for ind, (layer_depth, layer_heads) in enumerate(zip(depth, heads)):
|
||||
not_last = ind < (len(depth) - 1)
|
||||
|
||||
layers.append(Transformer(dim, layer_depth, layer_heads, dim_head, mlp_dim, dropout))
|
||||
|
||||
if not_last:
|
||||
layers.append(Pool(dim))
|
||||
dim *= 2
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.layers(x)
|
||||
|
||||
return self.mlp_head(x[:, 0])
|
||||
54
vit_pytorch/recorder.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from functools import wraps
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vit_pytorch.vit import Attention
|
||||
|
||||
def find_modules(nn_module, type):
|
||||
return [module for module in nn_module.modules() if isinstance(module, type)]
|
||||
|
||||
class Recorder(nn.Module):
|
||||
def __init__(self, vit):
|
||||
super().__init__()
|
||||
self.vit = vit
|
||||
|
||||
self.data = None
|
||||
self.recordings = []
|
||||
self.hooks = []
|
||||
self.hook_registered = False
|
||||
self.ejected = False
|
||||
|
||||
def _hook(self, _, input, output):
|
||||
self.recordings.append(output.clone().detach())
|
||||
|
||||
def _register_hook(self):
|
||||
modules = find_modules(self.vit.transformer, Attention)
|
||||
for module in modules:
|
||||
handle = module.attend.register_forward_hook(self._hook)
|
||||
self.hooks.append(handle)
|
||||
self.hook_registered = True
|
||||
|
||||
def eject(self):
|
||||
self.ejected = True
|
||||
for hook in self.hooks:
|
||||
hook.remove()
|
||||
self.hooks.clear()
|
||||
return self.vit
|
||||
|
||||
def clear(self):
|
||||
self.recordings.clear()
|
||||
|
||||
def record(self, attn):
|
||||
recording = attn.clone().detach()
|
||||
self.recordings.append(recording)
|
||||
|
||||
def forward(self, img):
|
||||
assert not self.ejected, 'recorder has been ejected, cannot be used anymore'
|
||||
self.clear()
|
||||
|
||||
if not self.hook_registered:
|
||||
self._register_hook()
|
||||
|
||||
pred = self.vit(img)
|
||||
attns = torch.stack(self.recordings, dim = 1)
|
||||
return pred, attns
|
||||
209
vit_pytorch/rvt.py
Normal file
@@ -0,0 +1,209 @@
|
||||
from math import sqrt, pi, log
|
||||
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# rotary embeddings
|
||||
|
||||
def rotate_every_two(x):
|
||||
x = rearrange(x, '... (d j) -> ... d j', j = 2)
|
||||
x1, x2 = x.unbind(dim = -1)
|
||||
x = torch.stack((-x2, x1), dim = -1)
|
||||
return rearrange(x, '... d j -> ... (d j)')
|
||||
|
||||
class AxialRotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_freq = 10):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
scales = torch.logspace(0., log(max_freq / 2) / log(2), self.dim // 4, base = 2)
|
||||
self.register_buffer('scales', scales)
|
||||
|
||||
def forward(self, x):
|
||||
device, dtype, n = x.device, x.dtype, int(sqrt(x.shape[-2]))
|
||||
|
||||
seq = torch.linspace(-1., 1., steps = n, device = device)
|
||||
seq = seq.unsqueeze(-1)
|
||||
|
||||
scales = self.scales[(*((None,) * (len(seq.shape) - 1)), Ellipsis)]
|
||||
scales = scales.to(x)
|
||||
|
||||
seq = seq * scales * pi
|
||||
|
||||
x_sinu = repeat(seq, 'i d -> i j d', j = n)
|
||||
y_sinu = repeat(seq, 'j d -> i j d', i = n)
|
||||
|
||||
sin = torch.cat((x_sinu.sin(), y_sinu.sin()), dim = -1)
|
||||
cos = torch.cat((x_sinu.cos(), y_sinu.cos()), dim = -1)
|
||||
|
||||
sin, cos = map(lambda t: rearrange(t, 'i j d -> (i j) d'), (sin, cos))
|
||||
sin, cos = map(lambda t: repeat(t, 'n d -> () n (d j)', j = 2), (sin, cos))
|
||||
return sin, cos
|
||||
|
||||
class DepthWiseConv2d(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, kernel_size, padding, stride = 1, bias = True):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
|
||||
nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
# helper classes
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
class SpatialConv(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, kernel, bias = False):
|
||||
super().__init__()
|
||||
self.conv = DepthWiseConv2d(dim_in, dim_out, kernel, padding = kernel // 2, bias = False)
|
||||
self.cls_proj = nn.Linear(dim_in, dim_out) if dim_in != dim_out else nn.Identity()
|
||||
|
||||
def forward(self, x, fmap_dims):
|
||||
cls_token, x = x[:, :1], x[:, 1:]
|
||||
x = rearrange(x, 'b (h w) d -> b d h w', **fmap_dims)
|
||||
x = self.conv(x)
|
||||
x = rearrange(x, 'b d h w -> b (h w) d')
|
||||
cls_token = self.cls_proj(cls_token)
|
||||
return torch.cat((cls_token, x), dim = 1)
|
||||
|
||||
class GEGLU(nn.Module):
|
||||
def forward(self, x):
|
||||
x, gates = x.chunk(2, dim = -1)
|
||||
return F.gelu(gates) * x
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0., use_glu = True):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim * 2 if use_glu else hidden_dim),
|
||||
GEGLU() if use_glu else nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., use_rotary = True, use_ds_conv = True, conv_query_kernel = 5):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.use_rotary = use_rotary
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
self.use_ds_conv = use_ds_conv
|
||||
|
||||
self.to_q = SpatialConv(dim, inner_dim, conv_query_kernel, bias = False) if use_ds_conv else nn.Linear(dim, inner_dim, bias = False)
|
||||
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, pos_emb, fmap_dims):
|
||||
b, n, _, h = *x.shape, self.heads
|
||||
|
||||
to_q_kwargs = {'fmap_dims': fmap_dims} if self.use_ds_conv else {}
|
||||
q = self.to_q(x, **to_q_kwargs)
|
||||
|
||||
qkv = (q, *self.to_kv(x).chunk(2, dim = -1))
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)
|
||||
|
||||
if self.use_rotary:
|
||||
# apply 2d rotary embeddings to queries and keys, excluding CLS tokens
|
||||
|
||||
sin, cos = pos_emb
|
||||
dim_rotary = sin.shape[-1]
|
||||
|
||||
(q_cls, q), (k_cls, k) = map(lambda t: (t[:, :1], t[:, 1:]), (q, k))
|
||||
|
||||
# handle the case where rotary dimension < head dimension
|
||||
|
||||
(q, q_pass), (k, k_pass) = map(lambda t: (t[..., :dim_rotary], t[..., dim_rotary:]), (q, k))
|
||||
q, k = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k))
|
||||
q, k = map(lambda t: torch.cat(t, dim = -1), ((q, q_pass), (k, k_pass)))
|
||||
|
||||
# concat back the CLS tokens
|
||||
|
||||
q = torch.cat((q_cls, q), dim = 1)
|
||||
k = torch.cat((k_cls, k), dim = 1)
|
||||
|
||||
dots = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = einsum('b i j, b j d -> b i d', attn, v)
|
||||
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., use_rotary = True, use_ds_conv = True, use_glu = True):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
self.pos_emb = AxialRotaryEmbedding(dim_head)
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_rotary = use_rotary, use_ds_conv = use_ds_conv)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout, use_glu = use_glu))
|
||||
]))
|
||||
def forward(self, x, fmap_dims):
|
||||
pos_emb = self.pos_emb(x[:, 1:])
|
||||
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x, pos_emb = pos_emb, fmap_dims = fmap_dims) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
# Rotary Vision Transformer
|
||||
|
||||
class RvT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., use_rotary = True, use_ds_conv = True, use_glu = True):
|
||||
super().__init__()
|
||||
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
patch_dim = channels * patch_size ** 2
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
||||
nn.Linear(patch_dim, dim),
|
||||
)
|
||||
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, use_rotary, use_ds_conv, use_glu)
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
b, _, h, w, p = *img.shape, self.patch_size
|
||||
|
||||
x = self.to_patch_embedding(img)
|
||||
n = x.shape[1]
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
|
||||
fmap_dims = {'h': h // p, 'w': w // p}
|
||||
x = self.transformer(x, fmap_dims = fmap_dims)
|
||||
|
||||
return self.mlp_head(x[:, 0])
|
||||
@@ -2,7 +2,7 @@ import math
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vit_pytorch.vit_pytorch import Transformer
|
||||
from vit_pytorch.vit import Transformer
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
229
vit_pytorch/twins_svt.py
Normal file
@@ -0,0 +1,229 @@
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helper methods
|
||||
|
||||
def group_dict_by_key(cond, d):
|
||||
return_val = [dict(), dict()]
|
||||
for key in d.keys():
|
||||
match = bool(cond(key))
|
||||
ind = int(not match)
|
||||
return_val[ind][key] = d[key]
|
||||
return (*return_val,)
|
||||
|
||||
def group_by_key_prefix_and_remove_prefix(prefix, d):
|
||||
kwargs_with_prefix, kwargs = group_dict_by_key(lambda x: x.startswith(prefix), d)
|
||||
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
||||
return kwargs_without_prefix, kwargs
|
||||
|
||||
# classes
|
||||
|
||||
class Residual(nn.Module):
|
||||
def __init__(self, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(x, **kwargs) + x
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps = 1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (std + self.eps) * self.g + self.b
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = LayerNorm(dim)
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
x = self.norm(x)
|
||||
return self.fn(x, **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, mult = 4, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(dim, dim * mult, 1),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Conv2d(dim * mult, dim, 1),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class PatchEmbedding(nn.Module):
|
||||
def __init__(self, *, dim, dim_out, patch_size):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_out = dim_out
|
||||
self.patch_size = patch_size
|
||||
self.proj = nn.Conv2d(patch_size ** 2 * dim, dim_out, 1)
|
||||
|
||||
def forward(self, fmap):
|
||||
p = self.patch_size
|
||||
fmap = rearrange(fmap, 'b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = p, p2 = p)
|
||||
return self.proj(fmap)
|
||||
|
||||
class PEG(nn.Module):
|
||||
def __init__(self, dim, kernel_size = 3):
|
||||
super().__init__()
|
||||
self.proj = Residual(nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1))
|
||||
|
||||
def forward(self, x):
|
||||
return self.proj(x)
|
||||
|
||||
class LocalAttention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., patch_size = 7):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.patch_size = patch_size
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False)
|
||||
self.to_kv = nn.Conv2d(dim, inner_dim * 2, 1, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Conv2d(inner_dim, dim, 1),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, fmap):
|
||||
shape, p = fmap.shape, self.patch_size
|
||||
b, n, x, y, h = *shape, self.heads
|
||||
x, y = map(lambda t: t // p, (x, y))
|
||||
|
||||
fmap = rearrange(fmap, 'b c (x p1) (y p2) -> (b x y) c p1 p2', p1 = p, p2 = p)
|
||||
|
||||
q, k, v = (self.to_q(fmap), *self.to_kv(fmap).chunk(2, dim = 1))
|
||||
q, k, v = map(lambda t: rearrange(t, 'b (h d) p1 p2 -> (b h) (p1 p2) d', h = h), (q, k, v))
|
||||
|
||||
dots = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
|
||||
attn = dots.softmax(dim = - 1)
|
||||
|
||||
out = einsum('b i j, b j d -> b i d', attn, v)
|
||||
out = rearrange(out, '(b x y h) (p1 p2) d -> b (h d) (x p1) (y p2)', h = h, x = x, y = y, p1 = p, p2 = p)
|
||||
return self.to_out(out)
|
||||
|
||||
class GlobalAttention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., k = 7):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False)
|
||||
self.to_kv = nn.Conv2d(dim, inner_dim * 2, k, stride = k, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Conv2d(inner_dim, dim, 1),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
shape = x.shape
|
||||
b, n, _, y, h = *shape, self.heads
|
||||
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = 1))
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> (b h) (x y) d', h = h), (q, k, v))
|
||||
|
||||
dots = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
|
||||
attn = dots.softmax(dim = -1)
|
||||
|
||||
out = einsum('b i j, b j d -> b i d', attn, v)
|
||||
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, y = y)
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads = 8, dim_head = 64, mlp_mult = 4, local_patch_size = 7, global_k = 7, dropout = 0., has_local = True):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Residual(PreNorm(dim, LocalAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout, patch_size = local_patch_size))) if has_local else nn.Identity(),
|
||||
Residual(PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout))) if has_local else nn.Identity(),
|
||||
Residual(PreNorm(dim, GlobalAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout, k = global_k))),
|
||||
Residual(PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout)))
|
||||
]))
|
||||
def forward(self, x):
|
||||
for local_attn, ff1, global_attn, ff2 in self.layers:
|
||||
x = local_attn(x)
|
||||
x = ff1(x)
|
||||
x = global_attn(x)
|
||||
x = ff2(x)
|
||||
return x
|
||||
|
||||
class TwinsSVT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
num_classes,
|
||||
s1_emb_dim = 64,
|
||||
s1_patch_size = 4,
|
||||
s1_local_patch_size = 7,
|
||||
s1_global_k = 7,
|
||||
s1_depth = 1,
|
||||
s2_emb_dim = 128,
|
||||
s2_patch_size = 2,
|
||||
s2_local_patch_size = 7,
|
||||
s2_global_k = 7,
|
||||
s2_depth = 1,
|
||||
s3_emb_dim = 256,
|
||||
s3_patch_size = 2,
|
||||
s3_local_patch_size = 7,
|
||||
s3_global_k = 7,
|
||||
s3_depth = 5,
|
||||
s4_emb_dim = 512,
|
||||
s4_patch_size = 2,
|
||||
s4_local_patch_size = 7,
|
||||
s4_global_k = 7,
|
||||
s4_depth = 4,
|
||||
peg_kernel_size = 3,
|
||||
dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
kwargs = dict(locals())
|
||||
|
||||
dim = 3
|
||||
layers = []
|
||||
|
||||
for prefix in ('s1', 's2', 's3', 's4'):
|
||||
config, kwargs = group_by_key_prefix_and_remove_prefix(f'{prefix}_', kwargs)
|
||||
is_last = prefix == 's4'
|
||||
|
||||
dim_next = config['emb_dim']
|
||||
|
||||
layers.append(nn.Sequential(
|
||||
PatchEmbedding(dim = dim, dim_out = dim_next, patch_size = config['patch_size']),
|
||||
Transformer(dim = dim_next, depth = 1, local_patch_size = config['local_patch_size'], global_k = config['global_k'], dropout = dropout, has_local = not is_last),
|
||||
PEG(dim = dim_next, kernel_size = peg_kernel_size),
|
||||
Transformer(dim = dim_next, depth = config['depth'], local_patch_size = config['local_patch_size'], global_k = config['global_k'], dropout = dropout, has_local = not is_last)
|
||||
))
|
||||
|
||||
dim = dim_next
|
||||
|
||||
self.layers = nn.Sequential(
|
||||
*layers,
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
Rearrange('... () () -> ...'),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
126
vit_pytorch/vit.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
# classes
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
b, n, _, h = *x.shape, self.heads
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
||||
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
class ViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
|
||||
nn.Linear(patch_dim, dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
self.pool = pool
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.mlp_head(x)
|
||||