Compare commits

...

52 Commits

Author SHA1 Message Date
Phil Wang
d2d6de01d3 0.20.7 2021-08-30 08:14:43 -07:00
Phil Wang
b9eadaef60 Merge pull request #151 from developer0hye/patch-1
Cleanup Attention Class & matmul based implementation for TensorRT conversion
2021-08-30 08:14:11 -07:00
Yonghye Kwon
24ac8350bf remove unused package 2021-08-30 18:25:03 +09:00
Yonghye Kwon
ca3cef9de0 Cleanup Attention Class 2021-08-30 18:05:16 +09:00
Phil Wang
6e1be11517 0.20.6 2021-08-21 09:03:54 -07:00
Phil Wang
73ed562ce4 Merge pull request #147 from developer0hye/patch-4
Make T2T process any scale image
2021-08-21 09:03:42 -07:00
Phil Wang
ff863175a6 Merge pull request #146 from developer0hye/patch-1
Make Pit process image with width and height less than the image_size
2021-08-21 09:03:31 -07:00
Yonghye Kwon
ca0bdca192 Make model process any scale image
Related to #145
2021-08-21 22:35:26 +09:00
Yonghye Kwon
1c70271778 Support image with width and height less than the image_size
Related to #145
2021-08-21 22:25:46 +09:00
Phil Wang
d7d3febfe3 Merge pull request #144 from developer0hye/patch-1
Remove unused package
2021-08-20 10:14:02 -07:00
Yonghye Kwon
946815164a Remove unused package 2021-08-20 13:44:57 +09:00
Phil Wang
aeed3381c1 use hardswish for levit 2021-08-19 08:22:55 -07:00
Phil Wang
3f754956fb remove last transformer layer in t2t 2021-08-14 08:06:23 -07:00
Phil Wang
918869571c fix hard distillation, thanks to @CiaoHe 2021-08-12 08:40:57 -07:00
Phil Wang
e5324242be fix wrong norm in nest 2021-08-05 12:55:48 -07:00
Phil Wang
22da26fa4b fix recorder in data parallel situation 2021-07-08 10:15:07 -07:00
Phil Wang
a6c085a2df 0.20.0 for cct 2021-07-02 15:48:48 -07:00
Phil Wang
121353c604 Merge pull request #128 from stevenwalton/main
Adding Compact Convolutional Transformers (CCT)
2021-07-02 15:48:32 -07:00
alih
2ece3333da Minor changes 2021-07-01 17:51:35 -07:00
Ali Hassani
a73030c9aa Update README.md 2021-07-01 16:41:27 -07:00
Steven Walton
780f91a220 Tested and changed README format 2021-07-01 16:26:41 -07:00
Steven Walton
88451068e8 Adding CCT
Adding Compact Convolutional Transformers (CCT) from Escaping the Big Data
Paradigm with Compact Transformers by Hassani et. al.
https://arxiv.org/abs/2104.05704
2021-07-01 16:22:33 -07:00
Phil Wang
64a2ef6462 fix mpp 2021-06-16 16:46:32 -07:00
Phil Wang
53884f583f 0.19.5 2021-06-16 14:24:46 -07:00
Phil Wang
e616b5dcbc Merge pull request #101 from zankner/mpp-fix
Mpp fix
2021-06-16 14:24:26 -07:00
Phil Wang
60ad4e266e layernorm on channel dimension == instancenorm2d with affine set to true 2021-06-03 16:41:45 -07:00
Phil Wang
a254a0258a fix typo 2021-06-01 07:33:00 -07:00
Phil Wang
26df10c0b7 fix max pool in nest 2021-05-28 11:06:02 -07:00
Phil Wang
17cb8976df make nest resilient to dimension that are not divisible by number of heads 2021-05-27 22:41:07 -07:00
Phil Wang
daf3abbeb5 add NesT 2021-05-27 22:02:17 -07:00
Phil Wang
b483b16833 0.18.4 2021-05-18 14:40:33 -07:00
Phil Wang
c457573808 Merge pull request #118 from loctruong96/main
update  mpp.py to work on GPU
2021-05-18 14:40:17 -07:00
Loc Truong
e75b6d0251 Update mpp.py
fix issue with GPU device mismatch
2021-05-16 20:07:49 -07:00
Phil Wang
679e5be3e7 apply scale to 2d rel pos bias in levit 2021-05-10 11:37:23 -07:00
Phil Wang
7333979e6b add link to official repo for levit 2021-05-06 13:12:30 -07:00
Phil Wang
74b402377b add image 2021-05-02 15:40:53 -07:00
Phil Wang
41d2d460d0 link to yannic 2021-05-02 14:51:55 -07:00
Phil Wang
04f86dee3c implement SOTA new self-supervised learning technique from facebook for vision transformers, Dino 2021-05-02 14:00:36 -07:00
Phil Wang
6549522629 be able to accept non-square patches, thanks to @FilipAndersson245 2021-05-01 20:04:41 -07:00
Phil Wang
6a80a4ef89 update readme 2021-05-01 11:51:35 -07:00
Phil Wang
9f05587a7d 0.17.2 2021-04-30 06:44:59 -07:00
Phil Wang
65bb350e85 0.17.2 2021-04-30 06:44:54 -07:00
Phil Wang
fd4a7dfcf8 Merge pull request #102 from jon-tow/rvt-add-use-glu-flag
Add `use_glu` flag to `RvT`
2021-04-30 06:44:41 -07:00
Jonathan Tow
6f3a5fcf0b Add use_glu flag to RvT 2021-04-30 02:07:41 -04:00
Phil Wang
7807f24509 fix small bug 2021-04-29 15:39:41 -07:00
Phil Wang
a612327126 readme 2021-04-29 15:22:12 -07:00
Phil Wang
30a1335d31 release twins svt 2021-04-29 14:55:25 -07:00
Phil Wang
ab781f7ddb add Twins SVT (small) 2021-04-29 14:54:06 -07:00
Zack Ankner
a2df363224 adding un-normalizing targets and fix for mask token dimension 2021-04-29 15:43:22 -04:00
Phil Wang
4f3dbd003f for PiT, project to increased dimensions on first grouped conv for depthwise-conv 2021-04-29 12:41:00 -07:00
Zack Ankner
710b6d57d3 Merge pull request #1 from lucidrains/main
catch up
2021-04-29 19:33:25 +00:00
Phil Wang
60b5687a79 cleanup rvt 2021-04-27 11:45:46 -07:00
18 changed files with 1428 additions and 127 deletions

329
README.md
View File

@@ -38,6 +38,7 @@ preds = v(img) # (1, 1000)
```
## Parameters
- `image_size`: int.
Image size. If you have rectangular images, make sure your image size is the maximum of the width and height
- `patch_size`: int.
@@ -61,6 +62,7 @@ Dropout rate.
Embedding dropout rate.
- `pool`: string, either `cls` token pooling or `mean` pooling
## Distillation
<img src="./images/distill.png" width="300px"></img>
@@ -117,6 +119,7 @@ v = v.to_vit()
type(v) # <class 'vit_pytorch.vit_pytorch.ViT'>
```
## Deep ViT
This <a href="https://arxiv.org/abs/2103.11886">paper</a> notes that ViT struggles to attend at greater depths (past 12 layers), and suggests mixing the attention of each head post-softmax as a solution, dubbed Re-attention. The results line up with the <a href="https://github.com/lucidrains/x-transformers#talking-heads-attention">Talking Heads</a> paper from NLP.
@@ -200,6 +203,61 @@ img = torch.randn(1, 3, 224, 224)
preds = v(img) # (1, 1000)
```
## CCT
<img src="https://raw.githubusercontent.com/SHI-Labs/Compact-Transformers/main/images/model_sym.png" width="400px"></img>
<a href="https://arxiv.org/abs/2104.05704">CCT</a> proposes compact transformers
by using convolutions instead of patching and performing sequence pooling. This
allows for CCT to have high accuracy and a low number of parameters.
You can use this with two methods
```python
import torch
from vit_pytorch.cct import CCT
model = CCT(
img_size=224,
embedding_dim=384,
n_conv_layers=2,
kernel_size=7,
stride=2,
padding=3,
pooling_kernel_size=3,
pooling_stride=2,
pooling_padding=1,
num_layers=14,
num_heads=6,
mlp_radio=3.,
num_classes=1000,
positional_embedding='learnable', # ['sine', 'learnable', 'none']
)
```
Alternatively you can use one of several pre-defined models `[2,4,6,7,8,14,16]`
which pre-define the number of layers, number of attention heads, the mlp ratio,
and the embedding dimension.
```python
import torch
from vit_pytorch.cct import cct_14
model = cct_14(
img_size=224,
n_conv_layers=1,
kernel_size=7,
stride=2,
padding=3,
pooling_kernel_size=3,
pooling_stride=2,
pooling_padding=1,
num_classes=1000,
positional_embedding='learnable', # ['sine', 'learnable', 'none']
)
```
<a href="https://github.com/SHI-Labs/Compact-Transformers">Official
Repository</a> includes links to pretrained model checkpoints.
## Cross ViT
<img src="./images/cross_vit.png" width="400px"></img>
@@ -270,6 +328,8 @@ preds = v(img) # (1, 1000)
<a href="https://arxiv.org/abs/2104.01136">This paper</a> proposes a number of changes, including (1) convolutional embedding instead of patch-wise projection (2) downsampling in stages (3) extra non-linearity in attention (4) 2d relative positional biases instead of initial absolute positional bias (5) batchnorm in place of layernorm.
<a href="https://github.com/facebookresearch/LeViT">Official repository</a>
```python
import torch
from vit_pytorch.levit import LeViT
@@ -334,6 +394,73 @@ img = torch.randn(1, 3, 224, 224)
pred = v(img) # (1, 1000)
```
## Twins SVT
<img src="./images/twins_svt.png" width="400px"></img>
This <a href="https://arxiv.org/abs/2104.13840">paper</a> proposes mixing local and global attention, along with position encoding generator (proposed in <a href="https://arxiv.org/abs/2102.10882">CPVT</a>) and global average pooling, to achieve the same results as <a href="https://arxiv.org/abs/2103.14030">Swin</a>, without the extra complexity of shifted windows, CLS tokens, nor positional embeddings.
```python
import torch
from vit_pytorch.twins_svt import TwinsSVT
model = TwinsSVT(
num_classes = 1000, # number of output classes
s1_emb_dim = 64, # stage 1 - patch embedding projected dimension
s1_patch_size = 4, # stage 1 - patch size for patch embedding
s1_local_patch_size = 7, # stage 1 - patch size for local attention
s1_global_k = 7, # stage 1 - global attention key / value reduction factor, defaults to 7 as specified in paper
s1_depth = 1, # stage 1 - number of transformer blocks (local attn -> ff -> global attn -> ff)
s2_emb_dim = 128, # stage 2 (same as above)
s2_patch_size = 2,
s2_local_patch_size = 7,
s2_global_k = 7,
s2_depth = 1,
s3_emb_dim = 256, # stage 3 (same as above)
s3_patch_size = 2,
s3_local_patch_size = 7,
s3_global_k = 7,
s3_depth = 5,
s4_emb_dim = 512, # stage 4 (same as above)
s4_patch_size = 2,
s4_local_patch_size = 7,
s4_global_k = 7,
s4_depth = 4,
peg_kernel_size = 3, # positional encoding generator kernel size
dropout = 0. # dropout
)
img = torch.randn(1, 3, 224, 224)
pred = model(img) # (1, 1000)
```
## NesT
<img src="./images/nest.png" width="400px"></img>
This <a href="https://arxiv.org/abs/2105.12723">paper</a> decided to process the image in hierarchical stages, with attention only within tokens of local blocks, which aggregate as it moves up the heirarchy. The aggregation is done in the image plane, and contains a convolution and subsequent maxpool to allow it to pass information across the boundary.
You can use it with the following code (ex. NesT-T)
```python
import torch
from vit_pytorch.nest import NesT
nest = NesT(
image_size = 224,
patch_size = 4,
dim = 96,
heads = 3,
num_hierarchies = 3, # number of hierarchies
block_repeats = (8, 4, 1), # the number of transformer blocks at each heirarchy, starting from the bottom
num_classes = 1000
)
img = torch.randn(1, 3, 224, 224)
pred = nest(img) # (1, 1000)
```
## Masked Patch Prediction
Thanks to <a href="https://github.com/zankner">Zach</a>, you can train using the original masked patch prediction task presented in the paper, with the following code.
@@ -367,7 +494,7 @@ mpp_trainer = MPP(
opt = torch.optim.Adam(mpp_trainer.parameters(), lr=3e-4)
def sample_unlabelled_images():
return torch.randn(20, 3, 256, 256)
return torch.FloatTensor(20, 3, 256, 256).uniform_(0., 1.)
for _ in range(100):
images = sample_unlabelled_images()
@@ -380,6 +507,60 @@ for _ in range(100):
torch.save(model.state_dict(), './pretrained-net.pt')
```
## Dino
<img src="./images/dino.png" width="350px"></img>
You can train `ViT` with the recent SOTA self-supervised learning technique, <a href="https://arxiv.org/abs/2104.14294">Dino</a>, with the following code.
<a href="https://www.youtube.com/watch?v=h3ij3F3cPIk">Yannic Kilcher</a> video
```python
import torch
from vit_pytorch import ViT, Dino
model = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048
)
learner = Dino(
model,
image_size = 256,
hidden_layer = 'to_latent', # hidden layer name or index, from which to extract the embedding
projection_hidden_size = 256, # projector network hidden dimension
projection_layers = 4, # number of layers in projection network
num_classes_K = 65336, # output logits dimensions (referenced as K in paper)
student_temp = 0.9, # student temperature
teacher_temp = 0.04, # teacher temperature, needs to be annealed from 0.04 to 0.07 over 30 epochs
local_upper_crop_scale = 0.4, # upper bound for local crop - 0.4 was recommended in the paper
global_lower_crop_scale = 0.5, # lower bound for global crop - 0.5 was recommended in the paper
moving_average_decay = 0.9, # moving average of encoder - paper showed anywhere from 0.9 to 0.999 was ok
center_moving_average_decay = 0.9, # moving average of teacher centers - paper showed anywhere from 0.9 to 0.999 was ok
)
opt = torch.optim.Adam(learner.parameters(), lr = 3e-4)
def sample_unlabelled_images():
return torch.randn(20, 3, 256, 256)
for _ in range(100):
images = sample_unlabelled_images()
loss = learner(images)
opt.zero_grad()
loss.backward()
opt.step()
learner.update_moving_average() # update moving average of teacher encoder and teacher centers
# save your improved network
torch.save(model.state_dict(), './pretrained-net.pt')
```
## Accessing Attention
If you would like to visualize the attention weights (post-softmax) for your research, just follow the procedure below
@@ -423,56 +604,6 @@ v = v.eject() # wrapper is discarded and original ViT instance is returned
## Research Ideas
### Self Supervised Training
You can train this with a near SOTA self-supervised learning technique, <a href="https://github.com/lucidrains/byol-pytorch">BYOL</a>, with the following code.
(1)
```bash
$ pip install byol-pytorch
```
(2)
```python
import torch
from vit_pytorch import ViT
from byol_pytorch import BYOL
model = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048
)
learner = BYOL(
model,
image_size = 256,
hidden_layer = 'to_latent'
)
opt = torch.optim.Adam(learner.parameters(), lr=3e-4)
def sample_unlabelled_images():
return torch.randn(20, 3, 256, 256)
for _ in range(100):
images = sample_unlabelled_images()
loss = learner(images)
opt.zero_grad()
loss.backward()
opt.step()
learner.update_moving_average() # update moving average of target encoder
# save your improved network
torch.save(model.state_dict(), './pretrained-net.pt')
```
A pytorch-lightning script is ready for you to use at the repository link above.
### Efficient Attention
There may be some coming from computer vision who think attention still suffers from quadratic costs. Fortunately, we have a lot of new techniques that may help. This repository offers a way for you to plugin your own sparse attention transformer.
@@ -542,6 +673,58 @@ img = torch.randn(1, 3, 224, 224)
v(img) # (1, 1000)
```
## FAQ
- How do I pass in non-square images?
You can already pass in non-square images - you just have to make sure your height and width is less than or equal to the `image_size`, and both divisible by the `patch_size`
ex.
```python
import torch
from vit_pytorch import ViT
v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
img = torch.randn(1, 3, 256, 128) # <-- not a square
preds = v(img) # (1, 1000)
```
- How do I pass in non-square patches?
```python
import torch
from vit_pytorch import ViT
v = ViT(
num_classes = 1000,
image_size = (256, 128), # image size is a tuple of (height, width)
patch_size = (32, 16), # patch size is a tuple of (height, width)
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
img = torch.randn(1, 3, 256, 128)
preds = v(img)
```
## Resources
Coming from computer vision and new to transformers? Here are some resources that greatly accelerated my learning.
@@ -554,6 +737,17 @@ Coming from computer vision and new to transformers? Here are some resources tha
## Citations
```bibtex
@article{hassani2021escaping,
title = {Escaping the Big Data Paradigm with Compact Transformers},
author = {Ali Hassani and Steven Walton and Nikhil Shah and Abulikemu Abuduweili and Jiachen Li and Humphrey Shi},
year = 2021,
url = {https://arxiv.org/abs/2104.05704},
eprint = {2104.05704},
archiveprefix = {arXiv},
primaryclass = {cs.CV}
}
```
```bibtex
@misc{dosovitskiy2020image,
@@ -665,6 +859,17 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```
```bibtex
@misc{chu2021twins,
title = {Twins: Revisiting Spatial Attention Design in Vision Transformers},
author = {Xiangxiang Chu and Zhi Tian and Yuqing Wang and Bo Zhang and Haibing Ren and Xiaolin Wei and Huaxia Xia and Chunhua Shen},
year = {2021},
eprint = {2104.13840},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@misc{su2021roformer,
title = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
@@ -676,6 +881,28 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```
```bibtex
@misc{zhang2021aggregating,
title = {Aggregating Nested Transformers},
author = {Zizhao Zhang and Han Zhang and Long Zhao and Ting Chen and Tomas Pfister},
year = {2021},
eprint = {2105.12723},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@misc{caron2021emerging,
title = {Emerging Properties in Self-Supervised Vision Transformers},
author = {Mathilde Caron and Hugo Touvron and Ishan Misra and Hervé Jégou and Julien Mairal and Piotr Bojanowski and Armand Joulin},
year = {2021},
eprint = {2104.14294},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@misc{vaswani2017attention,
title = {Attention Is All You Need},

BIN
images/dino.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 84 KiB

BIN
images/nest.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 75 KiB

BIN
images/twins_svt.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 110 KiB

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.16.12',
version = '0.20.7',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',
@@ -15,8 +15,9 @@ setup(
'image recognition'
],
install_requires=[
'einops>=0.3',
'torch>=1.6',
'einops>=0.3'
'torchvision'
],
classifiers=[
'Development Status :: 4 - Beta',

View File

@@ -1 +1,2 @@
from vit_pytorch.vit import ViT
from vit_pytorch.dino import Dino

339
vit_pytorch/cct.py Normal file
View File

@@ -0,0 +1,339 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
# Pre-defined CCT Models
__all__ = ['cct_2', 'cct_4', 'cct_6', 'cct_7', 'cct_8', 'cct_14', 'cct_16']
def cct_2(*args, **kwargs):
return _cct(num_layers=2, num_heads=2, mlp_ratio=1, embedding_dim=128,
*args, **kwargs)
def cct_4(*args, **kwargs):
return _cct(num_layers=4, num_heads=2, mlp_ratio=1, embedding_dim=128,
*args, **kwargs)
def cct_6(*args, **kwargs):
return _cct(num_layers=6, num_heads=4, mlp_ratio=2, embedding_dim=256,
*args, **kwargs)
def cct_7(*args, **kwargs):
return _cct(num_layers=7, num_heads=4, mlp_ratio=2, embedding_dim=256,
*args, **kwargs)
def cct_8(*args, **kwargs):
return _cct(num_layers=8, num_heads=4, mlp_ratio=2, embedding_dim=256,
*args, **kwargs)
def cct_14(*args, **kwargs):
return _cct(num_layers=14, num_heads=6, mlp_ratio=3, embedding_dim=384,
*args, **kwargs)
def cct_16(*args, **kwargs):
return _cct(num_layers=16, num_heads=6, mlp_ratio=3, embedding_dim=384,
*args, **kwargs)
def _cct(num_layers, num_heads, mlp_ratio, embedding_dim,
kernel_size=3, stride=None, padding=None,
*args, **kwargs):
stride = stride if stride is not None else max(1, (kernel_size // 2) - 1)
padding = padding if padding is not None else max(1, (kernel_size // 2))
return CCT(num_layers=num_layers,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
embedding_dim=embedding_dim,
kernel_size=kernel_size,
stride=stride,
padding=padding,
*args, **kwargs)
# Modules
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1):
super().__init__()
self.num_heads = num_heads
head_dim = dim // self.num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=False)
self.attn_drop = nn.Dropout(attention_dropout)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(projection_dropout)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class TransformerEncoderLayer(nn.Module):
"""
Inspired by torch.nn.TransformerEncoderLayer and
rwightman's timm package.
"""
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
attention_dropout=0.1, drop_path_rate=0.1):
super(TransformerEncoderLayer, self).__init__()
self.pre_norm = nn.LayerNorm(d_model)
self.self_attn = Attention(dim=d_model, num_heads=nhead,
attention_dropout=attention_dropout, projection_dropout=dropout)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout1 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(d_model)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.dropout2 = nn.Dropout(dropout)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
self.activation = F.gelu
def forward(self, src: torch.Tensor, *args, **kwargs) -> torch.Tensor:
src = src + self.drop_path(self.self_attn(self.pre_norm(src)))
src = self.norm1(src)
src2 = self.linear2(self.dropout1(self.activation(self.linear1(src))))
src = src + self.drop_path(self.dropout2(src2))
return src
def drop_path(x, drop_prob: float = 0., training: bool = False):
"""
Obtained from: github.com:rwightman/pytorch-image-models
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
"""
Obtained from: github.com:rwightman/pytorch-image-models
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
class Tokenizer(nn.Module):
def __init__(self,
kernel_size, stride, padding,
pooling_kernel_size=3, pooling_stride=2, pooling_padding=1,
n_conv_layers=1,
n_input_channels=3,
n_output_channels=64,
in_planes=64,
activation=None,
max_pool=True,
conv_bias=False):
super(Tokenizer, self).__init__()
n_filter_list = [n_input_channels] + \
[in_planes for _ in range(n_conv_layers - 1)] + \
[n_output_channels]
self.conv_layers = nn.Sequential(
*[nn.Sequential(
nn.Conv2d(n_filter_list[i], n_filter_list[i + 1],
kernel_size=(kernel_size, kernel_size),
stride=(stride, stride),
padding=(padding, padding), bias=conv_bias),
nn.Identity() if activation is None else activation(),
nn.MaxPool2d(kernel_size=pooling_kernel_size,
stride=pooling_stride,
padding=pooling_padding) if max_pool else nn.Identity()
)
for i in range(n_conv_layers)
])
self.flattener = nn.Flatten(2, 3)
self.apply(self.init_weight)
def sequence_length(self, n_channels=3, height=224, width=224):
return self.forward(torch.zeros((1, n_channels, height, width))).shape[1]
def forward(self, x):
return self.flattener(self.conv_layers(x)).transpose(-2, -1)
@staticmethod
def init_weight(m):
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
class TransformerClassifier(nn.Module):
def __init__(self,
seq_pool=True,
embedding_dim=768,
num_layers=12,
num_heads=12,
mlp_ratio=4.0,
num_classes=1000,
dropout_rate=0.1,
attention_dropout=0.1,
stochastic_depth_rate=0.1,
positional_embedding='sine',
sequence_length=None,
*args, **kwargs):
super().__init__()
positional_embedding = positional_embedding if \
positional_embedding in ['sine', 'learnable', 'none'] else 'sine'
dim_feedforward = int(embedding_dim * mlp_ratio)
self.embedding_dim = embedding_dim
self.sequence_length = sequence_length
self.seq_pool = seq_pool
assert sequence_length is not None or positional_embedding == 'none', \
f"Positional embedding is set to {positional_embedding} and" \
f" the sequence length was not specified."
if not seq_pool:
sequence_length += 1
self.class_emb = nn.Parameter(torch.zeros(1, 1, self.embedding_dim),
requires_grad=True)
else:
self.attention_pool = nn.Linear(self.embedding_dim, 1)
if positional_embedding != 'none':
if positional_embedding == 'learnable':
self.positional_emb = nn.Parameter(torch.zeros(1, sequence_length, embedding_dim),
requires_grad=True)
nn.init.trunc_normal_(self.positional_emb, std=0.2)
else:
self.positional_emb = nn.Parameter(self.sinusoidal_embedding(sequence_length, embedding_dim),
requires_grad=False)
else:
self.positional_emb = None
self.dropout = nn.Dropout(p=dropout_rate)
dpr = [x.item() for x in torch.linspace(0, stochastic_depth_rate, num_layers)]
self.blocks = nn.ModuleList([
TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads,
dim_feedforward=dim_feedforward, dropout=dropout_rate,
attention_dropout=attention_dropout, drop_path_rate=dpr[i])
for i in range(num_layers)])
self.norm = nn.LayerNorm(embedding_dim)
self.fc = nn.Linear(embedding_dim, num_classes)
self.apply(self.init_weight)
def forward(self, x):
if self.positional_emb is None and x.size(1) < self.sequence_length:
x = F.pad(x, (0, 0, 0, self.n_channels - x.size(1)), mode='constant', value=0)
if not self.seq_pool:
cls_token = self.class_emb.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
if self.positional_emb is not None:
x += self.positional_emb
x = self.dropout(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
if self.seq_pool:
x = torch.matmul(F.softmax(self.attention_pool(x), dim=1).transpose(-1, -2), x).squeeze(-2)
else:
x = x[:, 0]
x = self.fc(x)
return x
@staticmethod
def init_weight(m):
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@staticmethod
def sinusoidal_embedding(n_channels, dim):
pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)]
for p in range(n_channels)])
pe[:, 0::2] = torch.sin(pe[:, 0::2])
pe[:, 1::2] = torch.cos(pe[:, 1::2])
return pe.unsqueeze(0)
# CCT Main model
class CCT(nn.Module):
def __init__(self,
img_size=224,
embedding_dim=768,
n_input_channels=3,
n_conv_layers=1,
kernel_size=7,
stride=2,
padding=3,
pooling_kernel_size=3,
pooling_stride=2,
pooling_padding=1,
*args, **kwargs):
super(CCT, self).__init__()
self.tokenizer = Tokenizer(n_input_channels=n_input_channels,
n_output_channels=embedding_dim,
kernel_size=kernel_size,
stride=stride,
padding=padding,
pooling_kernel_size=pooling_kernel_size,
pooling_stride=pooling_stride,
pooling_padding=pooling_padding,
max_pool=True,
activation=nn.ReLU,
n_conv_layers=n_conv_layers,
conv_bias=False)
self.classifier = TransformerClassifier(
sequence_length=self.tokenizer.sequence_length(n_channels=n_input_channels,
height=img_size,
width=img_size),
embedding_dim=embedding_dim,
seq_pool=True,
dropout_rate=0.,
attention_dropout=0.1,
stochastic_depth=0.1,
*args, **kwargs)
def forward(self, x):
x = self.tokenizer(x)
return self.classifier(x)

303
vit_pytorch/dino.py Normal file
View File

@@ -0,0 +1,303 @@
import copy
import random
from functools import wraps, partial
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms as T
# helper functions
def exists(val):
return val is not None
def default(val, default):
return val if exists(val) else default
def singleton(cache_key):
def inner_fn(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
instance = getattr(self, cache_key)
if instance is not None:
return instance
instance = fn(self, *args, **kwargs)
setattr(self, cache_key, instance)
return instance
return wrapper
return inner_fn
def get_module_device(module):
return next(module.parameters()).device
def set_requires_grad(model, val):
for p in model.parameters():
p.requires_grad = val
# loss function # (algorithm 1 in the paper)
def loss_fn(
teacher_logits,
student_logits,
teacher_temp,
student_temp,
centers,
eps = 1e-20
):
teacher_logits = teacher_logits.detach()
student_probs = (student_logits / student_temp).softmax(dim = -1)
teacher_probs = ((teacher_logits - centers) / teacher_temp).softmax(dim = -1)
return - (teacher_probs * torch.log(student_probs + eps)).sum(dim = -1).mean()
# augmentation utils
class RandomApply(nn.Module):
def __init__(self, fn, p):
super().__init__()
self.fn = fn
self.p = p
def forward(self, x):
if random.random() > self.p:
return x
return self.fn(x)
# exponential moving average
class EMA():
def __init__(self, beta):
super().__init__()
self.beta = beta
def update_average(self, old, new):
if old is None:
return new
return old * self.beta + (1 - self.beta) * new
def update_moving_average(ema_updater, ma_model, current_model):
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = ema_updater.update_average(old_weight, up_weight)
# MLP class for projector and predictor
class L2Norm(nn.Module):
def forward(self, x, eps = 1e-6):
norm = x.norm(dim = 1, keepdim = True).clamp(min = eps)
return x / norm
class MLP(nn.Module):
def __init__(self, dim, dim_out, num_layers, hidden_size = 256):
super().__init__()
layers = []
dims = (dim, *((hidden_size,) * (num_layers - 1)))
for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
is_last = ind == (len(dims) - 1)
layers.extend([
nn.Linear(layer_dim_in, layer_dim_out),
nn.GELU() if not is_last else nn.Identity()
])
self.net = nn.Sequential(
*layers,
L2Norm(),
nn.Linear(hidden_size, dim_out)
)
def forward(self, x):
return self.net(x)
# a wrapper class for the base neural network
# will manage the interception of the hidden layer output
# and pipe it into the projecter and predictor nets
class NetWrapper(nn.Module):
def __init__(self, net, output_dim, projection_hidden_size, projection_num_layers, layer = -2):
super().__init__()
self.net = net
self.layer = layer
self.projector = None
self.projection_hidden_size = projection_hidden_size
self.projection_num_layers = projection_num_layers
self.output_dim = output_dim
self.hidden = {}
self.hook_registered = False
def _find_layer(self):
if type(self.layer) == str:
modules = dict([*self.net.named_modules()])
return modules.get(self.layer, None)
elif type(self.layer) == int:
children = [*self.net.children()]
return children[self.layer]
return None
def _hook(self, _, input, output):
device = input[0].device
self.hidden[device] = output.flatten(1)
def _register_hook(self):
layer = self._find_layer()
assert layer is not None, f'hidden layer ({self.layer}) not found'
handle = layer.register_forward_hook(self._hook)
self.hook_registered = True
@singleton('projector')
def _get_projector(self, hidden):
_, dim = hidden.shape
projector = MLP(dim, self.output_dim, self.projection_num_layers, self.projection_hidden_size)
return projector.to(hidden)
def get_embedding(self, x):
if self.layer == -1:
return self.net(x)
if not self.hook_registered:
self._register_hook()
self.hidden.clear()
_ = self.net(x)
hidden = self.hidden[x.device]
self.hidden.clear()
assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
return hidden
def forward(self, x, return_projection = True):
embed = self.get_embedding(x)
if not return_projection:
return embed
projector = self._get_projector(embed)
return projector(embed), embed
# main class
class Dino(nn.Module):
def __init__(
self,
net,
image_size,
hidden_layer = -2,
projection_hidden_size = 256,
num_classes_K = 65336,
projection_layers = 4,
student_temp = 0.9,
teacher_temp = 0.04,
local_upper_crop_scale = 0.4,
global_lower_crop_scale = 0.5,
moving_average_decay = 0.9,
center_moving_average_decay = 0.9,
augment_fn = None,
augment_fn2 = None
):
super().__init__()
self.net = net
# default BYOL augmentation
DEFAULT_AUG = torch.nn.Sequential(
RandomApply(
T.ColorJitter(0.8, 0.8, 0.8, 0.2),
p = 0.3
),
T.RandomGrayscale(p=0.2),
T.RandomHorizontalFlip(),
RandomApply(
T.GaussianBlur((3, 3), (1.0, 2.0)),
p = 0.2
),
T.Normalize(
mean=torch.tensor([0.485, 0.456, 0.406]),
std=torch.tensor([0.229, 0.224, 0.225])),
)
self.augment1 = default(augment_fn, DEFAULT_AUG)
self.augment2 = default(augment_fn2, DEFAULT_AUG)
# local and global crops
self.local_crop = T.RandomResizedCrop((image_size, image_size), scale = (0.05, local_upper_crop_scale))
self.global_crop = T.RandomResizedCrop((image_size, image_size), scale = (global_lower_crop_scale, 1.))
self.student_encoder = NetWrapper(net, num_classes_K, projection_hidden_size, projection_layers, layer = hidden_layer)
self.teacher_encoder = None
self.teacher_ema_updater = EMA(moving_average_decay)
self.register_buffer('teacher_centers', torch.zeros(1, num_classes_K))
self.register_buffer('last_teacher_centers', torch.zeros(1, num_classes_K))
self.teacher_centering_ema_updater = EMA(center_moving_average_decay)
self.student_temp = student_temp
self.teacher_temp = teacher_temp
# get device of network and make wrapper same device
device = get_module_device(net)
self.to(device)
# send a mock image tensor to instantiate singleton parameters
self.forward(torch.randn(2, 3, image_size, image_size, device=device))
@singleton('teacher_encoder')
def _get_teacher_encoder(self):
teacher_encoder = copy.deepcopy(self.student_encoder)
set_requires_grad(teacher_encoder, False)
return teacher_encoder
def reset_moving_average(self):
del self.teacher_encoder
self.teacher_encoder = None
def update_moving_average(self):
assert self.teacher_encoder is not None, 'target encoder has not been created yet'
update_moving_average(self.teacher_ema_updater, self.teacher_encoder, self.student_encoder)
new_teacher_centers = self.teacher_centering_ema_updater.update_average(self.teacher_centers, self.last_teacher_centers)
self.teacher_centers.copy_(new_teacher_centers)
def forward(
self,
x,
return_embedding = False,
return_projection = True,
student_temp = None,
teacher_temp = None
):
if return_embedding:
return self.student_encoder(x, return_projection = return_projection)
image_one, image_two = self.augment1(x), self.augment2(x)
local_image_one, local_image_two = self.local_crop(image_one), self.local_crop(image_two)
global_image_one, global_image_two = self.global_crop(image_one), self.global_crop(image_two)
student_proj_one, _ = self.student_encoder(local_image_one)
student_proj_two, _ = self.student_encoder(local_image_two)
with torch.no_grad():
teacher_encoder = self._get_teacher_encoder()
teacher_proj_one, _ = teacher_encoder(global_image_one)
teacher_proj_two, _ = teacher_encoder(global_image_two)
loss_fn_ = partial(
loss_fn,
student_temp = default(student_temp, self.student_temp),
teacher_temp = default(teacher_temp, self.teacher_temp),
centers = self.teacher_centers
)
teacher_logits_avg = torch.cat((teacher_proj_one, teacher_proj_two)).mean(dim = 0)
self.last_teacher_centers.copy_(teacher_logits_avg)
loss = (loss_fn_(teacher_proj_one, student_proj_two) + loss_fn_(teacher_proj_two, student_proj_one)) / 2
return loss

View File

@@ -148,6 +148,6 @@ class DistillWrapper(nn.Module):
else:
teacher_labels = teacher_logits.argmax(dim = -1)
distill_loss = F.cross_entropy(student_logits, teacher_labels)
distill_loss = F.cross_entropy(distill_logits, teacher_labels)
return loss * (1 - alpha) + distill_loss * alpha

View File

@@ -29,7 +29,7 @@ class FeedForward(nn.Module):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(dim, dim * mult, 1),
nn.GELU(),
nn.Hardswish(),
nn.Dropout(dropout),
nn.Conv2d(dim * mult, dim, 1),
nn.Dropout(dropout)
@@ -84,7 +84,7 @@ class Attention(nn.Module):
def apply_pos_bias(self, fmap):
bias = self.pos_bias(self.pos_indices)
bias = rearrange(bias, 'i j h -> () h i j')
return fmap + bias
return fmap + (bias / self.scale)
def forward(self, x):
b, n, *_, h = *x.shape, self.heads

View File

@@ -1,20 +1,20 @@
import math
from functools import reduce
import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange, repeat
from einops import rearrange, repeat, reduce
# helpers
def exists(val):
return val is not None
def prob_mask_like(t, prob):
batch, seq_length, _ = t.shape
return torch.zeros((batch, seq_length)).float().uniform_(0, 1) < prob
def get_mask_subset_with_prob(patched_input, prob):
batch, seq_len, _, device = *patched_input.shape, patched_input.device
max_masked = math.ceil(prob * seq_len)
@@ -31,43 +31,45 @@ def get_mask_subset_with_prob(patched_input, prob):
class MPPLoss(nn.Module):
def __init__(self, patch_size, channels, output_channel_bits,
max_pixel_val):
super(MPPLoss, self).__init__()
def __init__(
self,
patch_size,
channels,
output_channel_bits,
max_pixel_val,
mean,
std
):
super().__init__()
self.patch_size = patch_size
self.channels = channels
self.output_channel_bits = output_channel_bits
self.max_pixel_val = max_pixel_val
self.mean = torch.tensor(mean).view(-1, 1, 1) if mean else None
self.std = torch.tensor(std).view(-1, 1, 1) if std else None
def forward(self, predicted_patches, target, mask):
p, c, mpv, bits, device = self.patch_size, self.channels, self.max_pixel_val, self.output_channel_bits, target.device
bin_size = mpv / (2 ** bits)
# un-normalize input
if exists(self.mean) and exists(self.std):
target = target * self.std + self.mean
# reshape target to patches
p = self.patch_size
target = rearrange(target,
"b c (h p1) (w p2) -> b (h w) c (p1 p2) ",
p1=p,
p2=p)
target = target.clamp(max = mpv) # clamp just in case
avg_target = reduce(target, 'b c (h p1) (w p2) -> b (h w) c', 'mean', p1 = p, p2 = p).contiguous()
avg_target = target.mean(dim=3)
bin_size = self.max_pixel_val / self.output_channel_bits
channel_bins = torch.arange(bin_size, self.max_pixel_val, bin_size)
channel_bins = torch.arange(bin_size, mpv, bin_size, device = device)
discretized_target = torch.bucketize(avg_target, channel_bins)
discretized_target = F.one_hot(discretized_target,
self.output_channel_bits)
c, bi = self.channels, self.output_channel_bits
discretized_target = rearrange(discretized_target,
"b n c bi -> b n (c bi)",
c=c,
bi=bi)
bin_mask = 2**torch.arange(c * bi - 1, -1,
-1).to(discretized_target.device,
discretized_target.dtype)
target_label = torch.sum(bin_mask * discretized_target, -1)
bin_mask = (2 ** bits) ** torch.arange(0, c, device = device).long()
bin_mask = rearrange(bin_mask, 'c -> () () c')
predicted_patches = predicted_patches[mask]
target_label = target_label[mask]
loss = F.cross_entropy(predicted_patches, target_label)
target_label = torch.sum(bin_mask * discretized_target, dim = -1)
loss = F.cross_entropy(predicted_patches[mask], target_label[mask])
return loss
@@ -75,21 +77,24 @@ class MPPLoss(nn.Module):
class MPP(nn.Module):
def __init__(self,
transformer,
patch_size,
dim,
output_channel_bits=3,
channels=3,
max_pixel_val=1.0,
mask_prob=0.15,
replace_prob=0.5,
random_patch_prob=0.5):
def __init__(
self,
transformer,
patch_size,
dim,
output_channel_bits=3,
channels=3,
max_pixel_val=1.0,
mask_prob=0.15,
replace_prob=0.5,
random_patch_prob=0.5,
mean=None,
std=None
):
super().__init__()
self.transformer = transformer
self.loss = MPPLoss(patch_size, channels, output_channel_bits,
max_pixel_val)
max_pixel_val, mean, std)
# output transformation
self.to_bits = nn.Linear(dim, 2**(output_channel_bits * channels))
@@ -103,7 +108,7 @@ class MPP(nn.Module):
self.random_patch_prob = random_patch_prob
# token ids
self.mask_token = nn.Parameter(torch.randn(1, 1, dim * channels))
self.mask_token = nn.Parameter(torch.randn(1, 1, channels * patch_size ** 2))
def forward(self, input, **kwargs):
transformer = self.transformer
@@ -127,8 +132,9 @@ class MPP(nn.Module):
random_patch_sampling_prob = self.random_patch_prob / (
1 - self.replace_prob)
random_patch_prob = prob_mask_like(input,
random_patch_sampling_prob)
bool_random_patch_prob = mask * random_patch_prob == True
random_patch_sampling_prob).to(mask.device)
bool_random_patch_prob = mask * (random_patch_prob == True)
random_patches = torch.randint(0,
input.shape[1],
(input.shape[0], input.shape[1]),
@@ -140,7 +146,7 @@ class MPP(nn.Module):
bool_random_patch_prob]
# [mask] input
replace_prob = prob_mask_like(input, self.replace_prob)
replace_prob = prob_mask_like(input, self.replace_prob).to(mask.device)
bool_mask_replace = (mask * replace_prob) == True
masked_input[bool_mask_replace] = self.mask_token

179
vit_pytorch/nest.py Normal file
View File

@@ -0,0 +1,179 @@
from functools import partial
import torch
from torch import nn, einsum
from einops import rearrange
from einops.layers.torch import Rearrange, Reduce
# helpers
def cast_tuple(val, depth):
return val if isinstance(val, tuple) else ((val,) * depth)
# classes
class LayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (std + self.eps) * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, mlp_mult = 4, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(dim, dim * mlp_mult, 1),
nn.GELU(),
nn.Dropout(dropout),
nn.Conv2d(dim * mlp_mult, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dropout = 0.):
super().__init__()
dim_head = dim // heads
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
self.to_out = nn.Sequential(
nn.Conv2d(inner_dim, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
b, c, h, w, heads = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> b h (x y) d', h = heads), qkv)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = self.attend(dots)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
return self.to_out(out)
def Aggregate(dim, dim_out):
return nn.Sequential(
nn.Conv2d(dim, dim_out, 3, padding = 1),
LayerNorm(dim_out),
nn.MaxPool2d(3, stride = 2, padding = 1)
)
class Transformer(nn.Module):
def __init__(self, dim, seq_len, depth, heads, mlp_mult, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
self.pos_emb = nn.Parameter(torch.randn(seq_len))
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout))
]))
def forward(self, x):
*_, h, w = x.shape
pos_emb = self.pos_emb[:(h * w)]
pos_emb = rearrange(pos_emb, '(h w) -> () () h w', h = h, w = w)
x = x + pos_emb
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class NesT(nn.Module):
def __init__(
self,
*,
image_size,
patch_size,
num_classes,
dim,
heads,
num_hierarchies,
block_repeats,
mlp_mult = 4,
channels = 3,
dim_head = 64,
dropout = 0.
):
super().__init__()
assert (image_size % patch_size) == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
fmap_size = image_size // patch_size
blocks = 2 ** (num_hierarchies - 1)
seq_len = (fmap_size // blocks) ** 2 # sequence length is held constant across heirarchy
hierarchies = list(reversed(range(num_hierarchies)))
mults = [2 ** i for i in hierarchies]
layer_heads = list(map(lambda t: t * heads, mults))
layer_dims = list(map(lambda t: t * dim, mults))
layer_dims = [*layer_dims, layer_dims[-1]]
dim_pairs = zip(layer_dims[:-1], layer_dims[1:])
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (p1 p2 c) h w', p1 = patch_size, p2 = patch_size),
nn.Conv2d(patch_dim, layer_dims[0], 1),
)
block_repeats = cast_tuple(block_repeats, num_hierarchies)
self.layers = nn.ModuleList([])
for level, heads, (dim_in, dim_out), block_repeat in zip(hierarchies, layer_heads, dim_pairs, block_repeats):
is_last = level == 0
depth = block_repeat
self.layers.append(nn.ModuleList([
Transformer(dim_in, seq_len, depth, heads, mlp_mult, dropout),
Aggregate(dim_in, dim_out) if not is_last else nn.Identity()
]))
self.mlp_head = nn.Sequential(
LayerNorm(dim),
Reduce('b c h w -> b c', 'mean'),
nn.Linear(dim, num_classes)
)
def forward(self, img):
x = self.to_patch_embedding(img)
b, c, h, w = x.shape
num_hierarchies = len(self.layers)
for level, (transformer, aggregate) in zip(reversed(range(num_hierarchies)), self.layers):
block_size = 2 ** level
x = rearrange(x, 'b c (b1 h) (b2 w) -> (b b1 b2) c h w', b1 = block_size, b2 = block_size)
x = transformer(x)
x = rearrange(x, '(b b1 b2) c h w -> b c (b1 h) (b2 w)', b1 = block_size, b2 = block_size)
x = aggregate(x)
return self.mlp_head(x)

View File

@@ -89,8 +89,8 @@ class DepthWiseConv2d(nn.Module):
def __init__(self, dim_in, dim_out, kernel_size, padding, stride, bias = True):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)
nn.Conv2d(dim_in, dim_out, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
nn.Conv2d(dim_out, dim_out, kernel_size = 1, bias = bias)
)
def forward(self, x):
return self.net(x)
@@ -175,7 +175,7 @@ class PiT(nn.Module):
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding
x += self.pos_embedding[:, :n+1]
x = self.dropout(x)
x = self.layers(x)

View File

@@ -8,7 +8,7 @@ def find_modules(nn_module, type):
return [module for module in nn_module.modules() if isinstance(module, type)]
class Recorder(nn.Module):
def __init__(self, vit):
def __init__(self, vit, device = None):
super().__init__()
self.vit = vit
@@ -17,6 +17,7 @@ class Recorder(nn.Module):
self.hooks = []
self.hook_registered = False
self.ejected = False
self.device = device
def _hook(self, _, input, output):
self.recordings.append(output.clone().detach())
@@ -45,10 +46,14 @@ class Recorder(nn.Module):
def forward(self, img):
assert not self.ejected, 'recorder has been ejected, cannot be used anymore'
self.clear()
if not self.hook_registered:
self._register_hook()
pred = self.vit(img)
attns = torch.stack(self.recordings, dim = 1)
# move all recordings to one device before stacking
target_device = self.device if self.device is not None else img.device
recordings = tuple(map(lambda t: t.to(target_device), self.recordings))
attns = torch.stack(recordings, dim = 1)
return pred, attns

View File

@@ -83,11 +83,11 @@ class GEGLU(nn.Module):
return F.gelu(gates) * x
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
def __init__(self, dim, hidden_dim, dropout = 0., use_glu = True):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim * 2),
GEGLU(),
nn.Linear(dim, hidden_dim * 2 if use_glu else hidden_dim),
GEGLU() if use_glu else nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
@@ -119,7 +119,8 @@ class Attention(nn.Module):
def forward(self, x, pos_emb, fmap_dims):
b, n, _, h = *x.shape, self.heads
q = self.to_q(x, fmap_dims = fmap_dims) if self.use_ds_conv else self.to_q(x)
to_q_kwargs = {'fmap_dims': fmap_dims} if self.use_ds_conv else {}
q = self.to_q(x, **to_q_kwargs)
qkv = (q, *self.to_kv(x).chunk(2, dim = -1))
@@ -153,14 +154,14 @@ class Attention(nn.Module):
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., use_rotary = True, use_ds_conv = True):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., use_rotary = True, use_ds_conv = True, use_glu = True):
super().__init__()
self.layers = nn.ModuleList([])
self.pos_emb = AxialRotaryEmbedding(dim_head)
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_rotary = use_rotary, use_ds_conv = use_ds_conv)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout, use_glu = use_glu))
]))
def forward(self, x, fmap_dims):
pos_emb = self.pos_emb(x[:, 1:])
@@ -173,7 +174,7 @@ class Transformer(nn.Module):
# Rotary Vision Transformer
class RvT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., use_rotary = True, use_ds_conv = True):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., use_rotary = True, use_ds_conv = True, use_glu = True):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
@@ -186,7 +187,7 @@ class RvT(nn.Module):
)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, use_rotary, use_ds_conv)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, use_rotary, use_ds_conv, use_glu)
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),

View File

@@ -35,13 +35,14 @@ class T2TViT(nn.Module):
for i, (kernel_size, stride) in enumerate(t2t_layers):
layer_dim *= kernel_size ** 2
is_first = i == 0
is_last = i == (len(t2t_layers) - 1)
output_image_size = conv_output_size(output_image_size, kernel_size, stride, stride // 2)
layers.extend([
RearrangeImage() if not is_first else nn.Identity(),
nn.Unfold(kernel_size = kernel_size, stride = stride, padding = stride // 2),
Rearrange('b c n -> b n c'),
Transformer(dim = layer_dim, heads = 1, depth = 1, dim_head = layer_dim, mlp_dim = layer_dim, dropout = dropout),
Transformer(dim = layer_dim, heads = 1, depth = 1, dim_head = layer_dim, mlp_dim = layer_dim, dropout = dropout) if not is_last else nn.Identity(),
])
layers.append(nn.Linear(layer_dim, dim))
@@ -71,7 +72,7 @@ class T2TViT(nn.Module):
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding
x += self.pos_embedding[:, :n+1]
x = self.dropout(x)
x = self.transformer(x)

229
vit_pytorch/twins_svt.py Normal file
View File

@@ -0,0 +1,229 @@
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helper methods
def group_dict_by_key(cond, d):
return_val = [dict(), dict()]
for key in d.keys():
match = bool(cond(key))
ind = int(not match)
return_val[ind][key] = d[key]
return (*return_val,)
def group_by_key_prefix_and_remove_prefix(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(lambda x: x.startswith(prefix), d)
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs
# classes
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
class LayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (std + self.eps) * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(dim, dim * mult, 1),
nn.GELU(),
nn.Dropout(dropout),
nn.Conv2d(dim * mult, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class PatchEmbedding(nn.Module):
def __init__(self, *, dim, dim_out, patch_size):
super().__init__()
self.dim = dim
self.dim_out = dim_out
self.patch_size = patch_size
self.proj = nn.Conv2d(patch_size ** 2 * dim, dim_out, 1)
def forward(self, fmap):
p = self.patch_size
fmap = rearrange(fmap, 'b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = p, p2 = p)
return self.proj(fmap)
class PEG(nn.Module):
def __init__(self, dim, kernel_size = 3):
super().__init__()
self.proj = Residual(nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1))
def forward(self, x):
return self.proj(x)
class LocalAttention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., patch_size = 7):
super().__init__()
inner_dim = dim_head * heads
self.patch_size = patch_size
self.heads = heads
self.scale = dim_head ** -0.5
self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False)
self.to_kv = nn.Conv2d(dim, inner_dim * 2, 1, bias = False)
self.to_out = nn.Sequential(
nn.Conv2d(inner_dim, dim, 1),
nn.Dropout(dropout)
)
def forward(self, fmap):
shape, p = fmap.shape, self.patch_size
b, n, x, y, h = *shape, self.heads
x, y = map(lambda t: t // p, (x, y))
fmap = rearrange(fmap, 'b c (x p1) (y p2) -> (b x y) c p1 p2', p1 = p, p2 = p)
q, k, v = (self.to_q(fmap), *self.to_kv(fmap).chunk(2, dim = 1))
q, k, v = map(lambda t: rearrange(t, 'b (h d) p1 p2 -> (b h) (p1 p2) d', h = h), (q, k, v))
dots = einsum('b i d, b j d -> b i j', q, k) * self.scale
attn = dots.softmax(dim = - 1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b x y h) (p1 p2) d -> b (h d) (x p1) (y p2)', h = h, x = x, y = y, p1 = p, p2 = p)
return self.to_out(out)
class GlobalAttention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., k = 7):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False)
self.to_kv = nn.Conv2d(dim, inner_dim * 2, k, stride = k, bias = False)
self.to_out = nn.Sequential(
nn.Conv2d(inner_dim, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
shape = x.shape
b, n, _, y, h = *shape, self.heads
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = 1))
q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> (b h) (x y) d', h = h), (q, k, v))
dots = einsum('b i d, b j d -> b i j', q, k) * self.scale
attn = dots.softmax(dim = -1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, y = y)
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads = 8, dim_head = 64, mlp_mult = 4, local_patch_size = 7, global_k = 7, dropout = 0., has_local = True):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Residual(PreNorm(dim, LocalAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout, patch_size = local_patch_size))) if has_local else nn.Identity(),
Residual(PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout))) if has_local else nn.Identity(),
Residual(PreNorm(dim, GlobalAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout, k = global_k))),
Residual(PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout)))
]))
def forward(self, x):
for local_attn, ff1, global_attn, ff2 in self.layers:
x = local_attn(x)
x = ff1(x)
x = global_attn(x)
x = ff2(x)
return x
class TwinsSVT(nn.Module):
def __init__(
self,
*,
num_classes,
s1_emb_dim = 64,
s1_patch_size = 4,
s1_local_patch_size = 7,
s1_global_k = 7,
s1_depth = 1,
s2_emb_dim = 128,
s2_patch_size = 2,
s2_local_patch_size = 7,
s2_global_k = 7,
s2_depth = 1,
s3_emb_dim = 256,
s3_patch_size = 2,
s3_local_patch_size = 7,
s3_global_k = 7,
s3_depth = 5,
s4_emb_dim = 512,
s4_patch_size = 2,
s4_local_patch_size = 7,
s4_global_k = 7,
s4_depth = 4,
peg_kernel_size = 3,
dropout = 0.
):
super().__init__()
kwargs = dict(locals())
dim = 3
layers = []
for prefix in ('s1', 's2', 's3', 's4'):
config, kwargs = group_by_key_prefix_and_remove_prefix(f'{prefix}_', kwargs)
is_last = prefix == 's4'
dim_next = config['emb_dim']
layers.append(nn.Sequential(
PatchEmbedding(dim = dim, dim_out = dim_next, patch_size = config['patch_size']),
Transformer(dim = dim_next, depth = 1, local_patch_size = config['local_patch_size'], global_k = config['global_k'], dropout = dropout, has_local = not is_last),
PEG(dim = dim_next, kernel_size = peg_kernel_size),
Transformer(dim = dim_next, depth = config['depth'], local_patch_size = config['local_patch_size'], global_k = config['global_k'], dropout = dropout, has_local = not is_last)
))
dim = dim_next
self.layers = nn.Sequential(
*layers,
nn.AdaptiveAvgPool2d(1),
Rearrange('... () () -> ...'),
nn.Linear(dim, num_classes)
)
def forward(self, x):
return self.layers(x)

View File

@@ -1,10 +1,16 @@
import torch
from torch import nn, einsum
import torch.nn.functional as F
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
@@ -44,15 +50,14 @@ class Attention(nn.Module):
) if project_out else nn.Identity()
def forward(self, x):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
@@ -74,13 +79,17 @@ class Transformer(nn.Module):
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.Linear(patch_dim, dim),
)