mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 16:12:29 +00:00
Compare commits
72 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
06d375351e | ||
|
|
f196d1ec5b | ||
|
|
529044c9b3 | ||
|
|
c30655f3bc | ||
|
|
d2d6de01d3 | ||
|
|
b9eadaef60 | ||
|
|
24ac8350bf | ||
|
|
ca3cef9de0 | ||
|
|
6e1be11517 | ||
|
|
73ed562ce4 | ||
|
|
ff863175a6 | ||
|
|
ca0bdca192 | ||
|
|
1c70271778 | ||
|
|
d7d3febfe3 | ||
|
|
946815164a | ||
|
|
aeed3381c1 | ||
|
|
3f754956fb | ||
|
|
918869571c | ||
|
|
e5324242be | ||
|
|
22da26fa4b | ||
|
|
a6c085a2df | ||
|
|
121353c604 | ||
|
|
2ece3333da | ||
|
|
a73030c9aa | ||
|
|
780f91a220 | ||
|
|
88451068e8 | ||
|
|
64a2ef6462 | ||
|
|
53884f583f | ||
|
|
e616b5dcbc | ||
|
|
60ad4e266e | ||
|
|
a254a0258a | ||
|
|
26df10c0b7 | ||
|
|
17cb8976df | ||
|
|
daf3abbeb5 | ||
|
|
b483b16833 | ||
|
|
c457573808 | ||
|
|
e75b6d0251 | ||
|
|
679e5be3e7 | ||
|
|
7333979e6b | ||
|
|
74b402377b | ||
|
|
41d2d460d0 | ||
|
|
04f86dee3c | ||
|
|
6549522629 | ||
|
|
6a80a4ef89 | ||
|
|
9f05587a7d | ||
|
|
65bb350e85 | ||
|
|
fd4a7dfcf8 | ||
|
|
6f3a5fcf0b | ||
|
|
7807f24509 | ||
|
|
a612327126 | ||
|
|
30a1335d31 | ||
|
|
ab781f7ddb | ||
|
|
a2df363224 | ||
|
|
4f3dbd003f | ||
|
|
710b6d57d3 | ||
|
|
60b5687a79 | ||
|
|
0df1505662 | ||
|
|
3df6c31c61 | ||
|
|
54af220930 | ||
|
|
bad4b94e7b | ||
|
|
fbced01fe7 | ||
|
|
e42e9876bc | ||
|
|
566365978d | ||
|
|
34f78294d3 | ||
|
|
4c29328363 | ||
|
|
27ac10c1f1 | ||
|
|
fa216c45ea | ||
|
|
1d8b7826bf | ||
|
|
53b3af05f6 | ||
|
|
6289619e3f | ||
|
|
b42fa7862e | ||
|
|
dc6622c05c |
378
README.md
378
README.md
@@ -38,6 +38,7 @@ preds = v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## Parameters
|
||||
|
||||
- `image_size`: int.
|
||||
Image size. If you have rectangular images, make sure your image size is the maximum of the width and height
|
||||
- `patch_size`: int.
|
||||
@@ -61,6 +62,7 @@ Dropout rate.
|
||||
Embedding dropout rate.
|
||||
- `pool`: string, either `cls` token pooling or `mean` pooling
|
||||
|
||||
|
||||
## Distillation
|
||||
|
||||
<img src="./images/distill.png" width="300px"></img>
|
||||
@@ -117,6 +119,7 @@ v = v.to_vit()
|
||||
type(v) # <class 'vit_pytorch.vit_pytorch.ViT'>
|
||||
```
|
||||
|
||||
|
||||
## Deep ViT
|
||||
|
||||
This <a href="https://arxiv.org/abs/2103.11886">paper</a> notes that ViT struggles to attend at greater depths (past 12 layers), and suggests mixing the attention of each head post-softmax as a solution, dubbed Re-attention. The results line up with the <a href="https://github.com/lucidrains/x-transformers#talking-heads-attention">Talking Heads</a> paper from NLP.
|
||||
@@ -200,6 +203,61 @@ img = torch.randn(1, 3, 224, 224)
|
||||
preds = v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## CCT
|
||||
<img src="https://raw.githubusercontent.com/SHI-Labs/Compact-Transformers/main/images/model_sym.png" width="400px"></img>
|
||||
|
||||
<a href="https://arxiv.org/abs/2104.05704">CCT</a> proposes compact transformers
|
||||
by using convolutions instead of patching and performing sequence pooling. This
|
||||
allows for CCT to have high accuracy and a low number of parameters.
|
||||
|
||||
You can use this with two methods
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.cct import CCT
|
||||
|
||||
model = CCT(
|
||||
img_size=224,
|
||||
embedding_dim=384,
|
||||
n_conv_layers=2,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
padding=3,
|
||||
pooling_kernel_size=3,
|
||||
pooling_stride=2,
|
||||
pooling_padding=1,
|
||||
num_layers=14,
|
||||
num_heads=6,
|
||||
mlp_radio=3.,
|
||||
num_classes=1000,
|
||||
positional_embedding='learnable', # ['sine', 'learnable', 'none']
|
||||
)
|
||||
```
|
||||
|
||||
Alternatively you can use one of several pre-defined models `[2,4,6,7,8,14,16]`
|
||||
which pre-define the number of layers, number of attention heads, the mlp ratio,
|
||||
and the embedding dimension.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.cct import cct_14
|
||||
|
||||
model = cct_14(
|
||||
img_size=224,
|
||||
n_conv_layers=1,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
padding=3,
|
||||
pooling_kernel_size=3,
|
||||
pooling_stride=2,
|
||||
pooling_padding=1,
|
||||
num_classes=1000,
|
||||
positional_embedding='learnable', # ['sine', 'learnable', 'none']
|
||||
)
|
||||
```
|
||||
<a href="https://github.com/SHI-Labs/Compact-Transformers">Official
|
||||
Repository</a> includes links to pretrained model checkpoints.
|
||||
|
||||
|
||||
## Cross ViT
|
||||
|
||||
<img src="./images/cross_vit.png" width="400px"></img>
|
||||
@@ -270,6 +328,8 @@ preds = v(img) # (1, 1000)
|
||||
|
||||
<a href="https://arxiv.org/abs/2104.01136">This paper</a> proposes a number of changes, including (1) convolutional embedding instead of patch-wise projection (2) downsampling in stages (3) extra non-linearity in attention (4) 2d relative positional biases instead of initial absolute positional bias (5) batchnorm in place of layernorm.
|
||||
|
||||
<a href="https://github.com/facebookresearch/LeViT">Official repository</a>
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.levit import LeViT
|
||||
@@ -334,6 +394,100 @@ img = torch.randn(1, 3, 224, 224)
|
||||
pred = v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## Twins SVT
|
||||
|
||||
<img src="./images/twins_svt.png" width="400px"></img>
|
||||
|
||||
This <a href="https://arxiv.org/abs/2104.13840">paper</a> proposes mixing local and global attention, along with position encoding generator (proposed in <a href="https://arxiv.org/abs/2102.10882">CPVT</a>) and global average pooling, to achieve the same results as <a href="https://arxiv.org/abs/2103.14030">Swin</a>, without the extra complexity of shifted windows, CLS tokens, nor positional embeddings.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.twins_svt import TwinsSVT
|
||||
|
||||
model = TwinsSVT(
|
||||
num_classes = 1000, # number of output classes
|
||||
s1_emb_dim = 64, # stage 1 - patch embedding projected dimension
|
||||
s1_patch_size = 4, # stage 1 - patch size for patch embedding
|
||||
s1_local_patch_size = 7, # stage 1 - patch size for local attention
|
||||
s1_global_k = 7, # stage 1 - global attention key / value reduction factor, defaults to 7 as specified in paper
|
||||
s1_depth = 1, # stage 1 - number of transformer blocks (local attn -> ff -> global attn -> ff)
|
||||
s2_emb_dim = 128, # stage 2 (same as above)
|
||||
s2_patch_size = 2,
|
||||
s2_local_patch_size = 7,
|
||||
s2_global_k = 7,
|
||||
s2_depth = 1,
|
||||
s3_emb_dim = 256, # stage 3 (same as above)
|
||||
s3_patch_size = 2,
|
||||
s3_local_patch_size = 7,
|
||||
s3_global_k = 7,
|
||||
s3_depth = 5,
|
||||
s4_emb_dim = 512, # stage 4 (same as above)
|
||||
s4_patch_size = 2,
|
||||
s4_local_patch_size = 7,
|
||||
s4_global_k = 7,
|
||||
s4_depth = 4,
|
||||
peg_kernel_size = 3, # positional encoding generator kernel size
|
||||
dropout = 0. # dropout
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 224, 224)
|
||||
|
||||
pred = model(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## RegionViT
|
||||
|
||||
<img src="./images/regionvit.png" width="400px"></img>
|
||||
|
||||
<img src="./images/regionvit2.png" width="400px"></img>
|
||||
|
||||
<a href="https://arxiv.org/abs/2106.02689">This paper</a> proposes to divide up the feature map into local regions, whereby the local tokens attend to each other. Each local region has its own regional token which then attends to all its local tokens, as well as other regional tokens.
|
||||
|
||||
You can use it as follows
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.regionvit import RegionViT
|
||||
|
||||
model = RegionViT(
|
||||
dim = (64, 128, 256, 512), # tuple of size 4, indicating dimension at each stage
|
||||
depth = (2, 2, 8, 2), # depth of the region to local transformer at each stage
|
||||
window_size = 7, # window size, which should be either 7 or 14
|
||||
num_classes = 1000, # number of output lcasses
|
||||
tokenize_local_3_conv = False, # whether to use a 3 layer convolution to encode the local tokens from the image. the paper uses this for the smaller models, but uses only 1 conv (set to False) for the larger models
|
||||
use_peg = False, # whether to use positional generating module. they used this for object detection for a boost in performance
|
||||
)
|
||||
|
||||
x = torch.randn(1, 3, 224, 224)
|
||||
logits = model(x) # (1, 1000)
|
||||
```
|
||||
|
||||
## NesT
|
||||
|
||||
<img src="./images/nest.png" width="400px"></img>
|
||||
|
||||
This <a href="https://arxiv.org/abs/2105.12723">paper</a> decided to process the image in hierarchical stages, with attention only within tokens of local blocks, which aggregate as it moves up the heirarchy. The aggregation is done in the image plane, and contains a convolution and subsequent maxpool to allow it to pass information across the boundary.
|
||||
|
||||
You can use it with the following code (ex. NesT-T)
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.nest import NesT
|
||||
|
||||
nest = NesT(
|
||||
image_size = 224,
|
||||
patch_size = 4,
|
||||
dim = 96,
|
||||
heads = 3,
|
||||
num_hierarchies = 3, # number of hierarchies
|
||||
block_repeats = (8, 4, 1), # the number of transformer blocks at each heirarchy, starting from the bottom
|
||||
num_classes = 1000
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 224, 224)
|
||||
pred = nest(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## Masked Patch Prediction
|
||||
|
||||
Thanks to <a href="https://github.com/zankner">Zach</a>, you can train using the original masked patch prediction task presented in the paper, with the following code.
|
||||
@@ -367,7 +521,7 @@ mpp_trainer = MPP(
|
||||
opt = torch.optim.Adam(mpp_trainer.parameters(), lr=3e-4)
|
||||
|
||||
def sample_unlabelled_images():
|
||||
return torch.randn(20, 3, 256, 256)
|
||||
return torch.FloatTensor(20, 3, 256, 256).uniform_(0., 1.)
|
||||
|
||||
for _ in range(100):
|
||||
images = sample_unlabelled_images()
|
||||
@@ -380,6 +534,60 @@ for _ in range(100):
|
||||
torch.save(model.state_dict(), './pretrained-net.pt')
|
||||
```
|
||||
|
||||
## Dino
|
||||
|
||||
<img src="./images/dino.png" width="350px"></img>
|
||||
|
||||
You can train `ViT` with the recent SOTA self-supervised learning technique, <a href="https://arxiv.org/abs/2104.14294">Dino</a>, with the following code.
|
||||
|
||||
<a href="https://www.youtube.com/watch?v=h3ij3F3cPIk">Yannic Kilcher</a> video
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch import ViT, Dino
|
||||
|
||||
model = ViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048
|
||||
)
|
||||
|
||||
learner = Dino(
|
||||
model,
|
||||
image_size = 256,
|
||||
hidden_layer = 'to_latent', # hidden layer name or index, from which to extract the embedding
|
||||
projection_hidden_size = 256, # projector network hidden dimension
|
||||
projection_layers = 4, # number of layers in projection network
|
||||
num_classes_K = 65336, # output logits dimensions (referenced as K in paper)
|
||||
student_temp = 0.9, # student temperature
|
||||
teacher_temp = 0.04, # teacher temperature, needs to be annealed from 0.04 to 0.07 over 30 epochs
|
||||
local_upper_crop_scale = 0.4, # upper bound for local crop - 0.4 was recommended in the paper
|
||||
global_lower_crop_scale = 0.5, # lower bound for global crop - 0.5 was recommended in the paper
|
||||
moving_average_decay = 0.9, # moving average of encoder - paper showed anywhere from 0.9 to 0.999 was ok
|
||||
center_moving_average_decay = 0.9, # moving average of teacher centers - paper showed anywhere from 0.9 to 0.999 was ok
|
||||
)
|
||||
|
||||
opt = torch.optim.Adam(learner.parameters(), lr = 3e-4)
|
||||
|
||||
def sample_unlabelled_images():
|
||||
return torch.randn(20, 3, 256, 256)
|
||||
|
||||
for _ in range(100):
|
||||
images = sample_unlabelled_images()
|
||||
loss = learner(images)
|
||||
opt.zero_grad()
|
||||
loss.backward()
|
||||
opt.step()
|
||||
learner.update_moving_average() # update moving average of teacher encoder and teacher centers
|
||||
|
||||
# save your improved network
|
||||
torch.save(model.state_dict(), './pretrained-net.pt')
|
||||
```
|
||||
|
||||
## Accessing Attention
|
||||
|
||||
If you would like to visualize the attention weights (post-softmax) for your research, just follow the procedure below
|
||||
@@ -423,56 +631,6 @@ v = v.eject() # wrapper is discarded and original ViT instance is returned
|
||||
|
||||
## Research Ideas
|
||||
|
||||
### Self Supervised Training
|
||||
|
||||
You can train this with a near SOTA self-supervised learning technique, <a href="https://github.com/lucidrains/byol-pytorch">BYOL</a>, with the following code.
|
||||
|
||||
(1)
|
||||
```bash
|
||||
$ pip install byol-pytorch
|
||||
```
|
||||
|
||||
(2)
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch import ViT
|
||||
from byol_pytorch import BYOL
|
||||
|
||||
model = ViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048
|
||||
)
|
||||
|
||||
learner = BYOL(
|
||||
model,
|
||||
image_size = 256,
|
||||
hidden_layer = 'to_latent'
|
||||
)
|
||||
|
||||
opt = torch.optim.Adam(learner.parameters(), lr=3e-4)
|
||||
|
||||
def sample_unlabelled_images():
|
||||
return torch.randn(20, 3, 256, 256)
|
||||
|
||||
for _ in range(100):
|
||||
images = sample_unlabelled_images()
|
||||
loss = learner(images)
|
||||
opt.zero_grad()
|
||||
loss.backward()
|
||||
opt.step()
|
||||
learner.update_moving_average() # update moving average of target encoder
|
||||
|
||||
# save your improved network
|
||||
torch.save(model.state_dict(), './pretrained-net.pt')
|
||||
```
|
||||
|
||||
A pytorch-lightning script is ready for you to use at the repository link above.
|
||||
|
||||
### Efficient Attention
|
||||
|
||||
There may be some coming from computer vision who think attention still suffers from quadratic costs. Fortunately, we have a lot of new techniques that may help. This repository offers a way for you to plugin your own sparse attention transformer.
|
||||
@@ -542,6 +700,58 @@ img = torch.randn(1, 3, 224, 224)
|
||||
v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## FAQ
|
||||
|
||||
- How do I pass in non-square images?
|
||||
|
||||
You can already pass in non-square images - you just have to make sure your height and width is less than or equal to the `image_size`, and both divisible by the `patch_size`
|
||||
|
||||
ex.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch import ViT
|
||||
|
||||
v = ViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 256, 128) # <-- not a square
|
||||
|
||||
preds = v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
- How do I pass in non-square patches?
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch import ViT
|
||||
|
||||
v = ViT(
|
||||
num_classes = 1000,
|
||||
image_size = (256, 128), # image size is a tuple of (height, width)
|
||||
patch_size = (32, 16), # patch size is a tuple of (height, width)
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 256, 128)
|
||||
|
||||
preds = v(img)
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
Coming from computer vision and new to transformers? Here are some resources that greatly accelerated my learning.
|
||||
@@ -554,6 +764,17 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
|
||||
|
||||
## Citations
|
||||
```bibtex
|
||||
@article{hassani2021escaping,
|
||||
title = {Escaping the Big Data Paradigm with Compact Transformers},
|
||||
author = {Ali Hassani and Steven Walton and Nikhil Shah and Abulikemu Abuduweili and Jiachen Li and Humphrey Shi},
|
||||
year = 2021,
|
||||
url = {https://arxiv.org/abs/2104.05704},
|
||||
eprint = {2104.05704},
|
||||
archiveprefix = {arXiv},
|
||||
primaryclass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{dosovitskiy2020image,
|
||||
@@ -665,6 +886,61 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{chu2021twins,
|
||||
title = {Twins: Revisiting Spatial Attention Design in Vision Transformers},
|
||||
author = {Xiangxiang Chu and Zhi Tian and Yuqing Wang and Bo Zhang and Haibing Ren and Xiaolin Wei and Huaxia Xia and Chunhua Shen},
|
||||
year = {2021},
|
||||
eprint = {2104.13840},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{su2021roformer,
|
||||
title = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
|
||||
author = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
|
||||
year = {2021},
|
||||
eprint = {2104.09864},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CL}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{zhang2021aggregating,
|
||||
title = {Aggregating Nested Transformers},
|
||||
author = {Zizhao Zhang and Han Zhang and Long Zhao and Ting Chen and Tomas Pfister},
|
||||
year = {2021},
|
||||
eprint = {2105.12723},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{chen2021regionvit,
|
||||
title = {RegionViT: Regional-to-Local Attention for Vision Transformers},
|
||||
author = {Chun-Fu Chen and Rameswar Panda and Quanfu Fan},
|
||||
year = {2021},
|
||||
eprint = {2106.02689},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{caron2021emerging,
|
||||
title = {Emerging Properties in Self-Supervised Vision Transformers},
|
||||
author = {Mathilde Caron and Hugo Touvron and Ishan Misra and Hervé Jégou and Julien Mairal and Piotr Bojanowski and Armand Joulin},
|
||||
year = {2021},
|
||||
eprint = {2104.14294},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{vaswani2017attention,
|
||||
title = {Attention Is All You Need},
|
||||
|
||||
@@ -364,9 +364,8 @@
|
||||
"\n",
|
||||
"val_transforms = transforms.Compose(\n",
|
||||
" [\n",
|
||||
" transforms.Resize((224, 224)),\n",
|
||||
" transforms.RandomResizedCrop(224),\n",
|
||||
" transforms.RandomHorizontalFlip(),\n",
|
||||
" transforms.Resize(256),\n",
|
||||
" transforms.CenterCrop(224),\n",
|
||||
" transforms.ToTensor(),\n",
|
||||
" ]\n",
|
||||
")\n",
|
||||
@@ -374,9 +373,8 @@
|
||||
"\n",
|
||||
"test_transforms = transforms.Compose(\n",
|
||||
" [\n",
|
||||
" transforms.Resize((224, 224)),\n",
|
||||
" transforms.RandomResizedCrop(224),\n",
|
||||
" transforms.RandomHorizontalFlip(),\n",
|
||||
" transforms.Resize(256),\n",
|
||||
" transforms.CenterCrop(224),\n",
|
||||
" transforms.ToTensor(),\n",
|
||||
" ]\n",
|
||||
")\n"
|
||||
@@ -6250,4 +6248,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 1
|
||||
}
|
||||
}
|
||||
BIN
images/dino.png
Normal file
BIN
images/dino.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 84 KiB |
BIN
images/nest.png
Normal file
BIN
images/nest.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 75 KiB |
BIN
images/regionvit.png
Normal file
BIN
images/regionvit.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 94 KiB |
BIN
images/regionvit2.png
Normal file
BIN
images/regionvit2.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 55 KiB |
BIN
images/twins_svt.png
Normal file
BIN
images/twins_svt.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 110 KiB |
5
setup.py
5
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.16.0',
|
||||
version = '0.21.0',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
@@ -15,8 +15,9 @@ setup(
|
||||
'image recognition'
|
||||
],
|
||||
install_requires=[
|
||||
'einops>=0.3',
|
||||
'torch>=1.6',
|
||||
'einops>=0.3'
|
||||
'torchvision'
|
||||
],
|
||||
classifiers=[
|
||||
'Development Status :: 4 - Beta',
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
from vit_pytorch.vit import ViT
|
||||
from vit_pytorch.dino import Dino
|
||||
|
||||
339
vit_pytorch/cct.py
Normal file
339
vit_pytorch/cct.py
Normal file
@@ -0,0 +1,339 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Pre-defined CCT Models
|
||||
__all__ = ['cct_2', 'cct_4', 'cct_6', 'cct_7', 'cct_8', 'cct_14', 'cct_16']
|
||||
|
||||
|
||||
def cct_2(*args, **kwargs):
|
||||
return _cct(num_layers=2, num_heads=2, mlp_ratio=1, embedding_dim=128,
|
||||
*args, **kwargs)
|
||||
|
||||
|
||||
def cct_4(*args, **kwargs):
|
||||
return _cct(num_layers=4, num_heads=2, mlp_ratio=1, embedding_dim=128,
|
||||
*args, **kwargs)
|
||||
|
||||
|
||||
def cct_6(*args, **kwargs):
|
||||
return _cct(num_layers=6, num_heads=4, mlp_ratio=2, embedding_dim=256,
|
||||
*args, **kwargs)
|
||||
|
||||
|
||||
def cct_7(*args, **kwargs):
|
||||
return _cct(num_layers=7, num_heads=4, mlp_ratio=2, embedding_dim=256,
|
||||
*args, **kwargs)
|
||||
|
||||
|
||||
def cct_8(*args, **kwargs):
|
||||
return _cct(num_layers=8, num_heads=4, mlp_ratio=2, embedding_dim=256,
|
||||
*args, **kwargs)
|
||||
|
||||
|
||||
def cct_14(*args, **kwargs):
|
||||
return _cct(num_layers=14, num_heads=6, mlp_ratio=3, embedding_dim=384,
|
||||
*args, **kwargs)
|
||||
|
||||
|
||||
def cct_16(*args, **kwargs):
|
||||
return _cct(num_layers=16, num_heads=6, mlp_ratio=3, embedding_dim=384,
|
||||
*args, **kwargs)
|
||||
|
||||
|
||||
def _cct(num_layers, num_heads, mlp_ratio, embedding_dim,
|
||||
kernel_size=3, stride=None, padding=None,
|
||||
*args, **kwargs):
|
||||
stride = stride if stride is not None else max(1, (kernel_size // 2) - 1)
|
||||
padding = padding if padding is not None else max(1, (kernel_size // 2))
|
||||
return CCT(num_layers=num_layers,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
embedding_dim=embedding_dim,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
*args, **kwargs)
|
||||
|
||||
|
||||
# Modules
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // self.num_heads
|
||||
self.scale = head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=False)
|
||||
self.attn_drop = nn.Dropout(attention_dropout)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(projection_dropout)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
"""
|
||||
Inspired by torch.nn.TransformerEncoderLayer and
|
||||
rwightman's timm package.
|
||||
"""
|
||||
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
||||
attention_dropout=0.1, drop_path_rate=0.1):
|
||||
super(TransformerEncoderLayer, self).__init__()
|
||||
self.pre_norm = nn.LayerNorm(d_model)
|
||||
self.self_attn = Attention(dim=d_model, num_heads=nhead,
|
||||
attention_dropout=attention_dropout, projection_dropout=dropout)
|
||||
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
|
||||
|
||||
self.activation = F.gelu
|
||||
|
||||
def forward(self, src: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
||||
src = src + self.drop_path(self.self_attn(self.pre_norm(src)))
|
||||
src = self.norm1(src)
|
||||
src2 = self.linear2(self.dropout1(self.activation(self.linear1(src))))
|
||||
src = src + self.drop_path(self.dropout2(src2))
|
||||
return src
|
||||
|
||||
|
||||
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
||||
"""
|
||||
Obtained from: github.com:rwightman/pytorch-image-models
|
||||
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
||||
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
||||
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
||||
'survival rate' as the argument.
|
||||
"""
|
||||
if drop_prob == 0. or not training:
|
||||
return x
|
||||
keep_prob = 1 - drop_prob
|
||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
||||
random_tensor.floor_() # binarize
|
||||
output = x.div(keep_prob) * random_tensor
|
||||
return output
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""
|
||||
Obtained from: github.com:rwightman/pytorch-image-models
|
||||
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
"""
|
||||
def __init__(self, drop_prob=None):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training)
|
||||
|
||||
|
||||
class Tokenizer(nn.Module):
|
||||
def __init__(self,
|
||||
kernel_size, stride, padding,
|
||||
pooling_kernel_size=3, pooling_stride=2, pooling_padding=1,
|
||||
n_conv_layers=1,
|
||||
n_input_channels=3,
|
||||
n_output_channels=64,
|
||||
in_planes=64,
|
||||
activation=None,
|
||||
max_pool=True,
|
||||
conv_bias=False):
|
||||
super(Tokenizer, self).__init__()
|
||||
|
||||
n_filter_list = [n_input_channels] + \
|
||||
[in_planes for _ in range(n_conv_layers - 1)] + \
|
||||
[n_output_channels]
|
||||
|
||||
self.conv_layers = nn.Sequential(
|
||||
*[nn.Sequential(
|
||||
nn.Conv2d(n_filter_list[i], n_filter_list[i + 1],
|
||||
kernel_size=(kernel_size, kernel_size),
|
||||
stride=(stride, stride),
|
||||
padding=(padding, padding), bias=conv_bias),
|
||||
nn.Identity() if activation is None else activation(),
|
||||
nn.MaxPool2d(kernel_size=pooling_kernel_size,
|
||||
stride=pooling_stride,
|
||||
padding=pooling_padding) if max_pool else nn.Identity()
|
||||
)
|
||||
for i in range(n_conv_layers)
|
||||
])
|
||||
|
||||
self.flattener = nn.Flatten(2, 3)
|
||||
self.apply(self.init_weight)
|
||||
|
||||
def sequence_length(self, n_channels=3, height=224, width=224):
|
||||
return self.forward(torch.zeros((1, n_channels, height, width))).shape[1]
|
||||
|
||||
def forward(self, x):
|
||||
return self.flattener(self.conv_layers(x)).transpose(-2, -1)
|
||||
|
||||
@staticmethod
|
||||
def init_weight(m):
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight)
|
||||
|
||||
|
||||
class TransformerClassifier(nn.Module):
|
||||
def __init__(self,
|
||||
seq_pool=True,
|
||||
embedding_dim=768,
|
||||
num_layers=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4.0,
|
||||
num_classes=1000,
|
||||
dropout_rate=0.1,
|
||||
attention_dropout=0.1,
|
||||
stochastic_depth_rate=0.1,
|
||||
positional_embedding='sine',
|
||||
sequence_length=None,
|
||||
*args, **kwargs):
|
||||
super().__init__()
|
||||
positional_embedding = positional_embedding if \
|
||||
positional_embedding in ['sine', 'learnable', 'none'] else 'sine'
|
||||
dim_feedforward = int(embedding_dim * mlp_ratio)
|
||||
self.embedding_dim = embedding_dim
|
||||
self.sequence_length = sequence_length
|
||||
self.seq_pool = seq_pool
|
||||
|
||||
assert sequence_length is not None or positional_embedding == 'none', \
|
||||
f"Positional embedding is set to {positional_embedding} and" \
|
||||
f" the sequence length was not specified."
|
||||
|
||||
if not seq_pool:
|
||||
sequence_length += 1
|
||||
self.class_emb = nn.Parameter(torch.zeros(1, 1, self.embedding_dim),
|
||||
requires_grad=True)
|
||||
else:
|
||||
self.attention_pool = nn.Linear(self.embedding_dim, 1)
|
||||
|
||||
if positional_embedding != 'none':
|
||||
if positional_embedding == 'learnable':
|
||||
self.positional_emb = nn.Parameter(torch.zeros(1, sequence_length, embedding_dim),
|
||||
requires_grad=True)
|
||||
nn.init.trunc_normal_(self.positional_emb, std=0.2)
|
||||
else:
|
||||
self.positional_emb = nn.Parameter(self.sinusoidal_embedding(sequence_length, embedding_dim),
|
||||
requires_grad=False)
|
||||
else:
|
||||
self.positional_emb = None
|
||||
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
dpr = [x.item() for x in torch.linspace(0, stochastic_depth_rate, num_layers)]
|
||||
self.blocks = nn.ModuleList([
|
||||
TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads,
|
||||
dim_feedforward=dim_feedforward, dropout=dropout_rate,
|
||||
attention_dropout=attention_dropout, drop_path_rate=dpr[i])
|
||||
for i in range(num_layers)])
|
||||
self.norm = nn.LayerNorm(embedding_dim)
|
||||
|
||||
self.fc = nn.Linear(embedding_dim, num_classes)
|
||||
self.apply(self.init_weight)
|
||||
|
||||
def forward(self, x):
|
||||
if self.positional_emb is None and x.size(1) < self.sequence_length:
|
||||
x = F.pad(x, (0, 0, 0, self.n_channels - x.size(1)), mode='constant', value=0)
|
||||
|
||||
if not self.seq_pool:
|
||||
cls_token = self.class_emb.expand(x.shape[0], -1, -1)
|
||||
x = torch.cat((cls_token, x), dim=1)
|
||||
|
||||
if self.positional_emb is not None:
|
||||
x += self.positional_emb
|
||||
|
||||
x = self.dropout(x)
|
||||
|
||||
for blk in self.blocks:
|
||||
x = blk(x)
|
||||
x = self.norm(x)
|
||||
|
||||
if self.seq_pool:
|
||||
x = torch.matmul(F.softmax(self.attention_pool(x), dim=1).transpose(-1, -2), x).squeeze(-2)
|
||||
else:
|
||||
x = x[:, 0]
|
||||
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def init_weight(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
@staticmethod
|
||||
def sinusoidal_embedding(n_channels, dim):
|
||||
pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)]
|
||||
for p in range(n_channels)])
|
||||
pe[:, 0::2] = torch.sin(pe[:, 0::2])
|
||||
pe[:, 1::2] = torch.cos(pe[:, 1::2])
|
||||
return pe.unsqueeze(0)
|
||||
|
||||
|
||||
# CCT Main model
|
||||
class CCT(nn.Module):
|
||||
def __init__(self,
|
||||
img_size=224,
|
||||
embedding_dim=768,
|
||||
n_input_channels=3,
|
||||
n_conv_layers=1,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
padding=3,
|
||||
pooling_kernel_size=3,
|
||||
pooling_stride=2,
|
||||
pooling_padding=1,
|
||||
*args, **kwargs):
|
||||
super(CCT, self).__init__()
|
||||
|
||||
self.tokenizer = Tokenizer(n_input_channels=n_input_channels,
|
||||
n_output_channels=embedding_dim,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
pooling_kernel_size=pooling_kernel_size,
|
||||
pooling_stride=pooling_stride,
|
||||
pooling_padding=pooling_padding,
|
||||
max_pool=True,
|
||||
activation=nn.ReLU,
|
||||
n_conv_layers=n_conv_layers,
|
||||
conv_bias=False)
|
||||
|
||||
self.classifier = TransformerClassifier(
|
||||
sequence_length=self.tokenizer.sequence_length(n_channels=n_input_channels,
|
||||
height=img_size,
|
||||
width=img_size),
|
||||
embedding_dim=embedding_dim,
|
||||
seq_pool=True,
|
||||
dropout_rate=0.,
|
||||
attention_dropout=0.1,
|
||||
stochastic_depth=0.1,
|
||||
*args, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.tokenizer(x)
|
||||
return self.classifier(x)
|
||||
|
||||
@@ -22,15 +22,25 @@ def group_by_key_prefix_and_remove_prefix(prefix, d):
|
||||
|
||||
# classes
|
||||
|
||||
class LayerNorm(nn.Module): # layernorm, but done in the channel dimension #1
|
||||
def __init__(self, dim, eps = 1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (std + self.eps) * self.g + self.b
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.norm = LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
x = rearrange(x, 'b c h w -> b h w c')
|
||||
x = self.norm(x)
|
||||
x = rearrange(x, 'b h w c -> b c h w')
|
||||
return self.fn(x, **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
@@ -67,8 +77,8 @@ class Attention(nn.Module):
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
self.to_q = DepthWiseConv2d(dim, inner_dim, 3, padding = padding, stride = 1, bias = False)
|
||||
self.to_kv = DepthWiseConv2d(dim, inner_dim * 2, 3, padding = padding, stride = kv_proj_stride, bias = False)
|
||||
self.to_q = DepthWiseConv2d(dim, inner_dim, proj_kernel, padding = padding, stride = 1, bias = False)
|
||||
self.to_kv = DepthWiseConv2d(dim, inner_dim * 2, proj_kernel, padding = padding, stride = kv_proj_stride, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Conv2d(inner_dim, dim, 1),
|
||||
@@ -130,7 +140,7 @@ class CvT(nn.Module):
|
||||
s3_emb_stride = 2,
|
||||
s3_proj_kernel = 3,
|
||||
s3_kv_proj_stride = 2,
|
||||
s3_heads = 4,
|
||||
s3_heads = 6,
|
||||
s3_depth = 10,
|
||||
s3_mlp_mult = 4,
|
||||
dropout = 0.
|
||||
@@ -146,6 +156,7 @@ class CvT(nn.Module):
|
||||
|
||||
layers.append(nn.Sequential(
|
||||
nn.Conv2d(dim, config['emb_dim'], kernel_size = config['emb_kernel'], padding = (config['emb_kernel'] // 2), stride = config['emb_stride']),
|
||||
LayerNorm(config['emb_dim']),
|
||||
Transformer(dim = config['emb_dim'], proj_kernel = config['proj_kernel'], kv_proj_stride = config['kv_proj_stride'], depth = config['depth'], heads = config['heads'], mlp_mult = config['mlp_mult'], dropout = dropout)
|
||||
))
|
||||
|
||||
|
||||
303
vit_pytorch/dino.py
Normal file
303
vit_pytorch/dino.py
Normal file
@@ -0,0 +1,303 @@
|
||||
import copy
|
||||
import random
|
||||
from functools import wraps, partial
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torchvision import transforms as T
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, default):
|
||||
return val if exists(val) else default
|
||||
|
||||
def singleton(cache_key):
|
||||
def inner_fn(fn):
|
||||
@wraps(fn)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
instance = getattr(self, cache_key)
|
||||
if instance is not None:
|
||||
return instance
|
||||
|
||||
instance = fn(self, *args, **kwargs)
|
||||
setattr(self, cache_key, instance)
|
||||
return instance
|
||||
return wrapper
|
||||
return inner_fn
|
||||
|
||||
def get_module_device(module):
|
||||
return next(module.parameters()).device
|
||||
|
||||
def set_requires_grad(model, val):
|
||||
for p in model.parameters():
|
||||
p.requires_grad = val
|
||||
|
||||
# loss function # (algorithm 1 in the paper)
|
||||
|
||||
def loss_fn(
|
||||
teacher_logits,
|
||||
student_logits,
|
||||
teacher_temp,
|
||||
student_temp,
|
||||
centers,
|
||||
eps = 1e-20
|
||||
):
|
||||
teacher_logits = teacher_logits.detach()
|
||||
student_probs = (student_logits / student_temp).softmax(dim = -1)
|
||||
teacher_probs = ((teacher_logits - centers) / teacher_temp).softmax(dim = -1)
|
||||
return - (teacher_probs * torch.log(student_probs + eps)).sum(dim = -1).mean()
|
||||
|
||||
# augmentation utils
|
||||
|
||||
class RandomApply(nn.Module):
|
||||
def __init__(self, fn, p):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.p = p
|
||||
|
||||
def forward(self, x):
|
||||
if random.random() > self.p:
|
||||
return x
|
||||
return self.fn(x)
|
||||
|
||||
# exponential moving average
|
||||
|
||||
class EMA():
|
||||
def __init__(self, beta):
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
|
||||
def update_average(self, old, new):
|
||||
if old is None:
|
||||
return new
|
||||
return old * self.beta + (1 - self.beta) * new
|
||||
|
||||
def update_moving_average(ema_updater, ma_model, current_model):
|
||||
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
|
||||
old_weight, up_weight = ma_params.data, current_params.data
|
||||
ma_params.data = ema_updater.update_average(old_weight, up_weight)
|
||||
|
||||
# MLP class for projector and predictor
|
||||
|
||||
class L2Norm(nn.Module):
|
||||
def forward(self, x, eps = 1e-6):
|
||||
norm = x.norm(dim = 1, keepdim = True).clamp(min = eps)
|
||||
return x / norm
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, dim_out, num_layers, hidden_size = 256):
|
||||
super().__init__()
|
||||
|
||||
layers = []
|
||||
dims = (dim, *((hidden_size,) * (num_layers - 1)))
|
||||
|
||||
for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
is_last = ind == (len(dims) - 1)
|
||||
|
||||
layers.extend([
|
||||
nn.Linear(layer_dim_in, layer_dim_out),
|
||||
nn.GELU() if not is_last else nn.Identity()
|
||||
])
|
||||
|
||||
self.net = nn.Sequential(
|
||||
*layers,
|
||||
L2Norm(),
|
||||
nn.Linear(hidden_size, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
# a wrapper class for the base neural network
|
||||
# will manage the interception of the hidden layer output
|
||||
# and pipe it into the projecter and predictor nets
|
||||
|
||||
class NetWrapper(nn.Module):
|
||||
def __init__(self, net, output_dim, projection_hidden_size, projection_num_layers, layer = -2):
|
||||
super().__init__()
|
||||
self.net = net
|
||||
self.layer = layer
|
||||
|
||||
self.projector = None
|
||||
self.projection_hidden_size = projection_hidden_size
|
||||
self.projection_num_layers = projection_num_layers
|
||||
self.output_dim = output_dim
|
||||
|
||||
self.hidden = {}
|
||||
self.hook_registered = False
|
||||
|
||||
def _find_layer(self):
|
||||
if type(self.layer) == str:
|
||||
modules = dict([*self.net.named_modules()])
|
||||
return modules.get(self.layer, None)
|
||||
elif type(self.layer) == int:
|
||||
children = [*self.net.children()]
|
||||
return children[self.layer]
|
||||
return None
|
||||
|
||||
def _hook(self, _, input, output):
|
||||
device = input[0].device
|
||||
self.hidden[device] = output.flatten(1)
|
||||
|
||||
def _register_hook(self):
|
||||
layer = self._find_layer()
|
||||
assert layer is not None, f'hidden layer ({self.layer}) not found'
|
||||
handle = layer.register_forward_hook(self._hook)
|
||||
self.hook_registered = True
|
||||
|
||||
@singleton('projector')
|
||||
def _get_projector(self, hidden):
|
||||
_, dim = hidden.shape
|
||||
projector = MLP(dim, self.output_dim, self.projection_num_layers, self.projection_hidden_size)
|
||||
return projector.to(hidden)
|
||||
|
||||
def get_embedding(self, x):
|
||||
if self.layer == -1:
|
||||
return self.net(x)
|
||||
|
||||
if not self.hook_registered:
|
||||
self._register_hook()
|
||||
|
||||
self.hidden.clear()
|
||||
_ = self.net(x)
|
||||
hidden = self.hidden[x.device]
|
||||
self.hidden.clear()
|
||||
|
||||
assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
|
||||
return hidden
|
||||
|
||||
def forward(self, x, return_projection = True):
|
||||
embed = self.get_embedding(x)
|
||||
if not return_projection:
|
||||
return embed
|
||||
|
||||
projector = self._get_projector(embed)
|
||||
return projector(embed), embed
|
||||
|
||||
# main class
|
||||
|
||||
class Dino(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
net,
|
||||
image_size,
|
||||
hidden_layer = -2,
|
||||
projection_hidden_size = 256,
|
||||
num_classes_K = 65336,
|
||||
projection_layers = 4,
|
||||
student_temp = 0.9,
|
||||
teacher_temp = 0.04,
|
||||
local_upper_crop_scale = 0.4,
|
||||
global_lower_crop_scale = 0.5,
|
||||
moving_average_decay = 0.9,
|
||||
center_moving_average_decay = 0.9,
|
||||
augment_fn = None,
|
||||
augment_fn2 = None
|
||||
):
|
||||
super().__init__()
|
||||
self.net = net
|
||||
|
||||
# default BYOL augmentation
|
||||
|
||||
DEFAULT_AUG = torch.nn.Sequential(
|
||||
RandomApply(
|
||||
T.ColorJitter(0.8, 0.8, 0.8, 0.2),
|
||||
p = 0.3
|
||||
),
|
||||
T.RandomGrayscale(p=0.2),
|
||||
T.RandomHorizontalFlip(),
|
||||
RandomApply(
|
||||
T.GaussianBlur((3, 3), (1.0, 2.0)),
|
||||
p = 0.2
|
||||
),
|
||||
T.Normalize(
|
||||
mean=torch.tensor([0.485, 0.456, 0.406]),
|
||||
std=torch.tensor([0.229, 0.224, 0.225])),
|
||||
)
|
||||
|
||||
self.augment1 = default(augment_fn, DEFAULT_AUG)
|
||||
self.augment2 = default(augment_fn2, DEFAULT_AUG)
|
||||
|
||||
# local and global crops
|
||||
|
||||
self.local_crop = T.RandomResizedCrop((image_size, image_size), scale = (0.05, local_upper_crop_scale))
|
||||
self.global_crop = T.RandomResizedCrop((image_size, image_size), scale = (global_lower_crop_scale, 1.))
|
||||
|
||||
self.student_encoder = NetWrapper(net, num_classes_K, projection_hidden_size, projection_layers, layer = hidden_layer)
|
||||
|
||||
self.teacher_encoder = None
|
||||
self.teacher_ema_updater = EMA(moving_average_decay)
|
||||
|
||||
self.register_buffer('teacher_centers', torch.zeros(1, num_classes_K))
|
||||
self.register_buffer('last_teacher_centers', torch.zeros(1, num_classes_K))
|
||||
|
||||
self.teacher_centering_ema_updater = EMA(center_moving_average_decay)
|
||||
|
||||
self.student_temp = student_temp
|
||||
self.teacher_temp = teacher_temp
|
||||
|
||||
# get device of network and make wrapper same device
|
||||
device = get_module_device(net)
|
||||
self.to(device)
|
||||
|
||||
# send a mock image tensor to instantiate singleton parameters
|
||||
self.forward(torch.randn(2, 3, image_size, image_size, device=device))
|
||||
|
||||
@singleton('teacher_encoder')
|
||||
def _get_teacher_encoder(self):
|
||||
teacher_encoder = copy.deepcopy(self.student_encoder)
|
||||
set_requires_grad(teacher_encoder, False)
|
||||
return teacher_encoder
|
||||
|
||||
def reset_moving_average(self):
|
||||
del self.teacher_encoder
|
||||
self.teacher_encoder = None
|
||||
|
||||
def update_moving_average(self):
|
||||
assert self.teacher_encoder is not None, 'target encoder has not been created yet'
|
||||
update_moving_average(self.teacher_ema_updater, self.teacher_encoder, self.student_encoder)
|
||||
|
||||
new_teacher_centers = self.teacher_centering_ema_updater.update_average(self.teacher_centers, self.last_teacher_centers)
|
||||
self.teacher_centers.copy_(new_teacher_centers)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
return_embedding = False,
|
||||
return_projection = True,
|
||||
student_temp = None,
|
||||
teacher_temp = None
|
||||
):
|
||||
if return_embedding:
|
||||
return self.student_encoder(x, return_projection = return_projection)
|
||||
|
||||
image_one, image_two = self.augment1(x), self.augment2(x)
|
||||
|
||||
local_image_one, local_image_two = self.local_crop(image_one), self.local_crop(image_two)
|
||||
global_image_one, global_image_two = self.global_crop(image_one), self.global_crop(image_two)
|
||||
|
||||
student_proj_one, _ = self.student_encoder(local_image_one)
|
||||
student_proj_two, _ = self.student_encoder(local_image_two)
|
||||
|
||||
with torch.no_grad():
|
||||
teacher_encoder = self._get_teacher_encoder()
|
||||
teacher_proj_one, _ = teacher_encoder(global_image_one)
|
||||
teacher_proj_two, _ = teacher_encoder(global_image_two)
|
||||
|
||||
loss_fn_ = partial(
|
||||
loss_fn,
|
||||
student_temp = default(student_temp, self.student_temp),
|
||||
teacher_temp = default(teacher_temp, self.teacher_temp),
|
||||
centers = self.teacher_centers
|
||||
)
|
||||
|
||||
teacher_logits_avg = torch.cat((teacher_proj_one, teacher_proj_two)).mean(dim = 0)
|
||||
self.last_teacher_centers.copy_(teacher_logits_avg)
|
||||
|
||||
loss = (loss_fn_(teacher_proj_one, student_proj_two) + loss_fn_(teacher_proj_two, student_proj_one)) / 2
|
||||
return loss
|
||||
@@ -148,6 +148,6 @@ class DistillWrapper(nn.Module):
|
||||
|
||||
else:
|
||||
teacher_labels = teacher_logits.argmax(dim = -1)
|
||||
distill_loss = F.cross_entropy(student_logits, teacher_labels)
|
||||
distill_loss = F.cross_entropy(distill_logits, teacher_labels)
|
||||
|
||||
return loss * alpha + distill_loss * (1 - alpha)
|
||||
return loss * (1 - alpha) + distill_loss * alpha
|
||||
|
||||
@@ -29,7 +29,7 @@ class FeedForward(nn.Module):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(dim, dim * mult, 1),
|
||||
nn.GELU(),
|
||||
nn.Hardswish(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Conv2d(dim * mult, dim, 1),
|
||||
nn.Dropout(dropout)
|
||||
@@ -53,10 +53,13 @@ class Attention(nn.Module):
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
out_batch_norm = nn.BatchNorm2d(dim_out)
|
||||
nn.init.zeros_(out_batch_norm.weight)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.GELU(),
|
||||
nn.Conv2d(inner_dim_value, dim_out, 1),
|
||||
nn.BatchNorm2d(dim_out),
|
||||
out_batch_norm,
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
@@ -81,7 +84,7 @@ class Attention(nn.Module):
|
||||
def apply_pos_bias(self, fmap):
|
||||
bias = self.pos_bias(self.pos_indices)
|
||||
bias = rearrange(bias, 'i j h -> () h i j')
|
||||
return fmap + bias
|
||||
return fmap + (bias / self.scale)
|
||||
|
||||
def forward(self, x):
|
||||
b, n, *_, h = *x.shape, self.heads
|
||||
|
||||
@@ -149,4 +149,4 @@ class LocalViT(nn.Module):
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
return self.mlp_head(x)
|
||||
return self.mlp_head(x[:, 0])
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
import math
|
||||
from functools import reduce
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops import rearrange, repeat, reduce
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def prob_mask_like(t, prob):
|
||||
batch, seq_length, _ = t.shape
|
||||
return torch.zeros((batch, seq_length)).float().uniform_(0, 1) < prob
|
||||
|
||||
|
||||
def get_mask_subset_with_prob(patched_input, prob):
|
||||
batch, seq_len, _, device = *patched_input.shape, patched_input.device
|
||||
max_masked = math.ceil(prob * seq_len)
|
||||
@@ -31,43 +31,45 @@ def get_mask_subset_with_prob(patched_input, prob):
|
||||
|
||||
|
||||
class MPPLoss(nn.Module):
|
||||
def __init__(self, patch_size, channels, output_channel_bits,
|
||||
max_pixel_val):
|
||||
super(MPPLoss, self).__init__()
|
||||
def __init__(
|
||||
self,
|
||||
patch_size,
|
||||
channels,
|
||||
output_channel_bits,
|
||||
max_pixel_val,
|
||||
mean,
|
||||
std
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.channels = channels
|
||||
self.output_channel_bits = output_channel_bits
|
||||
self.max_pixel_val = max_pixel_val
|
||||
|
||||
self.mean = torch.tensor(mean).view(-1, 1, 1) if mean else None
|
||||
self.std = torch.tensor(std).view(-1, 1, 1) if std else None
|
||||
|
||||
def forward(self, predicted_patches, target, mask):
|
||||
p, c, mpv, bits, device = self.patch_size, self.channels, self.max_pixel_val, self.output_channel_bits, target.device
|
||||
bin_size = mpv / (2 ** bits)
|
||||
|
||||
# un-normalize input
|
||||
if exists(self.mean) and exists(self.std):
|
||||
target = target * self.std + self.mean
|
||||
|
||||
# reshape target to patches
|
||||
p = self.patch_size
|
||||
target = rearrange(target,
|
||||
"b c (h p1) (w p2) -> b (h w) c (p1 p2) ",
|
||||
p1=p,
|
||||
p2=p)
|
||||
target = target.clamp(max = mpv) # clamp just in case
|
||||
avg_target = reduce(target, 'b c (h p1) (w p2) -> b (h w) c', 'mean', p1 = p, p2 = p).contiguous()
|
||||
|
||||
avg_target = target.mean(dim=3)
|
||||
|
||||
bin_size = self.max_pixel_val / self.output_channel_bits
|
||||
channel_bins = torch.arange(bin_size, self.max_pixel_val, bin_size)
|
||||
channel_bins = torch.arange(bin_size, mpv, bin_size, device = device)
|
||||
discretized_target = torch.bucketize(avg_target, channel_bins)
|
||||
discretized_target = F.one_hot(discretized_target,
|
||||
self.output_channel_bits)
|
||||
c, bi = self.channels, self.output_channel_bits
|
||||
discretized_target = rearrange(discretized_target,
|
||||
"b n c bi -> b n (c bi)",
|
||||
c=c,
|
||||
bi=bi)
|
||||
|
||||
bin_mask = 2**torch.arange(c * bi - 1, -1,
|
||||
-1).to(discretized_target.device,
|
||||
discretized_target.dtype)
|
||||
target_label = torch.sum(bin_mask * discretized_target, -1)
|
||||
bin_mask = (2 ** bits) ** torch.arange(0, c, device = device).long()
|
||||
bin_mask = rearrange(bin_mask, 'c -> () () c')
|
||||
|
||||
predicted_patches = predicted_patches[mask]
|
||||
target_label = target_label[mask]
|
||||
loss = F.cross_entropy(predicted_patches, target_label)
|
||||
target_label = torch.sum(bin_mask * discretized_target, dim = -1)
|
||||
|
||||
loss = F.cross_entropy(predicted_patches[mask], target_label[mask])
|
||||
return loss
|
||||
|
||||
|
||||
@@ -75,21 +77,24 @@ class MPPLoss(nn.Module):
|
||||
|
||||
|
||||
class MPP(nn.Module):
|
||||
def __init__(self,
|
||||
transformer,
|
||||
patch_size,
|
||||
dim,
|
||||
output_channel_bits=3,
|
||||
channels=3,
|
||||
max_pixel_val=1.0,
|
||||
mask_prob=0.15,
|
||||
replace_prob=0.5,
|
||||
random_patch_prob=0.5):
|
||||
def __init__(
|
||||
self,
|
||||
transformer,
|
||||
patch_size,
|
||||
dim,
|
||||
output_channel_bits=3,
|
||||
channels=3,
|
||||
max_pixel_val=1.0,
|
||||
mask_prob=0.15,
|
||||
replace_prob=0.5,
|
||||
random_patch_prob=0.5,
|
||||
mean=None,
|
||||
std=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.transformer = transformer
|
||||
self.loss = MPPLoss(patch_size, channels, output_channel_bits,
|
||||
max_pixel_val)
|
||||
max_pixel_val, mean, std)
|
||||
|
||||
# output transformation
|
||||
self.to_bits = nn.Linear(dim, 2**(output_channel_bits * channels))
|
||||
@@ -103,7 +108,7 @@ class MPP(nn.Module):
|
||||
self.random_patch_prob = random_patch_prob
|
||||
|
||||
# token ids
|
||||
self.mask_token = nn.Parameter(torch.randn(1, 1, dim * channels))
|
||||
self.mask_token = nn.Parameter(torch.randn(1, 1, channels * patch_size ** 2))
|
||||
|
||||
def forward(self, input, **kwargs):
|
||||
transformer = self.transformer
|
||||
@@ -127,8 +132,9 @@ class MPP(nn.Module):
|
||||
random_patch_sampling_prob = self.random_patch_prob / (
|
||||
1 - self.replace_prob)
|
||||
random_patch_prob = prob_mask_like(input,
|
||||
random_patch_sampling_prob)
|
||||
bool_random_patch_prob = mask * random_patch_prob == True
|
||||
random_patch_sampling_prob).to(mask.device)
|
||||
|
||||
bool_random_patch_prob = mask * (random_patch_prob == True)
|
||||
random_patches = torch.randint(0,
|
||||
input.shape[1],
|
||||
(input.shape[0], input.shape[1]),
|
||||
@@ -140,7 +146,7 @@ class MPP(nn.Module):
|
||||
bool_random_patch_prob]
|
||||
|
||||
# [mask] input
|
||||
replace_prob = prob_mask_like(input, self.replace_prob)
|
||||
replace_prob = prob_mask_like(input, self.replace_prob).to(mask.device)
|
||||
bool_mask_replace = (mask * replace_prob) == True
|
||||
masked_input[bool_mask_replace] = self.mask_token
|
||||
|
||||
|
||||
179
vit_pytorch/nest.py
Normal file
179
vit_pytorch/nest.py
Normal file
@@ -0,0 +1,179 @@
|
||||
from functools import partial
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange, Reduce
|
||||
|
||||
# helpers
|
||||
|
||||
def cast_tuple(val, depth):
|
||||
return val if isinstance(val, tuple) else ((val,) * depth)
|
||||
|
||||
# classes
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps = 1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (std + self.eps) * self.g + self.b
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = LayerNorm(dim)
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, mlp_mult = 4, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(dim, dim * mlp_mult, 1),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Conv2d(dim * mlp_mult, dim, 1),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dropout = 0.):
|
||||
super().__init__()
|
||||
dim_head = dim // heads
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Conv2d(inner_dim, dim, 1),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w, heads = *x.shape, self.heads
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = 1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> b h (x y) d', h = heads), qkv)
|
||||
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
|
||||
return self.to_out(out)
|
||||
|
||||
def Aggregate(dim, dim_out):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(dim, dim_out, 3, padding = 1),
|
||||
LayerNorm(dim_out),
|
||||
nn.MaxPool2d(3, stride = 2, padding = 1)
|
||||
)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, seq_len, depth, heads, mlp_mult, dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
self.pos_emb = nn.Parameter(torch.randn(seq_len))
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, heads = heads, dropout = dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout))
|
||||
]))
|
||||
def forward(self, x):
|
||||
*_, h, w = x.shape
|
||||
|
||||
pos_emb = self.pos_emb[:(h * w)]
|
||||
pos_emb = rearrange(pos_emb, '(h w) -> () () h w', h = h, w = w)
|
||||
x = x + pos_emb
|
||||
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
class NesT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_size,
|
||||
patch_size,
|
||||
num_classes,
|
||||
dim,
|
||||
heads,
|
||||
num_hierarchies,
|
||||
block_repeats,
|
||||
mlp_mult = 4,
|
||||
channels = 3,
|
||||
dim_head = 64,
|
||||
dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
assert (image_size % patch_size) == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
patch_dim = channels * patch_size ** 2
|
||||
fmap_size = image_size // patch_size
|
||||
blocks = 2 ** (num_hierarchies - 1)
|
||||
|
||||
seq_len = (fmap_size // blocks) ** 2 # sequence length is held constant across heirarchy
|
||||
hierarchies = list(reversed(range(num_hierarchies)))
|
||||
mults = [2 ** i for i in hierarchies]
|
||||
|
||||
layer_heads = list(map(lambda t: t * heads, mults))
|
||||
layer_dims = list(map(lambda t: t * dim, mults))
|
||||
|
||||
layer_dims = [*layer_dims, layer_dims[-1]]
|
||||
dim_pairs = zip(layer_dims[:-1], layer_dims[1:])
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (p1 p2 c) h w', p1 = patch_size, p2 = patch_size),
|
||||
nn.Conv2d(patch_dim, layer_dims[0], 1),
|
||||
)
|
||||
|
||||
block_repeats = cast_tuple(block_repeats, num_hierarchies)
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
for level, heads, (dim_in, dim_out), block_repeat in zip(hierarchies, layer_heads, dim_pairs, block_repeats):
|
||||
is_last = level == 0
|
||||
depth = block_repeat
|
||||
|
||||
self.layers.append(nn.ModuleList([
|
||||
Transformer(dim_in, seq_len, depth, heads, mlp_mult, dropout),
|
||||
Aggregate(dim_in, dim_out) if not is_last else nn.Identity()
|
||||
]))
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
LayerNorm(dim),
|
||||
Reduce('b c h w -> b c', 'mean'),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
x = self.to_patch_embedding(img)
|
||||
b, c, h, w = x.shape
|
||||
|
||||
num_hierarchies = len(self.layers)
|
||||
|
||||
for level, (transformer, aggregate) in zip(reversed(range(num_hierarchies)), self.layers):
|
||||
block_size = 2 ** level
|
||||
x = rearrange(x, 'b c (b1 h) (b2 w) -> (b b1 b2) c h w', b1 = block_size, b2 = block_size)
|
||||
x = transformer(x)
|
||||
x = rearrange(x, '(b b1 b2) c h w -> b c (b1 h) (b2 w)', b1 = block_size, b2 = block_size)
|
||||
x = aggregate(x)
|
||||
|
||||
return self.mlp_head(x)
|
||||
@@ -89,8 +89,8 @@ class DepthWiseConv2d(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, kernel_size, padding, stride, bias = True):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
|
||||
nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)
|
||||
nn.Conv2d(dim_in, dim_out, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
|
||||
nn.Conv2d(dim_out, dim_out, kernel_size = 1, bias = bias)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
@@ -162,8 +162,9 @@ class PiT(nn.Module):
|
||||
layers.append(Pool(dim))
|
||||
dim *= 2
|
||||
|
||||
self.layers = nn.Sequential(
|
||||
*layers,
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
@@ -174,7 +175,9 @@ class PiT(nn.Module):
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding
|
||||
x += self.pos_embedding[:, :n+1]
|
||||
x = self.dropout(x)
|
||||
|
||||
return self.layers(x)
|
||||
x = self.layers(x)
|
||||
|
||||
return self.mlp_head(x[:, 0])
|
||||
|
||||
@@ -8,7 +8,7 @@ def find_modules(nn_module, type):
|
||||
return [module for module in nn_module.modules() if isinstance(module, type)]
|
||||
|
||||
class Recorder(nn.Module):
|
||||
def __init__(self, vit):
|
||||
def __init__(self, vit, device = None):
|
||||
super().__init__()
|
||||
self.vit = vit
|
||||
|
||||
@@ -17,6 +17,7 @@ class Recorder(nn.Module):
|
||||
self.hooks = []
|
||||
self.hook_registered = False
|
||||
self.ejected = False
|
||||
self.device = device
|
||||
|
||||
def _hook(self, _, input, output):
|
||||
self.recordings.append(output.clone().detach())
|
||||
@@ -45,10 +46,14 @@ class Recorder(nn.Module):
|
||||
def forward(self, img):
|
||||
assert not self.ejected, 'recorder has been ejected, cannot be used anymore'
|
||||
self.clear()
|
||||
|
||||
if not self.hook_registered:
|
||||
self._register_hook()
|
||||
|
||||
pred = self.vit(img)
|
||||
attns = torch.stack(self.recordings, dim = 1)
|
||||
|
||||
# move all recordings to one device before stacking
|
||||
target_device = self.device if self.device is not None else img.device
|
||||
recordings = tuple(map(lambda t: t.to(target_device), self.recordings))
|
||||
|
||||
attns = torch.stack(recordings, dim = 1)
|
||||
return pred, attns
|
||||
|
||||
268
vit_pytorch/regionvit.py
Normal file
268
vit_pytorch/regionvit.py
Normal file
@@ -0,0 +1,268 @@
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange, Reduce
|
||||
import torch.nn.functional as F
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def cast_tuple(val, length = 1):
|
||||
return val if isinstance(val, tuple) else ((val,) * length)
|
||||
|
||||
def divisible_by(val, d):
|
||||
return (val % d) == 0
|
||||
|
||||
# helper classes
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(dim_in, dim_out, 3, stride = 2, padding = 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
class PEG(nn.Module):
|
||||
def __init__(self, dim, kernel_size = 3):
|
||||
super().__init__()
|
||||
self.proj = nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.proj(x) + x
|
||||
|
||||
# transformer classes
|
||||
|
||||
def FeedForward(dim, mult = 4, dropout = 0.):
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, dim * mult, 1),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(dim * mult, dim, 1)
|
||||
)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
heads = 4,
|
||||
dim_head = 32,
|
||||
dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim)
|
||||
|
||||
def forward(self, x, rel_pos_bias = None):
|
||||
h = self.heads
|
||||
|
||||
# prenorm
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
|
||||
|
||||
# split heads
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
|
||||
q = q * self.scale
|
||||
|
||||
sim = einsum('b h i d, b h j d -> b h i j', q, k)
|
||||
|
||||
# add relative positional bias for local tokens
|
||||
|
||||
if exists(rel_pos_bias):
|
||||
sim = sim + rel_pos_bias
|
||||
|
||||
attn = sim.softmax(dim = -1)
|
||||
|
||||
# merge heads
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class R2LTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
window_size,
|
||||
depth = 4,
|
||||
heads = 4,
|
||||
dim_head = 32,
|
||||
attn_dropout = 0.,
|
||||
ff_dropout = 0.,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
self.window_size = window_size
|
||||
rel_positions = 2 * window_size - 1
|
||||
self.local_rel_pos_bias = nn.Embedding(rel_positions ** 2, heads)
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout),
|
||||
FeedForward(dim, dropout = ff_dropout)
|
||||
]))
|
||||
|
||||
def forward(self, local_tokens, region_tokens):
|
||||
device = local_tokens.device
|
||||
lh, lw = local_tokens.shape[-2:]
|
||||
rh, rw = region_tokens.shape[-2:]
|
||||
window_size_h, window_size_w = lh // rh, lw // rw
|
||||
|
||||
local_tokens = rearrange(local_tokens, 'b c h w -> b (h w) c')
|
||||
region_tokens = rearrange(region_tokens, 'b c h w -> b (h w) c')
|
||||
|
||||
# calculate local relative positional bias
|
||||
|
||||
h_range = torch.arange(window_size_h, device = device)
|
||||
w_range = torch.arange(window_size_w, device = device)
|
||||
|
||||
grid_x, grid_y = torch.meshgrid(h_range, w_range)
|
||||
grid = torch.stack((grid_x, grid_y))
|
||||
grid = rearrange(grid, 'c h w -> c (h w)')
|
||||
grid = (grid[:, :, None] - grid[:, None, :]) + (self.window_size - 1)
|
||||
bias_indices = (grid * torch.tensor([1, self.window_size * 2 - 1], device = device)[:, None, None]).sum(dim = 0)
|
||||
rel_pos_bias = self.local_rel_pos_bias(bias_indices)
|
||||
rel_pos_bias = rearrange(rel_pos_bias, 'i j h -> () h i j')
|
||||
rel_pos_bias = F.pad(rel_pos_bias, (1, 0, 1, 0), value = 0)
|
||||
|
||||
# go through r2l transformer layers
|
||||
|
||||
for attn, ff in self.layers:
|
||||
region_tokens = attn(region_tokens) + region_tokens
|
||||
|
||||
# concat region tokens to local tokens
|
||||
|
||||
local_tokens = rearrange(local_tokens, 'b (h w) d -> b h w d', h = lh)
|
||||
local_tokens = rearrange(local_tokens, 'b (h p1) (w p2) d -> (b h w) (p1 p2) d', p1 = window_size_h, p2 = window_size_w)
|
||||
region_tokens = rearrange(region_tokens, 'b n d -> (b n) () d')
|
||||
|
||||
# do self attention on local tokens, along with its regional token
|
||||
|
||||
region_and_local_tokens = torch.cat((region_tokens, local_tokens), dim = 1)
|
||||
region_and_local_tokens = attn(region_and_local_tokens, rel_pos_bias = rel_pos_bias) + region_and_local_tokens
|
||||
|
||||
# split back local and regional tokens
|
||||
|
||||
region_tokens, local_tokens = region_and_local_tokens[:, :1], region_and_local_tokens[:, 1:]
|
||||
local_tokens = rearrange(local_tokens, '(b h w) (p1 p2) d -> b (h p1 w p2) d', h = lh // window_size_h, w = lw // window_size_w, p1 = window_size_h)
|
||||
region_tokens = rearrange(region_tokens, '(b n) () d -> b n d', n = rh * rw)
|
||||
|
||||
# feedforwards
|
||||
|
||||
local_tokens = ff(local_tokens) + local_tokens
|
||||
region_tokens = ff(region_tokens) + region_tokens
|
||||
|
||||
local_tokens = rearrange(local_tokens, 'b (h w) c -> b c h w', h = lh, w = lw)
|
||||
region_tokens = rearrange(region_tokens, 'b (h w) c -> b c h w', h = rh, w = rw)
|
||||
return local_tokens, region_tokens
|
||||
|
||||
# classes
|
||||
|
||||
class RegionViT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim = (64, 128, 256, 512),
|
||||
depth = (2, 2, 8, 2),
|
||||
window_size = 7,
|
||||
num_classes = 1000,
|
||||
tokenize_local_3_conv = False,
|
||||
local_patch_size = 4,
|
||||
use_peg = False,
|
||||
attn_dropout = 0.,
|
||||
ff_dropout = 0.,
|
||||
channels = 3,
|
||||
):
|
||||
super().__init__()
|
||||
dim = cast_tuple(dim, 4)
|
||||
depth = cast_tuple(depth, 4)
|
||||
assert len(dim) == 4, 'dim needs to be a single value or a tuple of length 4'
|
||||
assert len(depth) == 4, 'depth needs to be a single value or a tuple of length 4'
|
||||
|
||||
self.local_patch_size = local_patch_size
|
||||
|
||||
region_patch_size = local_patch_size * window_size
|
||||
self.region_patch_size = local_patch_size * window_size
|
||||
|
||||
init_dim, *_, last_dim = dim
|
||||
|
||||
# local and region encoders
|
||||
|
||||
if tokenize_local_3_conv:
|
||||
self.local_encoder = nn.Sequential(
|
||||
nn.Conv2d(3, init_dim, 3, 2, 1),
|
||||
nn.LayerNorm(init_dim),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(init_dim, init_dim, 3, 2, 1),
|
||||
nn.LayerNorm(init_dim),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(init_dim, init_dim, 3, 1, 1)
|
||||
)
|
||||
else:
|
||||
self.local_encoder = nn.Conv2d(3, init_dim, 8, 4, 3)
|
||||
|
||||
self.region_encoder = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = region_patch_size, p2 = region_patch_size),
|
||||
nn.Conv2d((region_patch_size ** 2) * channels, init_dim, 1)
|
||||
)
|
||||
|
||||
# layers
|
||||
|
||||
current_dim = init_dim
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
for ind, dim, num_layers in zip(range(4), dim, depth):
|
||||
not_first = ind != 0
|
||||
need_downsample = not_first
|
||||
need_peg = not_first and use_peg
|
||||
|
||||
self.layers.append(nn.ModuleList([
|
||||
Downsample(current_dim, dim) if need_downsample else nn.Identity(),
|
||||
PEG(dim) if need_peg else nn.Identity(),
|
||||
R2LTransformer(dim, depth = num_layers, window_size = window_size, attn_dropout = attn_dropout, ff_dropout = ff_dropout)
|
||||
]))
|
||||
|
||||
current_dim = dim
|
||||
|
||||
# final logits
|
||||
|
||||
self.to_logits = nn.Sequential(
|
||||
Reduce('b c h w -> b c', 'mean'),
|
||||
nn.LayerNorm(last_dim),
|
||||
nn.Linear(last_dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
return_local_tokens = False
|
||||
):
|
||||
*_, h, w = x.shape
|
||||
assert divisible_by(h, self.region_patch_size) and divisible_by(w, self.region_patch_size), 'height and width must be divisible by region patch size'
|
||||
assert divisible_by(h, self.local_patch_size) and divisible_by(w, self.local_patch_size), 'height and width must be divisible by local patch size'
|
||||
|
||||
local_tokens = self.local_encoder(x)
|
||||
region_tokens = self.region_encoder(x)
|
||||
|
||||
for down, peg, transformer in self.layers:
|
||||
local_tokens, region_tokens = down(local_tokens), down(region_tokens)
|
||||
local_tokens = peg(local_tokens)
|
||||
local_tokens, region_tokens = transformer(local_tokens, region_tokens)
|
||||
|
||||
return self.to_logits(region_tokens)
|
||||
@@ -19,7 +19,7 @@ class AxialRotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_freq = 10):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
scales = torch.logspace(1., log(max_freq / 2) / log(2), self.dim // 4, base = 2)
|
||||
scales = torch.linspace(1., max_freq / 2, self.dim // 4)
|
||||
self.register_buffer('scales', scales)
|
||||
|
||||
def forward(self, x):
|
||||
@@ -43,6 +43,16 @@ class AxialRotaryEmbedding(nn.Module):
|
||||
sin, cos = map(lambda t: repeat(t, 'n d -> () n (d j)', j = 2), (sin, cos))
|
||||
return sin, cos
|
||||
|
||||
class DepthWiseConv2d(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, kernel_size, padding, stride = 1, bias = True):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
|
||||
nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
# helper classes
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
@@ -53,17 +63,31 @@ class PreNorm(nn.Module):
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
class SpatialConv(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, kernel, bias = False):
|
||||
super().__init__()
|
||||
self.conv = DepthWiseConv2d(dim_in, dim_out, kernel, padding = kernel // 2, bias = False)
|
||||
self.cls_proj = nn.Linear(dim_in, dim_out) if dim_in != dim_out else nn.Identity()
|
||||
|
||||
def forward(self, x, fmap_dims):
|
||||
cls_token, x = x[:, :1], x[:, 1:]
|
||||
x = rearrange(x, 'b (h w) d -> b d h w', **fmap_dims)
|
||||
x = self.conv(x)
|
||||
x = rearrange(x, 'b d h w -> b (h w) d')
|
||||
cls_token = self.cls_proj(cls_token)
|
||||
return torch.cat((cls_token, x), dim = 1)
|
||||
|
||||
class GEGLU(nn.Module):
|
||||
def forward(self, x):
|
||||
x, gates = x.chunk(2, dim = -1)
|
||||
return F.gelu(gates) * x
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0., use_glu = True):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim * 2),
|
||||
GEGLU(),
|
||||
nn.Linear(dim, hidden_dim * 2 if use_glu else hidden_dim),
|
||||
GEGLU() if use_glu else nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
@@ -72,36 +96,54 @@ class FeedForward(nn.Module):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., use_rotary = True, use_ds_conv = True, conv_query_kernel = 5):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.use_rotary = use_rotary
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.use_ds_conv = use_ds_conv
|
||||
|
||||
self.to_q = SpatialConv(dim, inner_dim, conv_query_kernel, bias = False) if use_ds_conv else nn.Linear(dim, inner_dim, bias = False)
|
||||
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, pos_emb):
|
||||
def forward(self, x, pos_emb, fmap_dims):
|
||||
b, n, _, h = *x.shape, self.heads
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
|
||||
to_q_kwargs = {'fmap_dims': fmap_dims} if self.use_ds_conv else {}
|
||||
q = self.to_q(x, **to_q_kwargs)
|
||||
|
||||
qkv = (q, *self.to_kv(x).chunk(2, dim = -1))
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)
|
||||
|
||||
# apply 2d rotary embeddings to queries and keys, excluding CLS tokens
|
||||
if self.use_rotary:
|
||||
# apply 2d rotary embeddings to queries and keys, excluding CLS tokens
|
||||
|
||||
sin, cos = pos_emb
|
||||
(q_cls, q), (k_cls, k) = map(lambda t: (t[:, :1], t[:, 1:]), (q, k))
|
||||
q, k = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k))
|
||||
sin, cos = pos_emb
|
||||
dim_rotary = sin.shape[-1]
|
||||
|
||||
# concat back the CLS tokens
|
||||
(q_cls, q), (k_cls, k) = map(lambda t: (t[:, :1], t[:, 1:]), (q, k))
|
||||
|
||||
q = torch.cat((q_cls, q), dim = 1)
|
||||
k = torch.cat((k_cls, k), dim = 1)
|
||||
# handle the case where rotary dimension < head dimension
|
||||
|
||||
(q, q_pass), (k, k_pass) = map(lambda t: (t[..., :dim_rotary], t[..., dim_rotary:]), (q, k))
|
||||
q, k = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k))
|
||||
q, k = map(lambda t: torch.cat(t, dim = -1), ((q, q_pass), (k, k_pass)))
|
||||
|
||||
# concat back the CLS tokens
|
||||
|
||||
q = torch.cat((q_cls, q), dim = 1)
|
||||
k = torch.cat((k_cls, k), dim = 1)
|
||||
|
||||
dots = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
|
||||
@@ -112,39 +154,40 @@ class Attention(nn.Module):
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, image_size, dropout = 0., use_rotary = True, use_ds_conv = True, use_glu = True):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
self.pos_emb = AxialRotaryEmbedding(dim_head)
|
||||
self.pos_emb = AxialRotaryEmbedding(dim_head, max_freq = image_size)
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
|
||||
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_rotary = use_rotary, use_ds_conv = use_ds_conv)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout, use_glu = use_glu))
|
||||
]))
|
||||
def forward(self, x):
|
||||
def forward(self, x, fmap_dims):
|
||||
pos_emb = self.pos_emb(x[:, 1:])
|
||||
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x, pos_emb = pos_emb) + x
|
||||
x = attn(x, pos_emb = pos_emb, fmap_dims = fmap_dims) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
# Rotary Vision Transformer
|
||||
|
||||
class RvT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., use_rotary = True, use_ds_conv = True, use_glu = True):
|
||||
super().__init__()
|
||||
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
patch_dim = channels * patch_size ** 2
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
||||
nn.Linear(patch_dim, dim),
|
||||
)
|
||||
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, image_size, dropout, use_rotary, use_ds_conv, use_glu)
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
@@ -152,12 +195,15 @@ class RvT(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
b, _, h, w, p = *img.shape, self.patch_size
|
||||
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
n = x.shape[1]
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
|
||||
x = self.transformer(x)
|
||||
fmap_dims = {'h': h // p, 'w': w // p}
|
||||
x = self.transformer(x, fmap_dims = fmap_dims)
|
||||
|
||||
return self.mlp_head(x)
|
||||
return self.mlp_head(x[:, 0])
|
||||
|
||||
@@ -35,13 +35,14 @@ class T2TViT(nn.Module):
|
||||
for i, (kernel_size, stride) in enumerate(t2t_layers):
|
||||
layer_dim *= kernel_size ** 2
|
||||
is_first = i == 0
|
||||
is_last = i == (len(t2t_layers) - 1)
|
||||
output_image_size = conv_output_size(output_image_size, kernel_size, stride, stride // 2)
|
||||
|
||||
layers.extend([
|
||||
RearrangeImage() if not is_first else nn.Identity(),
|
||||
nn.Unfold(kernel_size = kernel_size, stride = stride, padding = stride // 2),
|
||||
Rearrange('b c n -> b n c'),
|
||||
Transformer(dim = layer_dim, heads = 1, depth = 1, dim_head = layer_dim, mlp_dim = layer_dim, dropout = dropout),
|
||||
Transformer(dim = layer_dim, heads = 1, depth = 1, dim_head = layer_dim, mlp_dim = layer_dim, dropout = dropout) if not is_last else nn.Identity(),
|
||||
])
|
||||
|
||||
layers.append(nn.Linear(layer_dim, dim))
|
||||
@@ -71,7 +72,7 @@ class T2TViT(nn.Module):
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding
|
||||
x += self.pos_embedding[:, :n+1]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
229
vit_pytorch/twins_svt.py
Normal file
229
vit_pytorch/twins_svt.py
Normal file
@@ -0,0 +1,229 @@
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helper methods
|
||||
|
||||
def group_dict_by_key(cond, d):
|
||||
return_val = [dict(), dict()]
|
||||
for key in d.keys():
|
||||
match = bool(cond(key))
|
||||
ind = int(not match)
|
||||
return_val[ind][key] = d[key]
|
||||
return (*return_val,)
|
||||
|
||||
def group_by_key_prefix_and_remove_prefix(prefix, d):
|
||||
kwargs_with_prefix, kwargs = group_dict_by_key(lambda x: x.startswith(prefix), d)
|
||||
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
||||
return kwargs_without_prefix, kwargs
|
||||
|
||||
# classes
|
||||
|
||||
class Residual(nn.Module):
|
||||
def __init__(self, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(x, **kwargs) + x
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps = 1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (std + self.eps) * self.g + self.b
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = LayerNorm(dim)
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
x = self.norm(x)
|
||||
return self.fn(x, **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, mult = 4, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(dim, dim * mult, 1),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Conv2d(dim * mult, dim, 1),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class PatchEmbedding(nn.Module):
|
||||
def __init__(self, *, dim, dim_out, patch_size):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_out = dim_out
|
||||
self.patch_size = patch_size
|
||||
self.proj = nn.Conv2d(patch_size ** 2 * dim, dim_out, 1)
|
||||
|
||||
def forward(self, fmap):
|
||||
p = self.patch_size
|
||||
fmap = rearrange(fmap, 'b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = p, p2 = p)
|
||||
return self.proj(fmap)
|
||||
|
||||
class PEG(nn.Module):
|
||||
def __init__(self, dim, kernel_size = 3):
|
||||
super().__init__()
|
||||
self.proj = Residual(nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1))
|
||||
|
||||
def forward(self, x):
|
||||
return self.proj(x)
|
||||
|
||||
class LocalAttention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., patch_size = 7):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.patch_size = patch_size
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False)
|
||||
self.to_kv = nn.Conv2d(dim, inner_dim * 2, 1, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Conv2d(inner_dim, dim, 1),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, fmap):
|
||||
shape, p = fmap.shape, self.patch_size
|
||||
b, n, x, y, h = *shape, self.heads
|
||||
x, y = map(lambda t: t // p, (x, y))
|
||||
|
||||
fmap = rearrange(fmap, 'b c (x p1) (y p2) -> (b x y) c p1 p2', p1 = p, p2 = p)
|
||||
|
||||
q, k, v = (self.to_q(fmap), *self.to_kv(fmap).chunk(2, dim = 1))
|
||||
q, k, v = map(lambda t: rearrange(t, 'b (h d) p1 p2 -> (b h) (p1 p2) d', h = h), (q, k, v))
|
||||
|
||||
dots = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
|
||||
attn = dots.softmax(dim = - 1)
|
||||
|
||||
out = einsum('b i j, b j d -> b i d', attn, v)
|
||||
out = rearrange(out, '(b x y h) (p1 p2) d -> b (h d) (x p1) (y p2)', h = h, x = x, y = y, p1 = p, p2 = p)
|
||||
return self.to_out(out)
|
||||
|
||||
class GlobalAttention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., k = 7):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False)
|
||||
self.to_kv = nn.Conv2d(dim, inner_dim * 2, k, stride = k, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Conv2d(inner_dim, dim, 1),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
shape = x.shape
|
||||
b, n, _, y, h = *shape, self.heads
|
||||
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = 1))
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> (b h) (x y) d', h = h), (q, k, v))
|
||||
|
||||
dots = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
|
||||
attn = dots.softmax(dim = -1)
|
||||
|
||||
out = einsum('b i j, b j d -> b i d', attn, v)
|
||||
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, y = y)
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads = 8, dim_head = 64, mlp_mult = 4, local_patch_size = 7, global_k = 7, dropout = 0., has_local = True):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Residual(PreNorm(dim, LocalAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout, patch_size = local_patch_size))) if has_local else nn.Identity(),
|
||||
Residual(PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout))) if has_local else nn.Identity(),
|
||||
Residual(PreNorm(dim, GlobalAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout, k = global_k))),
|
||||
Residual(PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout)))
|
||||
]))
|
||||
def forward(self, x):
|
||||
for local_attn, ff1, global_attn, ff2 in self.layers:
|
||||
x = local_attn(x)
|
||||
x = ff1(x)
|
||||
x = global_attn(x)
|
||||
x = ff2(x)
|
||||
return x
|
||||
|
||||
class TwinsSVT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
num_classes,
|
||||
s1_emb_dim = 64,
|
||||
s1_patch_size = 4,
|
||||
s1_local_patch_size = 7,
|
||||
s1_global_k = 7,
|
||||
s1_depth = 1,
|
||||
s2_emb_dim = 128,
|
||||
s2_patch_size = 2,
|
||||
s2_local_patch_size = 7,
|
||||
s2_global_k = 7,
|
||||
s2_depth = 1,
|
||||
s3_emb_dim = 256,
|
||||
s3_patch_size = 2,
|
||||
s3_local_patch_size = 7,
|
||||
s3_global_k = 7,
|
||||
s3_depth = 5,
|
||||
s4_emb_dim = 512,
|
||||
s4_patch_size = 2,
|
||||
s4_local_patch_size = 7,
|
||||
s4_global_k = 7,
|
||||
s4_depth = 4,
|
||||
peg_kernel_size = 3,
|
||||
dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
kwargs = dict(locals())
|
||||
|
||||
dim = 3
|
||||
layers = []
|
||||
|
||||
for prefix in ('s1', 's2', 's3', 's4'):
|
||||
config, kwargs = group_by_key_prefix_and_remove_prefix(f'{prefix}_', kwargs)
|
||||
is_last = prefix == 's4'
|
||||
|
||||
dim_next = config['emb_dim']
|
||||
|
||||
layers.append(nn.Sequential(
|
||||
PatchEmbedding(dim = dim, dim_out = dim_next, patch_size = config['patch_size']),
|
||||
Transformer(dim = dim_next, depth = 1, local_patch_size = config['local_patch_size'], global_k = config['global_k'], dropout = dropout, has_local = not is_last),
|
||||
PEG(dim = dim_next, kernel_size = peg_kernel_size),
|
||||
Transformer(dim = dim_next, depth = config['depth'], local_patch_size = config['local_patch_size'], global_k = config['global_k'], dropout = dropout, has_local = not is_last)
|
||||
))
|
||||
|
||||
dim = dim_next
|
||||
|
||||
self.layers = nn.Sequential(
|
||||
*layers,
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
Rearrange('... () () -> ...'),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
@@ -1,10 +1,16 @@
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
# classes
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
@@ -44,15 +50,14 @@ class Attention(nn.Module):
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
b, n, _, h = *x.shape, self.heads
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
@@ -74,13 +79,17 @@ class Transformer(nn.Module):
|
||||
class ViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
|
||||
super().__init__()
|
||||
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
patch_dim = channels * patch_size ** 2
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
|
||||
nn.Linear(patch_dim, dim),
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user