mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
Compare commits
39 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e8f6d72033 | ||
|
|
cb1729af28 | ||
|
|
9e50b2a41e | ||
|
|
06d375351e | ||
|
|
f196d1ec5b | ||
|
|
529044c9b3 | ||
|
|
c30655f3bc | ||
|
|
d2d6de01d3 | ||
|
|
b9eadaef60 | ||
|
|
24ac8350bf | ||
|
|
ca3cef9de0 | ||
|
|
6e1be11517 | ||
|
|
73ed562ce4 | ||
|
|
ff863175a6 | ||
|
|
ca0bdca192 | ||
|
|
1c70271778 | ||
|
|
d7d3febfe3 | ||
|
|
946815164a | ||
|
|
aeed3381c1 | ||
|
|
3f754956fb | ||
|
|
918869571c | ||
|
|
e5324242be | ||
|
|
22da26fa4b | ||
|
|
a6c085a2df | ||
|
|
121353c604 | ||
|
|
2ece3333da | ||
|
|
a73030c9aa | ||
|
|
780f91a220 | ||
|
|
88451068e8 | ||
|
|
64a2ef6462 | ||
|
|
53884f583f | ||
|
|
e616b5dcbc | ||
|
|
60ad4e266e | ||
|
|
a254a0258a | ||
|
|
26df10c0b7 | ||
|
|
17cb8976df | ||
|
|
daf3abbeb5 | ||
|
|
a2df363224 | ||
|
|
710b6d57d3 |
197
README.md
197
README.md
@@ -62,6 +62,7 @@ Dropout rate.
|
||||
Embedding dropout rate.
|
||||
- `pool`: string, either `cls` token pooling or `mean` pooling
|
||||
|
||||
|
||||
## Distillation
|
||||
|
||||
<img src="./images/distill.png" width="300px"></img>
|
||||
@@ -118,6 +119,7 @@ v = v.to_vit()
|
||||
type(v) # <class 'vit_pytorch.vit_pytorch.ViT'>
|
||||
```
|
||||
|
||||
|
||||
## Deep ViT
|
||||
|
||||
This <a href="https://arxiv.org/abs/2103.11886">paper</a> notes that ViT struggles to attend at greater depths (past 12 layers), and suggests mixing the attention of each head post-softmax as a solution, dubbed Re-attention. The results line up with the <a href="https://github.com/lucidrains/x-transformers#talking-heads-attention">Talking Heads</a> paper from NLP.
|
||||
@@ -201,6 +203,61 @@ img = torch.randn(1, 3, 224, 224)
|
||||
preds = v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## CCT
|
||||
<img src="https://raw.githubusercontent.com/SHI-Labs/Compact-Transformers/main/images/model_sym.png" width="400px"></img>
|
||||
|
||||
<a href="https://arxiv.org/abs/2104.05704">CCT</a> proposes compact transformers
|
||||
by using convolutions instead of patching and performing sequence pooling. This
|
||||
allows for CCT to have high accuracy and a low number of parameters.
|
||||
|
||||
You can use this with two methods
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.cct import CCT
|
||||
|
||||
model = CCT(
|
||||
img_size=224,
|
||||
embedding_dim=384,
|
||||
n_conv_layers=2,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
padding=3,
|
||||
pooling_kernel_size=3,
|
||||
pooling_stride=2,
|
||||
pooling_padding=1,
|
||||
num_layers=14,
|
||||
num_heads=6,
|
||||
mlp_radio=3.,
|
||||
num_classes=1000,
|
||||
positional_embedding='learnable', # ['sine', 'learnable', 'none']
|
||||
)
|
||||
```
|
||||
|
||||
Alternatively you can use one of several pre-defined models `[2,4,6,7,8,14,16]`
|
||||
which pre-define the number of layers, number of attention heads, the mlp ratio,
|
||||
and the embedding dimension.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.cct import cct_14
|
||||
|
||||
model = cct_14(
|
||||
img_size=224,
|
||||
n_conv_layers=1,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
padding=3,
|
||||
pooling_kernel_size=3,
|
||||
pooling_stride=2,
|
||||
pooling_padding=1,
|
||||
num_classes=1000,
|
||||
positional_embedding='learnable', # ['sine', 'learnable', 'none']
|
||||
)
|
||||
```
|
||||
<a href="https://github.com/SHI-Labs/Compact-Transformers">Official
|
||||
Repository</a> includes links to pretrained model checkpoints.
|
||||
|
||||
|
||||
## Cross ViT
|
||||
|
||||
<img src="./images/cross_vit.png" width="400px"></img>
|
||||
@@ -378,6 +435,100 @@ img = torch.randn(1, 3, 224, 224)
|
||||
pred = model(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## RegionViT
|
||||
|
||||
<img src="./images/regionvit.png" width="400px"></img>
|
||||
|
||||
<img src="./images/regionvit2.png" width="400px"></img>
|
||||
|
||||
<a href="https://arxiv.org/abs/2106.02689">This paper</a> proposes to divide up the feature map into local regions, whereby the local tokens attend to each other. Each local region has its own regional token which then attends to all its local tokens, as well as other regional tokens.
|
||||
|
||||
You can use it as follows
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.regionvit import RegionViT
|
||||
|
||||
model = RegionViT(
|
||||
dim = (64, 128, 256, 512), # tuple of size 4, indicating dimension at each stage
|
||||
depth = (2, 2, 8, 2), # depth of the region to local transformer at each stage
|
||||
window_size = 7, # window size, which should be either 7 or 14
|
||||
num_classes = 1000, # number of output lcasses
|
||||
tokenize_local_3_conv = False, # whether to use a 3 layer convolution to encode the local tokens from the image. the paper uses this for the smaller models, but uses only 1 conv (set to False) for the larger models
|
||||
use_peg = False, # whether to use positional generating module. they used this for object detection for a boost in performance
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 224, 224)
|
||||
|
||||
pred = model(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## NesT
|
||||
|
||||
<img src="./images/nest.png" width="400px"></img>
|
||||
|
||||
This <a href="https://arxiv.org/abs/2105.12723">paper</a> decided to process the image in hierarchical stages, with attention only within tokens of local blocks, which aggregate as it moves up the heirarchy. The aggregation is done in the image plane, and contains a convolution and subsequent maxpool to allow it to pass information across the boundary.
|
||||
|
||||
You can use it with the following code (ex. NesT-T)
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.nest import NesT
|
||||
|
||||
nest = NesT(
|
||||
image_size = 224,
|
||||
patch_size = 4,
|
||||
dim = 96,
|
||||
heads = 3,
|
||||
num_hierarchies = 3, # number of hierarchies
|
||||
block_repeats = (8, 4, 1), # the number of transformer blocks at each heirarchy, starting from the bottom
|
||||
num_classes = 1000
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 224, 224)
|
||||
|
||||
pred = nest(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## Masked Autoencoder
|
||||
|
||||
<img src="./images/mae.png" width="400px"/>
|
||||
|
||||
A new <a href="https://arxiv.org/abs/2111.06377">Kaiming He paper</a> proposes a simple autoencoder scheme where the vision transformer attends to a set of unmasked patches, and a smaller decoder tries to reconstruct the masked pixel values.
|
||||
|
||||
You can use it with the following code
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch import ViT, MAE
|
||||
|
||||
v = ViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048
|
||||
)
|
||||
|
||||
mae = MAE(
|
||||
encoder = v,
|
||||
masking_ratio = 0.75,
|
||||
decoder_dim = 1024,
|
||||
decoder_depth = 6,
|
||||
decoder_heads = 8
|
||||
)
|
||||
|
||||
images = torch.randn(8, 3, 256, 256)
|
||||
|
||||
loss = mae(images)
|
||||
loss.backward()
|
||||
|
||||
# that's all!
|
||||
# do the above in a for loop many times with a lot of images and your vision transformer will learn
|
||||
```
|
||||
|
||||
## Masked Patch Prediction
|
||||
|
||||
Thanks to <a href="https://github.com/zankner">Zach</a>, you can train using the original masked patch prediction task presented in the paper, with the following code.
|
||||
@@ -411,7 +562,7 @@ mpp_trainer = MPP(
|
||||
opt = torch.optim.Adam(mpp_trainer.parameters(), lr=3e-4)
|
||||
|
||||
def sample_unlabelled_images():
|
||||
return torch.randn(20, 3, 256, 256)
|
||||
return torch.FloatTensor(20, 3, 256, 256).uniform_(0., 1.)
|
||||
|
||||
for _ in range(100):
|
||||
images = sample_unlabelled_images()
|
||||
@@ -654,6 +805,17 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
|
||||
|
||||
## Citations
|
||||
```bibtex
|
||||
@article{hassani2021escaping,
|
||||
title = {Escaping the Big Data Paradigm with Compact Transformers},
|
||||
author = {Ali Hassani and Steven Walton and Nikhil Shah and Abulikemu Abuduweili and Jiachen Li and Humphrey Shi},
|
||||
year = 2021,
|
||||
url = {https://arxiv.org/abs/2104.05704},
|
||||
eprint = {2104.05704},
|
||||
archiveprefix = {arXiv},
|
||||
primaryclass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{dosovitskiy2020image,
|
||||
@@ -787,6 +949,28 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{zhang2021aggregating,
|
||||
title = {Aggregating Nested Transformers},
|
||||
author = {Zizhao Zhang and Han Zhang and Long Zhao and Ting Chen and Tomas Pfister},
|
||||
year = {2021},
|
||||
eprint = {2105.12723},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{chen2021regionvit,
|
||||
title = {RegionViT: Regional-to-Local Attention for Vision Transformers},
|
||||
author = {Chun-Fu Chen and Rameswar Panda and Quanfu Fan},
|
||||
year = {2021},
|
||||
eprint = {2106.02689},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{caron2021emerging,
|
||||
title = {Emerging Properties in Self-Supervised Vision Transformers},
|
||||
@@ -798,6 +982,17 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{he2021masked,
|
||||
title = {Masked Autoencoders Are Scalable Vision Learners},
|
||||
author = {Kaiming He and Xinlei Chen and Saining Xie and Yanghao Li and Piotr Dollár and Ross Girshick},
|
||||
year = {2021},
|
||||
eprint = {2111.06377},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{vaswani2017attention,
|
||||
title = {Attention Is All You Need},
|
||||
|
||||
@@ -364,9 +364,8 @@
|
||||
"\n",
|
||||
"val_transforms = transforms.Compose(\n",
|
||||
" [\n",
|
||||
" transforms.Resize((224, 224)),\n",
|
||||
" transforms.RandomResizedCrop(224),\n",
|
||||
" transforms.RandomHorizontalFlip(),\n",
|
||||
" transforms.Resize(256),\n",
|
||||
" transforms.CenterCrop(224),\n",
|
||||
" transforms.ToTensor(),\n",
|
||||
" ]\n",
|
||||
")\n",
|
||||
@@ -374,9 +373,8 @@
|
||||
"\n",
|
||||
"test_transforms = transforms.Compose(\n",
|
||||
" [\n",
|
||||
" transforms.Resize((224, 224)),\n",
|
||||
" transforms.RandomResizedCrop(224),\n",
|
||||
" transforms.RandomHorizontalFlip(),\n",
|
||||
" transforms.Resize(256),\n",
|
||||
" transforms.CenterCrop(224),\n",
|
||||
" transforms.ToTensor(),\n",
|
||||
" ]\n",
|
||||
")\n"
|
||||
@@ -6250,4 +6248,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 1
|
||||
}
|
||||
}
|
||||
BIN
images/mae.png
Normal file
BIN
images/mae.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 198 KiB |
BIN
images/nest.png
Normal file
BIN
images/nest.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 75 KiB |
BIN
images/regionvit.png
Normal file
BIN
images/regionvit.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 94 KiB |
BIN
images/regionvit2.png
Normal file
BIN
images/regionvit2.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 55 KiB |
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.18.4',
|
||||
version = '0.22.0',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
from vit_pytorch.vit import ViT
|
||||
from vit_pytorch.mae import MAE
|
||||
from vit_pytorch.dino import Dino
|
||||
|
||||
339
vit_pytorch/cct.py
Normal file
339
vit_pytorch/cct.py
Normal file
@@ -0,0 +1,339 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Pre-defined CCT Models
|
||||
__all__ = ['cct_2', 'cct_4', 'cct_6', 'cct_7', 'cct_8', 'cct_14', 'cct_16']
|
||||
|
||||
|
||||
def cct_2(*args, **kwargs):
|
||||
return _cct(num_layers=2, num_heads=2, mlp_ratio=1, embedding_dim=128,
|
||||
*args, **kwargs)
|
||||
|
||||
|
||||
def cct_4(*args, **kwargs):
|
||||
return _cct(num_layers=4, num_heads=2, mlp_ratio=1, embedding_dim=128,
|
||||
*args, **kwargs)
|
||||
|
||||
|
||||
def cct_6(*args, **kwargs):
|
||||
return _cct(num_layers=6, num_heads=4, mlp_ratio=2, embedding_dim=256,
|
||||
*args, **kwargs)
|
||||
|
||||
|
||||
def cct_7(*args, **kwargs):
|
||||
return _cct(num_layers=7, num_heads=4, mlp_ratio=2, embedding_dim=256,
|
||||
*args, **kwargs)
|
||||
|
||||
|
||||
def cct_8(*args, **kwargs):
|
||||
return _cct(num_layers=8, num_heads=4, mlp_ratio=2, embedding_dim=256,
|
||||
*args, **kwargs)
|
||||
|
||||
|
||||
def cct_14(*args, **kwargs):
|
||||
return _cct(num_layers=14, num_heads=6, mlp_ratio=3, embedding_dim=384,
|
||||
*args, **kwargs)
|
||||
|
||||
|
||||
def cct_16(*args, **kwargs):
|
||||
return _cct(num_layers=16, num_heads=6, mlp_ratio=3, embedding_dim=384,
|
||||
*args, **kwargs)
|
||||
|
||||
|
||||
def _cct(num_layers, num_heads, mlp_ratio, embedding_dim,
|
||||
kernel_size=3, stride=None, padding=None,
|
||||
*args, **kwargs):
|
||||
stride = stride if stride is not None else max(1, (kernel_size // 2) - 1)
|
||||
padding = padding if padding is not None else max(1, (kernel_size // 2))
|
||||
return CCT(num_layers=num_layers,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
embedding_dim=embedding_dim,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
*args, **kwargs)
|
||||
|
||||
|
||||
# Modules
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // self.num_heads
|
||||
self.scale = head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=False)
|
||||
self.attn_drop = nn.Dropout(attention_dropout)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(projection_dropout)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
"""
|
||||
Inspired by torch.nn.TransformerEncoderLayer and
|
||||
rwightman's timm package.
|
||||
"""
|
||||
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
||||
attention_dropout=0.1, drop_path_rate=0.1):
|
||||
super(TransformerEncoderLayer, self).__init__()
|
||||
self.pre_norm = nn.LayerNorm(d_model)
|
||||
self.self_attn = Attention(dim=d_model, num_heads=nhead,
|
||||
attention_dropout=attention_dropout, projection_dropout=dropout)
|
||||
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
|
||||
|
||||
self.activation = F.gelu
|
||||
|
||||
def forward(self, src: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
||||
src = src + self.drop_path(self.self_attn(self.pre_norm(src)))
|
||||
src = self.norm1(src)
|
||||
src2 = self.linear2(self.dropout1(self.activation(self.linear1(src))))
|
||||
src = src + self.drop_path(self.dropout2(src2))
|
||||
return src
|
||||
|
||||
|
||||
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
||||
"""
|
||||
Obtained from: github.com:rwightman/pytorch-image-models
|
||||
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
||||
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
||||
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
||||
'survival rate' as the argument.
|
||||
"""
|
||||
if drop_prob == 0. or not training:
|
||||
return x
|
||||
keep_prob = 1 - drop_prob
|
||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
||||
random_tensor.floor_() # binarize
|
||||
output = x.div(keep_prob) * random_tensor
|
||||
return output
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""
|
||||
Obtained from: github.com:rwightman/pytorch-image-models
|
||||
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
"""
|
||||
def __init__(self, drop_prob=None):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training)
|
||||
|
||||
|
||||
class Tokenizer(nn.Module):
|
||||
def __init__(self,
|
||||
kernel_size, stride, padding,
|
||||
pooling_kernel_size=3, pooling_stride=2, pooling_padding=1,
|
||||
n_conv_layers=1,
|
||||
n_input_channels=3,
|
||||
n_output_channels=64,
|
||||
in_planes=64,
|
||||
activation=None,
|
||||
max_pool=True,
|
||||
conv_bias=False):
|
||||
super(Tokenizer, self).__init__()
|
||||
|
||||
n_filter_list = [n_input_channels] + \
|
||||
[in_planes for _ in range(n_conv_layers - 1)] + \
|
||||
[n_output_channels]
|
||||
|
||||
self.conv_layers = nn.Sequential(
|
||||
*[nn.Sequential(
|
||||
nn.Conv2d(n_filter_list[i], n_filter_list[i + 1],
|
||||
kernel_size=(kernel_size, kernel_size),
|
||||
stride=(stride, stride),
|
||||
padding=(padding, padding), bias=conv_bias),
|
||||
nn.Identity() if activation is None else activation(),
|
||||
nn.MaxPool2d(kernel_size=pooling_kernel_size,
|
||||
stride=pooling_stride,
|
||||
padding=pooling_padding) if max_pool else nn.Identity()
|
||||
)
|
||||
for i in range(n_conv_layers)
|
||||
])
|
||||
|
||||
self.flattener = nn.Flatten(2, 3)
|
||||
self.apply(self.init_weight)
|
||||
|
||||
def sequence_length(self, n_channels=3, height=224, width=224):
|
||||
return self.forward(torch.zeros((1, n_channels, height, width))).shape[1]
|
||||
|
||||
def forward(self, x):
|
||||
return self.flattener(self.conv_layers(x)).transpose(-2, -1)
|
||||
|
||||
@staticmethod
|
||||
def init_weight(m):
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight)
|
||||
|
||||
|
||||
class TransformerClassifier(nn.Module):
|
||||
def __init__(self,
|
||||
seq_pool=True,
|
||||
embedding_dim=768,
|
||||
num_layers=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4.0,
|
||||
num_classes=1000,
|
||||
dropout_rate=0.1,
|
||||
attention_dropout=0.1,
|
||||
stochastic_depth_rate=0.1,
|
||||
positional_embedding='sine',
|
||||
sequence_length=None,
|
||||
*args, **kwargs):
|
||||
super().__init__()
|
||||
positional_embedding = positional_embedding if \
|
||||
positional_embedding in ['sine', 'learnable', 'none'] else 'sine'
|
||||
dim_feedforward = int(embedding_dim * mlp_ratio)
|
||||
self.embedding_dim = embedding_dim
|
||||
self.sequence_length = sequence_length
|
||||
self.seq_pool = seq_pool
|
||||
|
||||
assert sequence_length is not None or positional_embedding == 'none', \
|
||||
f"Positional embedding is set to {positional_embedding} and" \
|
||||
f" the sequence length was not specified."
|
||||
|
||||
if not seq_pool:
|
||||
sequence_length += 1
|
||||
self.class_emb = nn.Parameter(torch.zeros(1, 1, self.embedding_dim),
|
||||
requires_grad=True)
|
||||
else:
|
||||
self.attention_pool = nn.Linear(self.embedding_dim, 1)
|
||||
|
||||
if positional_embedding != 'none':
|
||||
if positional_embedding == 'learnable':
|
||||
self.positional_emb = nn.Parameter(torch.zeros(1, sequence_length, embedding_dim),
|
||||
requires_grad=True)
|
||||
nn.init.trunc_normal_(self.positional_emb, std=0.2)
|
||||
else:
|
||||
self.positional_emb = nn.Parameter(self.sinusoidal_embedding(sequence_length, embedding_dim),
|
||||
requires_grad=False)
|
||||
else:
|
||||
self.positional_emb = None
|
||||
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
dpr = [x.item() for x in torch.linspace(0, stochastic_depth_rate, num_layers)]
|
||||
self.blocks = nn.ModuleList([
|
||||
TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads,
|
||||
dim_feedforward=dim_feedforward, dropout=dropout_rate,
|
||||
attention_dropout=attention_dropout, drop_path_rate=dpr[i])
|
||||
for i in range(num_layers)])
|
||||
self.norm = nn.LayerNorm(embedding_dim)
|
||||
|
||||
self.fc = nn.Linear(embedding_dim, num_classes)
|
||||
self.apply(self.init_weight)
|
||||
|
||||
def forward(self, x):
|
||||
if self.positional_emb is None and x.size(1) < self.sequence_length:
|
||||
x = F.pad(x, (0, 0, 0, self.n_channels - x.size(1)), mode='constant', value=0)
|
||||
|
||||
if not self.seq_pool:
|
||||
cls_token = self.class_emb.expand(x.shape[0], -1, -1)
|
||||
x = torch.cat((cls_token, x), dim=1)
|
||||
|
||||
if self.positional_emb is not None:
|
||||
x += self.positional_emb
|
||||
|
||||
x = self.dropout(x)
|
||||
|
||||
for blk in self.blocks:
|
||||
x = blk(x)
|
||||
x = self.norm(x)
|
||||
|
||||
if self.seq_pool:
|
||||
x = torch.matmul(F.softmax(self.attention_pool(x), dim=1).transpose(-1, -2), x).squeeze(-2)
|
||||
else:
|
||||
x = x[:, 0]
|
||||
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def init_weight(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
@staticmethod
|
||||
def sinusoidal_embedding(n_channels, dim):
|
||||
pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)]
|
||||
for p in range(n_channels)])
|
||||
pe[:, 0::2] = torch.sin(pe[:, 0::2])
|
||||
pe[:, 1::2] = torch.cos(pe[:, 1::2])
|
||||
return pe.unsqueeze(0)
|
||||
|
||||
|
||||
# CCT Main model
|
||||
class CCT(nn.Module):
|
||||
def __init__(self,
|
||||
img_size=224,
|
||||
embedding_dim=768,
|
||||
n_input_channels=3,
|
||||
n_conv_layers=1,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
padding=3,
|
||||
pooling_kernel_size=3,
|
||||
pooling_stride=2,
|
||||
pooling_padding=1,
|
||||
*args, **kwargs):
|
||||
super(CCT, self).__init__()
|
||||
|
||||
self.tokenizer = Tokenizer(n_input_channels=n_input_channels,
|
||||
n_output_channels=embedding_dim,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
pooling_kernel_size=pooling_kernel_size,
|
||||
pooling_stride=pooling_stride,
|
||||
pooling_padding=pooling_padding,
|
||||
max_pool=True,
|
||||
activation=nn.ReLU,
|
||||
n_conv_layers=n_conv_layers,
|
||||
conv_bias=False)
|
||||
|
||||
self.classifier = TransformerClassifier(
|
||||
sequence_length=self.tokenizer.sequence_length(n_channels=n_input_channels,
|
||||
height=img_size,
|
||||
width=img_size),
|
||||
embedding_dim=embedding_dim,
|
||||
seq_pool=True,
|
||||
dropout_rate=0.,
|
||||
attention_dropout=0.1,
|
||||
stochastic_depth=0.1,
|
||||
*args, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.tokenizer(x)
|
||||
return self.classifier(x)
|
||||
|
||||
@@ -148,6 +148,6 @@ class DistillWrapper(nn.Module):
|
||||
|
||||
else:
|
||||
teacher_labels = teacher_logits.argmax(dim = -1)
|
||||
distill_loss = F.cross_entropy(student_logits, teacher_labels)
|
||||
distill_loss = F.cross_entropy(distill_logits, teacher_labels)
|
||||
|
||||
return loss * (1 - alpha) + distill_loss * alpha
|
||||
|
||||
@@ -29,7 +29,7 @@ class FeedForward(nn.Module):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(dim, dim * mult, 1),
|
||||
nn.GELU(),
|
||||
nn.Hardswish(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Conv2d(dim * mult, dim, 1),
|
||||
nn.Dropout(dropout)
|
||||
|
||||
93
vit_pytorch/mae.py
Normal file
93
vit_pytorch/mae.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import torch
|
||||
from math import ceil
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from vit_pytorch.vit import Transformer
|
||||
|
||||
class MAE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
encoder,
|
||||
decoder_dim,
|
||||
masking_ratio = 0.75,
|
||||
decoder_depth = 1,
|
||||
decoder_heads = 8,
|
||||
decoder_dim_head = 64
|
||||
):
|
||||
super().__init__()
|
||||
assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1'
|
||||
self.masking_ratio = masking_ratio
|
||||
|
||||
# extract some hyperparameters and functions from encoder (vision transformer to be trained)
|
||||
|
||||
self.encoder = encoder
|
||||
num_patches, encoder_dim = encoder.pos_embedding.shape[-2:]
|
||||
self.to_patch, self.patch_to_emb = encoder.to_patch_embedding[:2]
|
||||
pixel_values_per_patch = self.patch_to_emb.weight.shape[-1]
|
||||
|
||||
# decoder parameters
|
||||
|
||||
self.enc_to_dec = nn.Linear(encoder_dim, decoder_dim) if encoder_dim != decoder_dim else nn.Identity()
|
||||
self.mask_token = nn.Parameter(torch.randn(decoder_dim))
|
||||
self.decoder = Transformer(dim = decoder_dim, depth = decoder_depth, heads = decoder_heads, dim_head = decoder_dim_head, mlp_dim = decoder_dim * 4)
|
||||
self.decoder_pos_emb = nn.Embedding(num_patches, decoder_dim)
|
||||
self.to_pixels = nn.Linear(decoder_dim, pixel_values_per_patch)
|
||||
|
||||
def forward(self, img):
|
||||
device = img.device
|
||||
|
||||
# get patches
|
||||
|
||||
patches = self.to_patch(img)
|
||||
batch, num_patches, *_ = patches.shape
|
||||
|
||||
# patch to encoder tokens and add positions
|
||||
|
||||
tokens = self.patch_to_emb(patches)
|
||||
tokens = tokens + self.encoder.pos_embedding[:, 1:(num_patches + 1)]
|
||||
|
||||
# calculate of patches needed to be masked, and get random indices, dividing it up for mask vs unmasked
|
||||
|
||||
num_masked = int(self.masking_ratio * num_patches)
|
||||
rand_indices = torch.rand(batch, num_patches, device = device).argsort(dim = -1)
|
||||
masked_indices, unmasked_indices = rand_indices[:, :num_masked], rand_indices[:, num_masked:]
|
||||
|
||||
# get the unmasked tokens to be encoded
|
||||
|
||||
batch_range = torch.arange(batch, device = device)[:, None]
|
||||
tokens = tokens[batch_range, unmasked_indices]
|
||||
|
||||
# get the patches to be masked for the final reconstruction loss
|
||||
|
||||
masked_patches = patches[batch_range, masked_indices]
|
||||
|
||||
# attend with vision transformer
|
||||
|
||||
encoded_tokens = self.encoder.transformer(tokens)
|
||||
|
||||
# project encoder to decoder dimensions, if they are not equal - the paper says you can get away with a smaller dimension for decoder
|
||||
|
||||
decoder_tokens = self.enc_to_dec(encoded_tokens)
|
||||
|
||||
# repeat mask tokens for number of masked, and add the positions using the masked indices derived above
|
||||
|
||||
mask_tokens = repeat(self.mask_token, 'd -> b n d', b = batch, n = num_masked)
|
||||
mask_tokens = mask_tokens + self.decoder_pos_emb(masked_indices)
|
||||
|
||||
# concat the masked tokens to the decoder tokens and attend with decoder
|
||||
|
||||
decoder_tokens = torch.cat((decoder_tokens, mask_tokens), dim = 1)
|
||||
decoded_tokens = self.decoder(decoder_tokens)
|
||||
|
||||
# splice out the mask tokens and project to pixel values
|
||||
|
||||
mask_tokens = decoded_tokens[:, -num_masked:]
|
||||
pred_pixel_values = self.to_pixels(mask_tokens)
|
||||
|
||||
# calculate reconstruction loss
|
||||
|
||||
recon_loss = F.mse_loss(pred_pixel_values, masked_patches)
|
||||
return recon_loss
|
||||
@@ -1,20 +1,20 @@
|
||||
import math
|
||||
from functools import reduce
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops import rearrange, repeat, reduce
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def prob_mask_like(t, prob):
|
||||
batch, seq_length, _ = t.shape
|
||||
return torch.zeros((batch, seq_length)).float().uniform_(0, 1) < prob
|
||||
|
||||
|
||||
def get_mask_subset_with_prob(patched_input, prob):
|
||||
batch, seq_len, _, device = *patched_input.shape, patched_input.device
|
||||
max_masked = math.ceil(prob * seq_len)
|
||||
@@ -31,43 +31,45 @@ def get_mask_subset_with_prob(patched_input, prob):
|
||||
|
||||
|
||||
class MPPLoss(nn.Module):
|
||||
def __init__(self, patch_size, channels, output_channel_bits,
|
||||
max_pixel_val):
|
||||
super(MPPLoss, self).__init__()
|
||||
def __init__(
|
||||
self,
|
||||
patch_size,
|
||||
channels,
|
||||
output_channel_bits,
|
||||
max_pixel_val,
|
||||
mean,
|
||||
std
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.channels = channels
|
||||
self.output_channel_bits = output_channel_bits
|
||||
self.max_pixel_val = max_pixel_val
|
||||
|
||||
self.mean = torch.tensor(mean).view(-1, 1, 1) if mean else None
|
||||
self.std = torch.tensor(std).view(-1, 1, 1) if std else None
|
||||
|
||||
def forward(self, predicted_patches, target, mask):
|
||||
p, c, mpv, bits, device = self.patch_size, self.channels, self.max_pixel_val, self.output_channel_bits, target.device
|
||||
bin_size = mpv / (2 ** bits)
|
||||
|
||||
# un-normalize input
|
||||
if exists(self.mean) and exists(self.std):
|
||||
target = target * self.std + self.mean
|
||||
|
||||
# reshape target to patches
|
||||
p = self.patch_size
|
||||
target = rearrange(target,
|
||||
"b c (h p1) (w p2) -> b (h w) c (p1 p2) ",
|
||||
p1=p,
|
||||
p2=p)
|
||||
target = target.clamp(max = mpv) # clamp just in case
|
||||
avg_target = reduce(target, 'b c (h p1) (w p2) -> b (h w) c', 'mean', p1 = p, p2 = p).contiguous()
|
||||
|
||||
avg_target = target.mean(dim=3)
|
||||
|
||||
bin_size = self.max_pixel_val / self.output_channel_bits
|
||||
channel_bins = torch.arange(bin_size, self.max_pixel_val, bin_size).to(avg_target.device)
|
||||
channel_bins = torch.arange(bin_size, mpv, bin_size, device = device)
|
||||
discretized_target = torch.bucketize(avg_target, channel_bins)
|
||||
discretized_target = F.one_hot(discretized_target,
|
||||
self.output_channel_bits)
|
||||
c, bi = self.channels, self.output_channel_bits
|
||||
discretized_target = rearrange(discretized_target,
|
||||
"b n c bi -> b n (c bi)",
|
||||
c=c,
|
||||
bi=bi)
|
||||
|
||||
bin_mask = 2**torch.arange(c * bi - 1, -1,
|
||||
-1).to(discretized_target.device,
|
||||
discretized_target.dtype)
|
||||
target_label = torch.sum(bin_mask * discretized_target, -1)
|
||||
bin_mask = (2 ** bits) ** torch.arange(0, c, device = device).long()
|
||||
bin_mask = rearrange(bin_mask, 'c -> () () c')
|
||||
|
||||
predicted_patches = predicted_patches[mask]
|
||||
target_label = target_label[mask]
|
||||
loss = F.cross_entropy(predicted_patches, target_label)
|
||||
target_label = torch.sum(bin_mask * discretized_target, dim = -1)
|
||||
|
||||
loss = F.cross_entropy(predicted_patches[mask], target_label[mask])
|
||||
return loss
|
||||
|
||||
|
||||
@@ -75,20 +77,24 @@ class MPPLoss(nn.Module):
|
||||
|
||||
|
||||
class MPP(nn.Module):
|
||||
def __init__(self,
|
||||
transformer,
|
||||
patch_size,
|
||||
dim,
|
||||
output_channel_bits=3,
|
||||
channels=3,
|
||||
max_pixel_val=1.0,
|
||||
mask_prob=0.15,
|
||||
replace_prob=0.5,
|
||||
random_patch_prob=0.5):
|
||||
def __init__(
|
||||
self,
|
||||
transformer,
|
||||
patch_size,
|
||||
dim,
|
||||
output_channel_bits=3,
|
||||
channels=3,
|
||||
max_pixel_val=1.0,
|
||||
mask_prob=0.15,
|
||||
replace_prob=0.5,
|
||||
random_patch_prob=0.5,
|
||||
mean=None,
|
||||
std=None
|
||||
):
|
||||
super().__init__()
|
||||
self.transformer = transformer
|
||||
self.loss = MPPLoss(patch_size, channels, output_channel_bits,
|
||||
max_pixel_val)
|
||||
max_pixel_val, mean, std)
|
||||
|
||||
# output transformation
|
||||
self.to_bits = nn.Linear(dim, 2**(output_channel_bits * channels))
|
||||
@@ -102,7 +108,7 @@ class MPP(nn.Module):
|
||||
self.random_patch_prob = random_patch_prob
|
||||
|
||||
# token ids
|
||||
self.mask_token = nn.Parameter(torch.randn(1, 1, dim * channels))
|
||||
self.mask_token = nn.Parameter(torch.randn(1, 1, channels * patch_size ** 2))
|
||||
|
||||
def forward(self, input, **kwargs):
|
||||
transformer = self.transformer
|
||||
|
||||
179
vit_pytorch/nest.py
Normal file
179
vit_pytorch/nest.py
Normal file
@@ -0,0 +1,179 @@
|
||||
from functools import partial
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange, Reduce
|
||||
|
||||
# helpers
|
||||
|
||||
def cast_tuple(val, depth):
|
||||
return val if isinstance(val, tuple) else ((val,) * depth)
|
||||
|
||||
# classes
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps = 1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (std + self.eps) * self.g + self.b
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = LayerNorm(dim)
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, mlp_mult = 4, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(dim, dim * mlp_mult, 1),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Conv2d(dim * mlp_mult, dim, 1),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dropout = 0.):
|
||||
super().__init__()
|
||||
dim_head = dim // heads
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Conv2d(inner_dim, dim, 1),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w, heads = *x.shape, self.heads
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = 1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> b h (x y) d', h = heads), qkv)
|
||||
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
|
||||
return self.to_out(out)
|
||||
|
||||
def Aggregate(dim, dim_out):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(dim, dim_out, 3, padding = 1),
|
||||
LayerNorm(dim_out),
|
||||
nn.MaxPool2d(3, stride = 2, padding = 1)
|
||||
)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, seq_len, depth, heads, mlp_mult, dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
self.pos_emb = nn.Parameter(torch.randn(seq_len))
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, heads = heads, dropout = dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout))
|
||||
]))
|
||||
def forward(self, x):
|
||||
*_, h, w = x.shape
|
||||
|
||||
pos_emb = self.pos_emb[:(h * w)]
|
||||
pos_emb = rearrange(pos_emb, '(h w) -> () () h w', h = h, w = w)
|
||||
x = x + pos_emb
|
||||
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
class NesT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_size,
|
||||
patch_size,
|
||||
num_classes,
|
||||
dim,
|
||||
heads,
|
||||
num_hierarchies,
|
||||
block_repeats,
|
||||
mlp_mult = 4,
|
||||
channels = 3,
|
||||
dim_head = 64,
|
||||
dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
assert (image_size % patch_size) == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
patch_dim = channels * patch_size ** 2
|
||||
fmap_size = image_size // patch_size
|
||||
blocks = 2 ** (num_hierarchies - 1)
|
||||
|
||||
seq_len = (fmap_size // blocks) ** 2 # sequence length is held constant across heirarchy
|
||||
hierarchies = list(reversed(range(num_hierarchies)))
|
||||
mults = [2 ** i for i in hierarchies]
|
||||
|
||||
layer_heads = list(map(lambda t: t * heads, mults))
|
||||
layer_dims = list(map(lambda t: t * dim, mults))
|
||||
|
||||
layer_dims = [*layer_dims, layer_dims[-1]]
|
||||
dim_pairs = zip(layer_dims[:-1], layer_dims[1:])
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (p1 p2 c) h w', p1 = patch_size, p2 = patch_size),
|
||||
nn.Conv2d(patch_dim, layer_dims[0], 1),
|
||||
)
|
||||
|
||||
block_repeats = cast_tuple(block_repeats, num_hierarchies)
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
for level, heads, (dim_in, dim_out), block_repeat in zip(hierarchies, layer_heads, dim_pairs, block_repeats):
|
||||
is_last = level == 0
|
||||
depth = block_repeat
|
||||
|
||||
self.layers.append(nn.ModuleList([
|
||||
Transformer(dim_in, seq_len, depth, heads, mlp_mult, dropout),
|
||||
Aggregate(dim_in, dim_out) if not is_last else nn.Identity()
|
||||
]))
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
LayerNorm(dim),
|
||||
Reduce('b c h w -> b c', 'mean'),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
x = self.to_patch_embedding(img)
|
||||
b, c, h, w = x.shape
|
||||
|
||||
num_hierarchies = len(self.layers)
|
||||
|
||||
for level, (transformer, aggregate) in zip(reversed(range(num_hierarchies)), self.layers):
|
||||
block_size = 2 ** level
|
||||
x = rearrange(x, 'b c (b1 h) (b2 w) -> (b b1 b2) c h w', b1 = block_size, b2 = block_size)
|
||||
x = transformer(x)
|
||||
x = rearrange(x, '(b b1 b2) c h w -> b c (b1 h) (b2 w)', b1 = block_size, b2 = block_size)
|
||||
x = aggregate(x)
|
||||
|
||||
return self.mlp_head(x)
|
||||
@@ -175,7 +175,7 @@ class PiT(nn.Module):
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding
|
||||
x += self.pos_embedding[:, :n+1]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.layers(x)
|
||||
|
||||
@@ -8,7 +8,7 @@ def find_modules(nn_module, type):
|
||||
return [module for module in nn_module.modules() if isinstance(module, type)]
|
||||
|
||||
class Recorder(nn.Module):
|
||||
def __init__(self, vit):
|
||||
def __init__(self, vit, device = None):
|
||||
super().__init__()
|
||||
self.vit = vit
|
||||
|
||||
@@ -17,6 +17,7 @@ class Recorder(nn.Module):
|
||||
self.hooks = []
|
||||
self.hook_registered = False
|
||||
self.ejected = False
|
||||
self.device = device
|
||||
|
||||
def _hook(self, _, input, output):
|
||||
self.recordings.append(output.clone().detach())
|
||||
@@ -45,10 +46,14 @@ class Recorder(nn.Module):
|
||||
def forward(self, img):
|
||||
assert not self.ejected, 'recorder has been ejected, cannot be used anymore'
|
||||
self.clear()
|
||||
|
||||
if not self.hook_registered:
|
||||
self._register_hook()
|
||||
|
||||
pred = self.vit(img)
|
||||
attns = torch.stack(self.recordings, dim = 1)
|
||||
|
||||
# move all recordings to one device before stacking
|
||||
target_device = self.device if self.device is not None else img.device
|
||||
recordings = tuple(map(lambda t: t.to(target_device), self.recordings))
|
||||
|
||||
attns = torch.stack(recordings, dim = 1)
|
||||
return pred, attns
|
||||
|
||||
267
vit_pytorch/regionvit.py
Normal file
267
vit_pytorch/regionvit.py
Normal file
@@ -0,0 +1,267 @@
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange, Reduce
|
||||
import torch.nn.functional as F
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def cast_tuple(val, length = 1):
|
||||
return val if isinstance(val, tuple) else ((val,) * length)
|
||||
|
||||
def divisible_by(val, d):
|
||||
return (val % d) == 0
|
||||
|
||||
# helper classes
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(dim_in, dim_out, 3, stride = 2, padding = 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
class PEG(nn.Module):
|
||||
def __init__(self, dim, kernel_size = 3):
|
||||
super().__init__()
|
||||
self.proj = nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.proj(x) + x
|
||||
|
||||
# transformer classes
|
||||
|
||||
def FeedForward(dim, mult = 4, dropout = 0.):
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, dim * mult, 1),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(dim * mult, dim, 1)
|
||||
)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
heads = 4,
|
||||
dim_head = 32,
|
||||
dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim)
|
||||
|
||||
def forward(self, x, rel_pos_bias = None):
|
||||
h = self.heads
|
||||
|
||||
# prenorm
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
|
||||
|
||||
# split heads
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
|
||||
q = q * self.scale
|
||||
|
||||
sim = einsum('b h i d, b h j d -> b h i j', q, k)
|
||||
|
||||
# add relative positional bias for local tokens
|
||||
|
||||
if exists(rel_pos_bias):
|
||||
sim = sim + rel_pos_bias
|
||||
|
||||
attn = sim.softmax(dim = -1)
|
||||
|
||||
# merge heads
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class R2LTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
window_size,
|
||||
depth = 4,
|
||||
heads = 4,
|
||||
dim_head = 32,
|
||||
attn_dropout = 0.,
|
||||
ff_dropout = 0.,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
self.window_size = window_size
|
||||
rel_positions = 2 * window_size - 1
|
||||
self.local_rel_pos_bias = nn.Embedding(rel_positions ** 2, heads)
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout),
|
||||
FeedForward(dim, dropout = ff_dropout)
|
||||
]))
|
||||
|
||||
def forward(self, local_tokens, region_tokens):
|
||||
device = local_tokens.device
|
||||
lh, lw = local_tokens.shape[-2:]
|
||||
rh, rw = region_tokens.shape[-2:]
|
||||
window_size_h, window_size_w = lh // rh, lw // rw
|
||||
|
||||
local_tokens = rearrange(local_tokens, 'b c h w -> b (h w) c')
|
||||
region_tokens = rearrange(region_tokens, 'b c h w -> b (h w) c')
|
||||
|
||||
# calculate local relative positional bias
|
||||
|
||||
h_range = torch.arange(window_size_h, device = device)
|
||||
w_range = torch.arange(window_size_w, device = device)
|
||||
|
||||
grid_x, grid_y = torch.meshgrid(h_range, w_range)
|
||||
grid = torch.stack((grid_x, grid_y))
|
||||
grid = rearrange(grid, 'c h w -> c (h w)')
|
||||
grid = (grid[:, :, None] - grid[:, None, :]) + (self.window_size - 1)
|
||||
bias_indices = (grid * torch.tensor([1, self.window_size * 2 - 1], device = device)[:, None, None]).sum(dim = 0)
|
||||
rel_pos_bias = self.local_rel_pos_bias(bias_indices)
|
||||
rel_pos_bias = rearrange(rel_pos_bias, 'i j h -> () h i j')
|
||||
rel_pos_bias = F.pad(rel_pos_bias, (1, 0, 1, 0), value = 0)
|
||||
|
||||
# go through r2l transformer layers
|
||||
|
||||
for attn, ff in self.layers:
|
||||
region_tokens = attn(region_tokens) + region_tokens
|
||||
|
||||
# concat region tokens to local tokens
|
||||
|
||||
local_tokens = rearrange(local_tokens, 'b (h w) d -> b h w d', h = lh)
|
||||
local_tokens = rearrange(local_tokens, 'b (h p1) (w p2) d -> (b h w) (p1 p2) d', p1 = window_size_h, p2 = window_size_w)
|
||||
region_tokens = rearrange(region_tokens, 'b n d -> (b n) () d')
|
||||
|
||||
# do self attention on local tokens, along with its regional token
|
||||
|
||||
region_and_local_tokens = torch.cat((region_tokens, local_tokens), dim = 1)
|
||||
region_and_local_tokens = attn(region_and_local_tokens, rel_pos_bias = rel_pos_bias) + region_and_local_tokens
|
||||
|
||||
# feedforward
|
||||
|
||||
region_and_local_tokens = ff(region_and_local_tokens) + region_and_local_tokens
|
||||
|
||||
# split back local and regional tokens
|
||||
|
||||
region_tokens, local_tokens = region_and_local_tokens[:, :1], region_and_local_tokens[:, 1:]
|
||||
local_tokens = rearrange(local_tokens, '(b h w) (p1 p2) d -> b (h p1 w p2) d', h = lh // window_size_h, w = lw // window_size_w, p1 = window_size_h)
|
||||
region_tokens = rearrange(region_tokens, '(b n) () d -> b n d', n = rh * rw)
|
||||
|
||||
local_tokens = rearrange(local_tokens, 'b (h w) c -> b c h w', h = lh, w = lw)
|
||||
region_tokens = rearrange(region_tokens, 'b (h w) c -> b c h w', h = rh, w = rw)
|
||||
return local_tokens, region_tokens
|
||||
|
||||
# classes
|
||||
|
||||
class RegionViT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim = (64, 128, 256, 512),
|
||||
depth = (2, 2, 8, 2),
|
||||
window_size = 7,
|
||||
num_classes = 1000,
|
||||
tokenize_local_3_conv = False,
|
||||
local_patch_size = 4,
|
||||
use_peg = False,
|
||||
attn_dropout = 0.,
|
||||
ff_dropout = 0.,
|
||||
channels = 3,
|
||||
):
|
||||
super().__init__()
|
||||
dim = cast_tuple(dim, 4)
|
||||
depth = cast_tuple(depth, 4)
|
||||
assert len(dim) == 4, 'dim needs to be a single value or a tuple of length 4'
|
||||
assert len(depth) == 4, 'depth needs to be a single value or a tuple of length 4'
|
||||
|
||||
self.local_patch_size = local_patch_size
|
||||
|
||||
region_patch_size = local_patch_size * window_size
|
||||
self.region_patch_size = local_patch_size * window_size
|
||||
|
||||
init_dim, *_, last_dim = dim
|
||||
|
||||
# local and region encoders
|
||||
|
||||
if tokenize_local_3_conv:
|
||||
self.local_encoder = nn.Sequential(
|
||||
nn.Conv2d(3, init_dim, 3, 2, 1),
|
||||
nn.LayerNorm(init_dim),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(init_dim, init_dim, 3, 2, 1),
|
||||
nn.LayerNorm(init_dim),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(init_dim, init_dim, 3, 1, 1)
|
||||
)
|
||||
else:
|
||||
self.local_encoder = nn.Conv2d(3, init_dim, 8, 4, 3)
|
||||
|
||||
self.region_encoder = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = region_patch_size, p2 = region_patch_size),
|
||||
nn.Conv2d((region_patch_size ** 2) * channels, init_dim, 1)
|
||||
)
|
||||
|
||||
# layers
|
||||
|
||||
current_dim = init_dim
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
for ind, dim, num_layers in zip(range(4), dim, depth):
|
||||
not_first = ind != 0
|
||||
need_downsample = not_first
|
||||
need_peg = not_first and use_peg
|
||||
|
||||
self.layers.append(nn.ModuleList([
|
||||
Downsample(current_dim, dim) if need_downsample else nn.Identity(),
|
||||
PEG(dim) if need_peg else nn.Identity(),
|
||||
R2LTransformer(dim, depth = num_layers, window_size = window_size, attn_dropout = attn_dropout, ff_dropout = ff_dropout)
|
||||
]))
|
||||
|
||||
current_dim = dim
|
||||
|
||||
# final logits
|
||||
|
||||
self.to_logits = nn.Sequential(
|
||||
Reduce('b c h w -> b c', 'mean'),
|
||||
nn.LayerNorm(last_dim),
|
||||
nn.Linear(last_dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
return_local_tokens = False
|
||||
):
|
||||
*_, h, w = x.shape
|
||||
assert divisible_by(h, self.region_patch_size) and divisible_by(w, self.region_patch_size), 'height and width must be divisible by region patch size'
|
||||
assert divisible_by(h, self.local_patch_size) and divisible_by(w, self.local_patch_size), 'height and width must be divisible by local patch size'
|
||||
|
||||
local_tokens = self.local_encoder(x)
|
||||
region_tokens = self.region_encoder(x)
|
||||
|
||||
for down, peg, transformer in self.layers:
|
||||
local_tokens, region_tokens = down(local_tokens), down(region_tokens)
|
||||
local_tokens = peg(local_tokens)
|
||||
local_tokens, region_tokens = transformer(local_tokens, region_tokens)
|
||||
|
||||
return self.to_logits(region_tokens)
|
||||
@@ -19,7 +19,7 @@ class AxialRotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_freq = 10):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
scales = torch.logspace(0., log(max_freq / 2) / log(2), self.dim // 4, base = 2)
|
||||
scales = torch.linspace(1., max_freq / 2, self.dim // 4)
|
||||
self.register_buffer('scales', scales)
|
||||
|
||||
def forward(self, x):
|
||||
@@ -154,10 +154,10 @@ class Attention(nn.Module):
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., use_rotary = True, use_ds_conv = True, use_glu = True):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, image_size, dropout = 0., use_rotary = True, use_ds_conv = True, use_glu = True):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
self.pos_emb = AxialRotaryEmbedding(dim_head)
|
||||
self.pos_emb = AxialRotaryEmbedding(dim_head, max_freq = image_size)
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_rotary = use_rotary, use_ds_conv = use_ds_conv)),
|
||||
@@ -187,7 +187,7 @@ class RvT(nn.Module):
|
||||
)
|
||||
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, use_rotary, use_ds_conv, use_glu)
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, image_size, dropout, use_rotary, use_ds_conv, use_glu)
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
|
||||
@@ -35,13 +35,14 @@ class T2TViT(nn.Module):
|
||||
for i, (kernel_size, stride) in enumerate(t2t_layers):
|
||||
layer_dim *= kernel_size ** 2
|
||||
is_first = i == 0
|
||||
is_last = i == (len(t2t_layers) - 1)
|
||||
output_image_size = conv_output_size(output_image_size, kernel_size, stride, stride // 2)
|
||||
|
||||
layers.extend([
|
||||
RearrangeImage() if not is_first else nn.Identity(),
|
||||
nn.Unfold(kernel_size = kernel_size, stride = stride, padding = stride // 2),
|
||||
Rearrange('b c n -> b n c'),
|
||||
Transformer(dim = layer_dim, heads = 1, depth = 1, dim_head = layer_dim, mlp_dim = layer_dim, dropout = dropout),
|
||||
Transformer(dim = layer_dim, heads = 1, depth = 1, dim_head = layer_dim, mlp_dim = layer_dim, dropout = dropout) if not is_last else nn.Identity(),
|
||||
])
|
||||
|
||||
layers.append(nn.Linear(layer_dim, dim))
|
||||
@@ -71,7 +72,7 @@ class T2TViT(nn.Module):
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding
|
||||
x += self.pos_embedding[:, :n+1]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
@@ -51,15 +50,14 @@ class Attention(nn.Module):
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
b, n, _, h = *x.shape, self.heads
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user