Compare commits

...

39 Commits

Author SHA1 Message Date
Phil Wang
e8f6d72033 release masked autoencoder 2021-11-12 20:08:48 -08:00
Phil Wang
cb1729af28 more efficient feedforward for regionvit 2021-11-07 17:18:59 -08:00
Phil Wang
9e50b2a41e readme 2021-11-07 09:59:49 -08:00
Phil Wang
06d375351e add RegionViT paper 2021-11-07 09:47:28 -08:00
Phil Wang
f196d1ec5b move freqs in RvT to linspace 2021-10-05 09:23:44 -07:00
Phil Wang
529044c9b3 Merge pull request #153 from developer0hye/fix-example
fix transforms for val an test process in example code
2021-09-02 06:57:16 -07:00
yhkwon-DT01
c30655f3bc fix transforms for val an test process 2021-09-02 17:30:18 +09:00
Phil Wang
d2d6de01d3 0.20.7 2021-08-30 08:14:43 -07:00
Phil Wang
b9eadaef60 Merge pull request #151 from developer0hye/patch-1
Cleanup Attention Class & matmul based implementation for TensorRT conversion
2021-08-30 08:14:11 -07:00
Yonghye Kwon
24ac8350bf remove unused package 2021-08-30 18:25:03 +09:00
Yonghye Kwon
ca3cef9de0 Cleanup Attention Class 2021-08-30 18:05:16 +09:00
Phil Wang
6e1be11517 0.20.6 2021-08-21 09:03:54 -07:00
Phil Wang
73ed562ce4 Merge pull request #147 from developer0hye/patch-4
Make T2T process any scale image
2021-08-21 09:03:42 -07:00
Phil Wang
ff863175a6 Merge pull request #146 from developer0hye/patch-1
Make Pit process image with width and height less than the image_size
2021-08-21 09:03:31 -07:00
Yonghye Kwon
ca0bdca192 Make model process any scale image
Related to #145
2021-08-21 22:35:26 +09:00
Yonghye Kwon
1c70271778 Support image with width and height less than the image_size
Related to #145
2021-08-21 22:25:46 +09:00
Phil Wang
d7d3febfe3 Merge pull request #144 from developer0hye/patch-1
Remove unused package
2021-08-20 10:14:02 -07:00
Yonghye Kwon
946815164a Remove unused package 2021-08-20 13:44:57 +09:00
Phil Wang
aeed3381c1 use hardswish for levit 2021-08-19 08:22:55 -07:00
Phil Wang
3f754956fb remove last transformer layer in t2t 2021-08-14 08:06:23 -07:00
Phil Wang
918869571c fix hard distillation, thanks to @CiaoHe 2021-08-12 08:40:57 -07:00
Phil Wang
e5324242be fix wrong norm in nest 2021-08-05 12:55:48 -07:00
Phil Wang
22da26fa4b fix recorder in data parallel situation 2021-07-08 10:15:07 -07:00
Phil Wang
a6c085a2df 0.20.0 for cct 2021-07-02 15:48:48 -07:00
Phil Wang
121353c604 Merge pull request #128 from stevenwalton/main
Adding Compact Convolutional Transformers (CCT)
2021-07-02 15:48:32 -07:00
alih
2ece3333da Minor changes 2021-07-01 17:51:35 -07:00
Ali Hassani
a73030c9aa Update README.md 2021-07-01 16:41:27 -07:00
Steven Walton
780f91a220 Tested and changed README format 2021-07-01 16:26:41 -07:00
Steven Walton
88451068e8 Adding CCT
Adding Compact Convolutional Transformers (CCT) from Escaping the Big Data
Paradigm with Compact Transformers by Hassani et. al.
https://arxiv.org/abs/2104.05704
2021-07-01 16:22:33 -07:00
Phil Wang
64a2ef6462 fix mpp 2021-06-16 16:46:32 -07:00
Phil Wang
53884f583f 0.19.5 2021-06-16 14:24:46 -07:00
Phil Wang
e616b5dcbc Merge pull request #101 from zankner/mpp-fix
Mpp fix
2021-06-16 14:24:26 -07:00
Phil Wang
60ad4e266e layernorm on channel dimension == instancenorm2d with affine set to true 2021-06-03 16:41:45 -07:00
Phil Wang
a254a0258a fix typo 2021-06-01 07:33:00 -07:00
Phil Wang
26df10c0b7 fix max pool in nest 2021-05-28 11:06:02 -07:00
Phil Wang
17cb8976df make nest resilient to dimension that are not divisible by number of heads 2021-05-27 22:41:07 -07:00
Phil Wang
daf3abbeb5 add NesT 2021-05-27 22:02:17 -07:00
Zack Ankner
a2df363224 adding un-normalizing targets and fix for mask token dimension 2021-04-29 15:43:22 -04:00
Zack Ankner
710b6d57d3 Merge pull request #1 from lucidrains/main
catch up
2021-04-29 19:33:25 +00:00
20 changed files with 1150 additions and 68 deletions

197
README.md
View File

@@ -62,6 +62,7 @@ Dropout rate.
Embedding dropout rate.
- `pool`: string, either `cls` token pooling or `mean` pooling
## Distillation
<img src="./images/distill.png" width="300px"></img>
@@ -118,6 +119,7 @@ v = v.to_vit()
type(v) # <class 'vit_pytorch.vit_pytorch.ViT'>
```
## Deep ViT
This <a href="https://arxiv.org/abs/2103.11886">paper</a> notes that ViT struggles to attend at greater depths (past 12 layers), and suggests mixing the attention of each head post-softmax as a solution, dubbed Re-attention. The results line up with the <a href="https://github.com/lucidrains/x-transformers#talking-heads-attention">Talking Heads</a> paper from NLP.
@@ -201,6 +203,61 @@ img = torch.randn(1, 3, 224, 224)
preds = v(img) # (1, 1000)
```
## CCT
<img src="https://raw.githubusercontent.com/SHI-Labs/Compact-Transformers/main/images/model_sym.png" width="400px"></img>
<a href="https://arxiv.org/abs/2104.05704">CCT</a> proposes compact transformers
by using convolutions instead of patching and performing sequence pooling. This
allows for CCT to have high accuracy and a low number of parameters.
You can use this with two methods
```python
import torch
from vit_pytorch.cct import CCT
model = CCT(
img_size=224,
embedding_dim=384,
n_conv_layers=2,
kernel_size=7,
stride=2,
padding=3,
pooling_kernel_size=3,
pooling_stride=2,
pooling_padding=1,
num_layers=14,
num_heads=6,
mlp_radio=3.,
num_classes=1000,
positional_embedding='learnable', # ['sine', 'learnable', 'none']
)
```
Alternatively you can use one of several pre-defined models `[2,4,6,7,8,14,16]`
which pre-define the number of layers, number of attention heads, the mlp ratio,
and the embedding dimension.
```python
import torch
from vit_pytorch.cct import cct_14
model = cct_14(
img_size=224,
n_conv_layers=1,
kernel_size=7,
stride=2,
padding=3,
pooling_kernel_size=3,
pooling_stride=2,
pooling_padding=1,
num_classes=1000,
positional_embedding='learnable', # ['sine', 'learnable', 'none']
)
```
<a href="https://github.com/SHI-Labs/Compact-Transformers">Official
Repository</a> includes links to pretrained model checkpoints.
## Cross ViT
<img src="./images/cross_vit.png" width="400px"></img>
@@ -378,6 +435,100 @@ img = torch.randn(1, 3, 224, 224)
pred = model(img) # (1, 1000)
```
## RegionViT
<img src="./images/regionvit.png" width="400px"></img>
<img src="./images/regionvit2.png" width="400px"></img>
<a href="https://arxiv.org/abs/2106.02689">This paper</a> proposes to divide up the feature map into local regions, whereby the local tokens attend to each other. Each local region has its own regional token which then attends to all its local tokens, as well as other regional tokens.
You can use it as follows
```python
import torch
from vit_pytorch.regionvit import RegionViT
model = RegionViT(
dim = (64, 128, 256, 512), # tuple of size 4, indicating dimension at each stage
depth = (2, 2, 8, 2), # depth of the region to local transformer at each stage
window_size = 7, # window size, which should be either 7 or 14
num_classes = 1000, # number of output lcasses
tokenize_local_3_conv = False, # whether to use a 3 layer convolution to encode the local tokens from the image. the paper uses this for the smaller models, but uses only 1 conv (set to False) for the larger models
use_peg = False, # whether to use positional generating module. they used this for object detection for a boost in performance
)
img = torch.randn(1, 3, 224, 224)
pred = model(img) # (1, 1000)
```
## NesT
<img src="./images/nest.png" width="400px"></img>
This <a href="https://arxiv.org/abs/2105.12723">paper</a> decided to process the image in hierarchical stages, with attention only within tokens of local blocks, which aggregate as it moves up the heirarchy. The aggregation is done in the image plane, and contains a convolution and subsequent maxpool to allow it to pass information across the boundary.
You can use it with the following code (ex. NesT-T)
```python
import torch
from vit_pytorch.nest import NesT
nest = NesT(
image_size = 224,
patch_size = 4,
dim = 96,
heads = 3,
num_hierarchies = 3, # number of hierarchies
block_repeats = (8, 4, 1), # the number of transformer blocks at each heirarchy, starting from the bottom
num_classes = 1000
)
img = torch.randn(1, 3, 224, 224)
pred = nest(img) # (1, 1000)
```
## Masked Autoencoder
<img src="./images/mae.png" width="400px"/>
A new <a href="https://arxiv.org/abs/2111.06377">Kaiming He paper</a> proposes a simple autoencoder scheme where the vision transformer attends to a set of unmasked patches, and a smaller decoder tries to reconstruct the masked pixel values.
You can use it with the following code
```python
import torch
from vit_pytorch import ViT, MAE
v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048
)
mae = MAE(
encoder = v,
masking_ratio = 0.75,
decoder_dim = 1024,
decoder_depth = 6,
decoder_heads = 8
)
images = torch.randn(8, 3, 256, 256)
loss = mae(images)
loss.backward()
# that's all!
# do the above in a for loop many times with a lot of images and your vision transformer will learn
```
## Masked Patch Prediction
Thanks to <a href="https://github.com/zankner">Zach</a>, you can train using the original masked patch prediction task presented in the paper, with the following code.
@@ -411,7 +562,7 @@ mpp_trainer = MPP(
opt = torch.optim.Adam(mpp_trainer.parameters(), lr=3e-4)
def sample_unlabelled_images():
return torch.randn(20, 3, 256, 256)
return torch.FloatTensor(20, 3, 256, 256).uniform_(0., 1.)
for _ in range(100):
images = sample_unlabelled_images()
@@ -654,6 +805,17 @@ Coming from computer vision and new to transformers? Here are some resources tha
## Citations
```bibtex
@article{hassani2021escaping,
title = {Escaping the Big Data Paradigm with Compact Transformers},
author = {Ali Hassani and Steven Walton and Nikhil Shah and Abulikemu Abuduweili and Jiachen Li and Humphrey Shi},
year = 2021,
url = {https://arxiv.org/abs/2104.05704},
eprint = {2104.05704},
archiveprefix = {arXiv},
primaryclass = {cs.CV}
}
```
```bibtex
@misc{dosovitskiy2020image,
@@ -787,6 +949,28 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```
```bibtex
@misc{zhang2021aggregating,
title = {Aggregating Nested Transformers},
author = {Zizhao Zhang and Han Zhang and Long Zhao and Ting Chen and Tomas Pfister},
year = {2021},
eprint = {2105.12723},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@misc{chen2021regionvit,
title = {RegionViT: Regional-to-Local Attention for Vision Transformers},
author = {Chun-Fu Chen and Rameswar Panda and Quanfu Fan},
year = {2021},
eprint = {2106.02689},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@misc{caron2021emerging,
title = {Emerging Properties in Self-Supervised Vision Transformers},
@@ -798,6 +982,17 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```
```bibtex
@misc{he2021masked,
title = {Masked Autoencoders Are Scalable Vision Learners},
author = {Kaiming He and Xinlei Chen and Saining Xie and Yanghao Li and Piotr Dollár and Ross Girshick},
year = {2021},
eprint = {2111.06377},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@misc{vaswani2017attention,
title = {Attention Is All You Need},

View File

@@ -364,9 +364,8 @@
"\n",
"val_transforms = transforms.Compose(\n",
" [\n",
" transforms.Resize((224, 224)),\n",
" transforms.RandomResizedCrop(224),\n",
" transforms.RandomHorizontalFlip(),\n",
" transforms.Resize(256),\n",
" transforms.CenterCrop(224),\n",
" transforms.ToTensor(),\n",
" ]\n",
")\n",
@@ -374,9 +373,8 @@
"\n",
"test_transforms = transforms.Compose(\n",
" [\n",
" transforms.Resize((224, 224)),\n",
" transforms.RandomResizedCrop(224),\n",
" transforms.RandomHorizontalFlip(),\n",
" transforms.Resize(256),\n",
" transforms.CenterCrop(224),\n",
" transforms.ToTensor(),\n",
" ]\n",
")\n"
@@ -6250,4 +6248,4 @@
},
"nbformat": 4,
"nbformat_minor": 1
}
}

BIN
images/mae.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 198 KiB

BIN
images/nest.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 75 KiB

BIN
images/regionvit.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 94 KiB

BIN
images/regionvit2.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 55 KiB

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.18.4',
version = '0.22.0',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',

View File

@@ -1,2 +1,3 @@
from vit_pytorch.vit import ViT
from vit_pytorch.mae import MAE
from vit_pytorch.dino import Dino

339
vit_pytorch/cct.py Normal file
View File

@@ -0,0 +1,339 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
# Pre-defined CCT Models
__all__ = ['cct_2', 'cct_4', 'cct_6', 'cct_7', 'cct_8', 'cct_14', 'cct_16']
def cct_2(*args, **kwargs):
return _cct(num_layers=2, num_heads=2, mlp_ratio=1, embedding_dim=128,
*args, **kwargs)
def cct_4(*args, **kwargs):
return _cct(num_layers=4, num_heads=2, mlp_ratio=1, embedding_dim=128,
*args, **kwargs)
def cct_6(*args, **kwargs):
return _cct(num_layers=6, num_heads=4, mlp_ratio=2, embedding_dim=256,
*args, **kwargs)
def cct_7(*args, **kwargs):
return _cct(num_layers=7, num_heads=4, mlp_ratio=2, embedding_dim=256,
*args, **kwargs)
def cct_8(*args, **kwargs):
return _cct(num_layers=8, num_heads=4, mlp_ratio=2, embedding_dim=256,
*args, **kwargs)
def cct_14(*args, **kwargs):
return _cct(num_layers=14, num_heads=6, mlp_ratio=3, embedding_dim=384,
*args, **kwargs)
def cct_16(*args, **kwargs):
return _cct(num_layers=16, num_heads=6, mlp_ratio=3, embedding_dim=384,
*args, **kwargs)
def _cct(num_layers, num_heads, mlp_ratio, embedding_dim,
kernel_size=3, stride=None, padding=None,
*args, **kwargs):
stride = stride if stride is not None else max(1, (kernel_size // 2) - 1)
padding = padding if padding is not None else max(1, (kernel_size // 2))
return CCT(num_layers=num_layers,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
embedding_dim=embedding_dim,
kernel_size=kernel_size,
stride=stride,
padding=padding,
*args, **kwargs)
# Modules
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1):
super().__init__()
self.num_heads = num_heads
head_dim = dim // self.num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=False)
self.attn_drop = nn.Dropout(attention_dropout)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(projection_dropout)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class TransformerEncoderLayer(nn.Module):
"""
Inspired by torch.nn.TransformerEncoderLayer and
rwightman's timm package.
"""
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
attention_dropout=0.1, drop_path_rate=0.1):
super(TransformerEncoderLayer, self).__init__()
self.pre_norm = nn.LayerNorm(d_model)
self.self_attn = Attention(dim=d_model, num_heads=nhead,
attention_dropout=attention_dropout, projection_dropout=dropout)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout1 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(d_model)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.dropout2 = nn.Dropout(dropout)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
self.activation = F.gelu
def forward(self, src: torch.Tensor, *args, **kwargs) -> torch.Tensor:
src = src + self.drop_path(self.self_attn(self.pre_norm(src)))
src = self.norm1(src)
src2 = self.linear2(self.dropout1(self.activation(self.linear1(src))))
src = src + self.drop_path(self.dropout2(src2))
return src
def drop_path(x, drop_prob: float = 0., training: bool = False):
"""
Obtained from: github.com:rwightman/pytorch-image-models
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
"""
Obtained from: github.com:rwightman/pytorch-image-models
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
class Tokenizer(nn.Module):
def __init__(self,
kernel_size, stride, padding,
pooling_kernel_size=3, pooling_stride=2, pooling_padding=1,
n_conv_layers=1,
n_input_channels=3,
n_output_channels=64,
in_planes=64,
activation=None,
max_pool=True,
conv_bias=False):
super(Tokenizer, self).__init__()
n_filter_list = [n_input_channels] + \
[in_planes for _ in range(n_conv_layers - 1)] + \
[n_output_channels]
self.conv_layers = nn.Sequential(
*[nn.Sequential(
nn.Conv2d(n_filter_list[i], n_filter_list[i + 1],
kernel_size=(kernel_size, kernel_size),
stride=(stride, stride),
padding=(padding, padding), bias=conv_bias),
nn.Identity() if activation is None else activation(),
nn.MaxPool2d(kernel_size=pooling_kernel_size,
stride=pooling_stride,
padding=pooling_padding) if max_pool else nn.Identity()
)
for i in range(n_conv_layers)
])
self.flattener = nn.Flatten(2, 3)
self.apply(self.init_weight)
def sequence_length(self, n_channels=3, height=224, width=224):
return self.forward(torch.zeros((1, n_channels, height, width))).shape[1]
def forward(self, x):
return self.flattener(self.conv_layers(x)).transpose(-2, -1)
@staticmethod
def init_weight(m):
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
class TransformerClassifier(nn.Module):
def __init__(self,
seq_pool=True,
embedding_dim=768,
num_layers=12,
num_heads=12,
mlp_ratio=4.0,
num_classes=1000,
dropout_rate=0.1,
attention_dropout=0.1,
stochastic_depth_rate=0.1,
positional_embedding='sine',
sequence_length=None,
*args, **kwargs):
super().__init__()
positional_embedding = positional_embedding if \
positional_embedding in ['sine', 'learnable', 'none'] else 'sine'
dim_feedforward = int(embedding_dim * mlp_ratio)
self.embedding_dim = embedding_dim
self.sequence_length = sequence_length
self.seq_pool = seq_pool
assert sequence_length is not None or positional_embedding == 'none', \
f"Positional embedding is set to {positional_embedding} and" \
f" the sequence length was not specified."
if not seq_pool:
sequence_length += 1
self.class_emb = nn.Parameter(torch.zeros(1, 1, self.embedding_dim),
requires_grad=True)
else:
self.attention_pool = nn.Linear(self.embedding_dim, 1)
if positional_embedding != 'none':
if positional_embedding == 'learnable':
self.positional_emb = nn.Parameter(torch.zeros(1, sequence_length, embedding_dim),
requires_grad=True)
nn.init.trunc_normal_(self.positional_emb, std=0.2)
else:
self.positional_emb = nn.Parameter(self.sinusoidal_embedding(sequence_length, embedding_dim),
requires_grad=False)
else:
self.positional_emb = None
self.dropout = nn.Dropout(p=dropout_rate)
dpr = [x.item() for x in torch.linspace(0, stochastic_depth_rate, num_layers)]
self.blocks = nn.ModuleList([
TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads,
dim_feedforward=dim_feedforward, dropout=dropout_rate,
attention_dropout=attention_dropout, drop_path_rate=dpr[i])
for i in range(num_layers)])
self.norm = nn.LayerNorm(embedding_dim)
self.fc = nn.Linear(embedding_dim, num_classes)
self.apply(self.init_weight)
def forward(self, x):
if self.positional_emb is None and x.size(1) < self.sequence_length:
x = F.pad(x, (0, 0, 0, self.n_channels - x.size(1)), mode='constant', value=0)
if not self.seq_pool:
cls_token = self.class_emb.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
if self.positional_emb is not None:
x += self.positional_emb
x = self.dropout(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
if self.seq_pool:
x = torch.matmul(F.softmax(self.attention_pool(x), dim=1).transpose(-1, -2), x).squeeze(-2)
else:
x = x[:, 0]
x = self.fc(x)
return x
@staticmethod
def init_weight(m):
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@staticmethod
def sinusoidal_embedding(n_channels, dim):
pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)]
for p in range(n_channels)])
pe[:, 0::2] = torch.sin(pe[:, 0::2])
pe[:, 1::2] = torch.cos(pe[:, 1::2])
return pe.unsqueeze(0)
# CCT Main model
class CCT(nn.Module):
def __init__(self,
img_size=224,
embedding_dim=768,
n_input_channels=3,
n_conv_layers=1,
kernel_size=7,
stride=2,
padding=3,
pooling_kernel_size=3,
pooling_stride=2,
pooling_padding=1,
*args, **kwargs):
super(CCT, self).__init__()
self.tokenizer = Tokenizer(n_input_channels=n_input_channels,
n_output_channels=embedding_dim,
kernel_size=kernel_size,
stride=stride,
padding=padding,
pooling_kernel_size=pooling_kernel_size,
pooling_stride=pooling_stride,
pooling_padding=pooling_padding,
max_pool=True,
activation=nn.ReLU,
n_conv_layers=n_conv_layers,
conv_bias=False)
self.classifier = TransformerClassifier(
sequence_length=self.tokenizer.sequence_length(n_channels=n_input_channels,
height=img_size,
width=img_size),
embedding_dim=embedding_dim,
seq_pool=True,
dropout_rate=0.,
attention_dropout=0.1,
stochastic_depth=0.1,
*args, **kwargs)
def forward(self, x):
x = self.tokenizer(x)
return self.classifier(x)

View File

@@ -148,6 +148,6 @@ class DistillWrapper(nn.Module):
else:
teacher_labels = teacher_logits.argmax(dim = -1)
distill_loss = F.cross_entropy(student_logits, teacher_labels)
distill_loss = F.cross_entropy(distill_logits, teacher_labels)
return loss * (1 - alpha) + distill_loss * alpha

View File

@@ -29,7 +29,7 @@ class FeedForward(nn.Module):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(dim, dim * mult, 1),
nn.GELU(),
nn.Hardswish(),
nn.Dropout(dropout),
nn.Conv2d(dim * mult, dim, 1),
nn.Dropout(dropout)

93
vit_pytorch/mae.py Normal file
View File

@@ -0,0 +1,93 @@
import torch
from math import ceil
from torch import nn
import torch.nn.functional as F
from einops import rearrange, repeat
from vit_pytorch.vit import Transformer
class MAE(nn.Module):
def __init__(
self,
*,
encoder,
decoder_dim,
masking_ratio = 0.75,
decoder_depth = 1,
decoder_heads = 8,
decoder_dim_head = 64
):
super().__init__()
assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1'
self.masking_ratio = masking_ratio
# extract some hyperparameters and functions from encoder (vision transformer to be trained)
self.encoder = encoder
num_patches, encoder_dim = encoder.pos_embedding.shape[-2:]
self.to_patch, self.patch_to_emb = encoder.to_patch_embedding[:2]
pixel_values_per_patch = self.patch_to_emb.weight.shape[-1]
# decoder parameters
self.enc_to_dec = nn.Linear(encoder_dim, decoder_dim) if encoder_dim != decoder_dim else nn.Identity()
self.mask_token = nn.Parameter(torch.randn(decoder_dim))
self.decoder = Transformer(dim = decoder_dim, depth = decoder_depth, heads = decoder_heads, dim_head = decoder_dim_head, mlp_dim = decoder_dim * 4)
self.decoder_pos_emb = nn.Embedding(num_patches, decoder_dim)
self.to_pixels = nn.Linear(decoder_dim, pixel_values_per_patch)
def forward(self, img):
device = img.device
# get patches
patches = self.to_patch(img)
batch, num_patches, *_ = patches.shape
# patch to encoder tokens and add positions
tokens = self.patch_to_emb(patches)
tokens = tokens + self.encoder.pos_embedding[:, 1:(num_patches + 1)]
# calculate of patches needed to be masked, and get random indices, dividing it up for mask vs unmasked
num_masked = int(self.masking_ratio * num_patches)
rand_indices = torch.rand(batch, num_patches, device = device).argsort(dim = -1)
masked_indices, unmasked_indices = rand_indices[:, :num_masked], rand_indices[:, num_masked:]
# get the unmasked tokens to be encoded
batch_range = torch.arange(batch, device = device)[:, None]
tokens = tokens[batch_range, unmasked_indices]
# get the patches to be masked for the final reconstruction loss
masked_patches = patches[batch_range, masked_indices]
# attend with vision transformer
encoded_tokens = self.encoder.transformer(tokens)
# project encoder to decoder dimensions, if they are not equal - the paper says you can get away with a smaller dimension for decoder
decoder_tokens = self.enc_to_dec(encoded_tokens)
# repeat mask tokens for number of masked, and add the positions using the masked indices derived above
mask_tokens = repeat(self.mask_token, 'd -> b n d', b = batch, n = num_masked)
mask_tokens = mask_tokens + self.decoder_pos_emb(masked_indices)
# concat the masked tokens to the decoder tokens and attend with decoder
decoder_tokens = torch.cat((decoder_tokens, mask_tokens), dim = 1)
decoded_tokens = self.decoder(decoder_tokens)
# splice out the mask tokens and project to pixel values
mask_tokens = decoded_tokens[:, -num_masked:]
pred_pixel_values = self.to_pixels(mask_tokens)
# calculate reconstruction loss
recon_loss = F.mse_loss(pred_pixel_values, masked_patches)
return recon_loss

View File

@@ -1,20 +1,20 @@
import math
from functools import reduce
import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange, repeat
from einops import rearrange, repeat, reduce
# helpers
def exists(val):
return val is not None
def prob_mask_like(t, prob):
batch, seq_length, _ = t.shape
return torch.zeros((batch, seq_length)).float().uniform_(0, 1) < prob
def get_mask_subset_with_prob(patched_input, prob):
batch, seq_len, _, device = *patched_input.shape, patched_input.device
max_masked = math.ceil(prob * seq_len)
@@ -31,43 +31,45 @@ def get_mask_subset_with_prob(patched_input, prob):
class MPPLoss(nn.Module):
def __init__(self, patch_size, channels, output_channel_bits,
max_pixel_val):
super(MPPLoss, self).__init__()
def __init__(
self,
patch_size,
channels,
output_channel_bits,
max_pixel_val,
mean,
std
):
super().__init__()
self.patch_size = patch_size
self.channels = channels
self.output_channel_bits = output_channel_bits
self.max_pixel_val = max_pixel_val
self.mean = torch.tensor(mean).view(-1, 1, 1) if mean else None
self.std = torch.tensor(std).view(-1, 1, 1) if std else None
def forward(self, predicted_patches, target, mask):
p, c, mpv, bits, device = self.patch_size, self.channels, self.max_pixel_val, self.output_channel_bits, target.device
bin_size = mpv / (2 ** bits)
# un-normalize input
if exists(self.mean) and exists(self.std):
target = target * self.std + self.mean
# reshape target to patches
p = self.patch_size
target = rearrange(target,
"b c (h p1) (w p2) -> b (h w) c (p1 p2) ",
p1=p,
p2=p)
target = target.clamp(max = mpv) # clamp just in case
avg_target = reduce(target, 'b c (h p1) (w p2) -> b (h w) c', 'mean', p1 = p, p2 = p).contiguous()
avg_target = target.mean(dim=3)
bin_size = self.max_pixel_val / self.output_channel_bits
channel_bins = torch.arange(bin_size, self.max_pixel_val, bin_size).to(avg_target.device)
channel_bins = torch.arange(bin_size, mpv, bin_size, device = device)
discretized_target = torch.bucketize(avg_target, channel_bins)
discretized_target = F.one_hot(discretized_target,
self.output_channel_bits)
c, bi = self.channels, self.output_channel_bits
discretized_target = rearrange(discretized_target,
"b n c bi -> b n (c bi)",
c=c,
bi=bi)
bin_mask = 2**torch.arange(c * bi - 1, -1,
-1).to(discretized_target.device,
discretized_target.dtype)
target_label = torch.sum(bin_mask * discretized_target, -1)
bin_mask = (2 ** bits) ** torch.arange(0, c, device = device).long()
bin_mask = rearrange(bin_mask, 'c -> () () c')
predicted_patches = predicted_patches[mask]
target_label = target_label[mask]
loss = F.cross_entropy(predicted_patches, target_label)
target_label = torch.sum(bin_mask * discretized_target, dim = -1)
loss = F.cross_entropy(predicted_patches[mask], target_label[mask])
return loss
@@ -75,20 +77,24 @@ class MPPLoss(nn.Module):
class MPP(nn.Module):
def __init__(self,
transformer,
patch_size,
dim,
output_channel_bits=3,
channels=3,
max_pixel_val=1.0,
mask_prob=0.15,
replace_prob=0.5,
random_patch_prob=0.5):
def __init__(
self,
transformer,
patch_size,
dim,
output_channel_bits=3,
channels=3,
max_pixel_val=1.0,
mask_prob=0.15,
replace_prob=0.5,
random_patch_prob=0.5,
mean=None,
std=None
):
super().__init__()
self.transformer = transformer
self.loss = MPPLoss(patch_size, channels, output_channel_bits,
max_pixel_val)
max_pixel_val, mean, std)
# output transformation
self.to_bits = nn.Linear(dim, 2**(output_channel_bits * channels))
@@ -102,7 +108,7 @@ class MPP(nn.Module):
self.random_patch_prob = random_patch_prob
# token ids
self.mask_token = nn.Parameter(torch.randn(1, 1, dim * channels))
self.mask_token = nn.Parameter(torch.randn(1, 1, channels * patch_size ** 2))
def forward(self, input, **kwargs):
transformer = self.transformer

179
vit_pytorch/nest.py Normal file
View File

@@ -0,0 +1,179 @@
from functools import partial
import torch
from torch import nn, einsum
from einops import rearrange
from einops.layers.torch import Rearrange, Reduce
# helpers
def cast_tuple(val, depth):
return val if isinstance(val, tuple) else ((val,) * depth)
# classes
class LayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (std + self.eps) * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, mlp_mult = 4, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(dim, dim * mlp_mult, 1),
nn.GELU(),
nn.Dropout(dropout),
nn.Conv2d(dim * mlp_mult, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dropout = 0.):
super().__init__()
dim_head = dim // heads
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
self.to_out = nn.Sequential(
nn.Conv2d(inner_dim, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
b, c, h, w, heads = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> b h (x y) d', h = heads), qkv)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = self.attend(dots)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
return self.to_out(out)
def Aggregate(dim, dim_out):
return nn.Sequential(
nn.Conv2d(dim, dim_out, 3, padding = 1),
LayerNorm(dim_out),
nn.MaxPool2d(3, stride = 2, padding = 1)
)
class Transformer(nn.Module):
def __init__(self, dim, seq_len, depth, heads, mlp_mult, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
self.pos_emb = nn.Parameter(torch.randn(seq_len))
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout))
]))
def forward(self, x):
*_, h, w = x.shape
pos_emb = self.pos_emb[:(h * w)]
pos_emb = rearrange(pos_emb, '(h w) -> () () h w', h = h, w = w)
x = x + pos_emb
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class NesT(nn.Module):
def __init__(
self,
*,
image_size,
patch_size,
num_classes,
dim,
heads,
num_hierarchies,
block_repeats,
mlp_mult = 4,
channels = 3,
dim_head = 64,
dropout = 0.
):
super().__init__()
assert (image_size % patch_size) == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
fmap_size = image_size // patch_size
blocks = 2 ** (num_hierarchies - 1)
seq_len = (fmap_size // blocks) ** 2 # sequence length is held constant across heirarchy
hierarchies = list(reversed(range(num_hierarchies)))
mults = [2 ** i for i in hierarchies]
layer_heads = list(map(lambda t: t * heads, mults))
layer_dims = list(map(lambda t: t * dim, mults))
layer_dims = [*layer_dims, layer_dims[-1]]
dim_pairs = zip(layer_dims[:-1], layer_dims[1:])
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (p1 p2 c) h w', p1 = patch_size, p2 = patch_size),
nn.Conv2d(patch_dim, layer_dims[0], 1),
)
block_repeats = cast_tuple(block_repeats, num_hierarchies)
self.layers = nn.ModuleList([])
for level, heads, (dim_in, dim_out), block_repeat in zip(hierarchies, layer_heads, dim_pairs, block_repeats):
is_last = level == 0
depth = block_repeat
self.layers.append(nn.ModuleList([
Transformer(dim_in, seq_len, depth, heads, mlp_mult, dropout),
Aggregate(dim_in, dim_out) if not is_last else nn.Identity()
]))
self.mlp_head = nn.Sequential(
LayerNorm(dim),
Reduce('b c h w -> b c', 'mean'),
nn.Linear(dim, num_classes)
)
def forward(self, img):
x = self.to_patch_embedding(img)
b, c, h, w = x.shape
num_hierarchies = len(self.layers)
for level, (transformer, aggregate) in zip(reversed(range(num_hierarchies)), self.layers):
block_size = 2 ** level
x = rearrange(x, 'b c (b1 h) (b2 w) -> (b b1 b2) c h w', b1 = block_size, b2 = block_size)
x = transformer(x)
x = rearrange(x, '(b b1 b2) c h w -> b c (b1 h) (b2 w)', b1 = block_size, b2 = block_size)
x = aggregate(x)
return self.mlp_head(x)

View File

@@ -175,7 +175,7 @@ class PiT(nn.Module):
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding
x += self.pos_embedding[:, :n+1]
x = self.dropout(x)
x = self.layers(x)

View File

@@ -8,7 +8,7 @@ def find_modules(nn_module, type):
return [module for module in nn_module.modules() if isinstance(module, type)]
class Recorder(nn.Module):
def __init__(self, vit):
def __init__(self, vit, device = None):
super().__init__()
self.vit = vit
@@ -17,6 +17,7 @@ class Recorder(nn.Module):
self.hooks = []
self.hook_registered = False
self.ejected = False
self.device = device
def _hook(self, _, input, output):
self.recordings.append(output.clone().detach())
@@ -45,10 +46,14 @@ class Recorder(nn.Module):
def forward(self, img):
assert not self.ejected, 'recorder has been ejected, cannot be used anymore'
self.clear()
if not self.hook_registered:
self._register_hook()
pred = self.vit(img)
attns = torch.stack(self.recordings, dim = 1)
# move all recordings to one device before stacking
target_device = self.device if self.device is not None else img.device
recordings = tuple(map(lambda t: t.to(target_device), self.recordings))
attns = torch.stack(recordings, dim = 1)
return pred, attns

267
vit_pytorch/regionvit.py Normal file
View File

@@ -0,0 +1,267 @@
import torch
from torch import nn, einsum
from einops import rearrange
from einops.layers.torch import Rearrange, Reduce
import torch.nn.functional as F
# helpers
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
def divisible_by(val, d):
return (val % d) == 0
# helper classes
class Downsample(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.conv = nn.Conv2d(dim_in, dim_out, 3, stride = 2, padding = 1)
def forward(self, x):
return self.conv(x)
class PEG(nn.Module):
def __init__(self, dim, kernel_size = 3):
super().__init__()
self.proj = nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1)
def forward(self, x):
return self.proj(x) + x
# transformer classes
def FeedForward(dim, mult = 4, dropout = 0.):
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, dim * mult, 1),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim, 1)
)
class Attention(nn.Module):
def __init__(
self,
dim,
heads = 4,
dim_head = 32,
dropout = 0.
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
inner_dim = dim_head * heads
self.norm = nn.LayerNorm(dim)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
def forward(self, x, rel_pos_bias = None):
h = self.heads
# prenorm
x = self.norm(x)
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
# split heads
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
q = q * self.scale
sim = einsum('b h i d, b h j d -> b h i j', q, k)
# add relative positional bias for local tokens
if exists(rel_pos_bias):
sim = sim + rel_pos_bias
attn = sim.softmax(dim = -1)
# merge heads
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class R2LTransformer(nn.Module):
def __init__(
self,
dim,
*,
window_size,
depth = 4,
heads = 4,
dim_head = 32,
attn_dropout = 0.,
ff_dropout = 0.,
):
super().__init__()
self.layers = nn.ModuleList([])
self.window_size = window_size
rel_positions = 2 * window_size - 1
self.local_rel_pos_bias = nn.Embedding(rel_positions ** 2, heads)
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout),
FeedForward(dim, dropout = ff_dropout)
]))
def forward(self, local_tokens, region_tokens):
device = local_tokens.device
lh, lw = local_tokens.shape[-2:]
rh, rw = region_tokens.shape[-2:]
window_size_h, window_size_w = lh // rh, lw // rw
local_tokens = rearrange(local_tokens, 'b c h w -> b (h w) c')
region_tokens = rearrange(region_tokens, 'b c h w -> b (h w) c')
# calculate local relative positional bias
h_range = torch.arange(window_size_h, device = device)
w_range = torch.arange(window_size_w, device = device)
grid_x, grid_y = torch.meshgrid(h_range, w_range)
grid = torch.stack((grid_x, grid_y))
grid = rearrange(grid, 'c h w -> c (h w)')
grid = (grid[:, :, None] - grid[:, None, :]) + (self.window_size - 1)
bias_indices = (grid * torch.tensor([1, self.window_size * 2 - 1], device = device)[:, None, None]).sum(dim = 0)
rel_pos_bias = self.local_rel_pos_bias(bias_indices)
rel_pos_bias = rearrange(rel_pos_bias, 'i j h -> () h i j')
rel_pos_bias = F.pad(rel_pos_bias, (1, 0, 1, 0), value = 0)
# go through r2l transformer layers
for attn, ff in self.layers:
region_tokens = attn(region_tokens) + region_tokens
# concat region tokens to local tokens
local_tokens = rearrange(local_tokens, 'b (h w) d -> b h w d', h = lh)
local_tokens = rearrange(local_tokens, 'b (h p1) (w p2) d -> (b h w) (p1 p2) d', p1 = window_size_h, p2 = window_size_w)
region_tokens = rearrange(region_tokens, 'b n d -> (b n) () d')
# do self attention on local tokens, along with its regional token
region_and_local_tokens = torch.cat((region_tokens, local_tokens), dim = 1)
region_and_local_tokens = attn(region_and_local_tokens, rel_pos_bias = rel_pos_bias) + region_and_local_tokens
# feedforward
region_and_local_tokens = ff(region_and_local_tokens) + region_and_local_tokens
# split back local and regional tokens
region_tokens, local_tokens = region_and_local_tokens[:, :1], region_and_local_tokens[:, 1:]
local_tokens = rearrange(local_tokens, '(b h w) (p1 p2) d -> b (h p1 w p2) d', h = lh // window_size_h, w = lw // window_size_w, p1 = window_size_h)
region_tokens = rearrange(region_tokens, '(b n) () d -> b n d', n = rh * rw)
local_tokens = rearrange(local_tokens, 'b (h w) c -> b c h w', h = lh, w = lw)
region_tokens = rearrange(region_tokens, 'b (h w) c -> b c h w', h = rh, w = rw)
return local_tokens, region_tokens
# classes
class RegionViT(nn.Module):
def __init__(
self,
*,
dim = (64, 128, 256, 512),
depth = (2, 2, 8, 2),
window_size = 7,
num_classes = 1000,
tokenize_local_3_conv = False,
local_patch_size = 4,
use_peg = False,
attn_dropout = 0.,
ff_dropout = 0.,
channels = 3,
):
super().__init__()
dim = cast_tuple(dim, 4)
depth = cast_tuple(depth, 4)
assert len(dim) == 4, 'dim needs to be a single value or a tuple of length 4'
assert len(depth) == 4, 'depth needs to be a single value or a tuple of length 4'
self.local_patch_size = local_patch_size
region_patch_size = local_patch_size * window_size
self.region_patch_size = local_patch_size * window_size
init_dim, *_, last_dim = dim
# local and region encoders
if tokenize_local_3_conv:
self.local_encoder = nn.Sequential(
nn.Conv2d(3, init_dim, 3, 2, 1),
nn.LayerNorm(init_dim),
nn.GELU(),
nn.Conv2d(init_dim, init_dim, 3, 2, 1),
nn.LayerNorm(init_dim),
nn.GELU(),
nn.Conv2d(init_dim, init_dim, 3, 1, 1)
)
else:
self.local_encoder = nn.Conv2d(3, init_dim, 8, 4, 3)
self.region_encoder = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = region_patch_size, p2 = region_patch_size),
nn.Conv2d((region_patch_size ** 2) * channels, init_dim, 1)
)
# layers
current_dim = init_dim
self.layers = nn.ModuleList([])
for ind, dim, num_layers in zip(range(4), dim, depth):
not_first = ind != 0
need_downsample = not_first
need_peg = not_first and use_peg
self.layers.append(nn.ModuleList([
Downsample(current_dim, dim) if need_downsample else nn.Identity(),
PEG(dim) if need_peg else nn.Identity(),
R2LTransformer(dim, depth = num_layers, window_size = window_size, attn_dropout = attn_dropout, ff_dropout = ff_dropout)
]))
current_dim = dim
# final logits
self.to_logits = nn.Sequential(
Reduce('b c h w -> b c', 'mean'),
nn.LayerNorm(last_dim),
nn.Linear(last_dim, num_classes)
)
def forward(
self,
x,
return_local_tokens = False
):
*_, h, w = x.shape
assert divisible_by(h, self.region_patch_size) and divisible_by(w, self.region_patch_size), 'height and width must be divisible by region patch size'
assert divisible_by(h, self.local_patch_size) and divisible_by(w, self.local_patch_size), 'height and width must be divisible by local patch size'
local_tokens = self.local_encoder(x)
region_tokens = self.region_encoder(x)
for down, peg, transformer in self.layers:
local_tokens, region_tokens = down(local_tokens), down(region_tokens)
local_tokens = peg(local_tokens)
local_tokens, region_tokens = transformer(local_tokens, region_tokens)
return self.to_logits(region_tokens)

View File

@@ -19,7 +19,7 @@ class AxialRotaryEmbedding(nn.Module):
def __init__(self, dim, max_freq = 10):
super().__init__()
self.dim = dim
scales = torch.logspace(0., log(max_freq / 2) / log(2), self.dim // 4, base = 2)
scales = torch.linspace(1., max_freq / 2, self.dim // 4)
self.register_buffer('scales', scales)
def forward(self, x):
@@ -154,10 +154,10 @@ class Attention(nn.Module):
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., use_rotary = True, use_ds_conv = True, use_glu = True):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, image_size, dropout = 0., use_rotary = True, use_ds_conv = True, use_glu = True):
super().__init__()
self.layers = nn.ModuleList([])
self.pos_emb = AxialRotaryEmbedding(dim_head)
self.pos_emb = AxialRotaryEmbedding(dim_head, max_freq = image_size)
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_rotary = use_rotary, use_ds_conv = use_ds_conv)),
@@ -187,7 +187,7 @@ class RvT(nn.Module):
)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, use_rotary, use_ds_conv, use_glu)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, image_size, dropout, use_rotary, use_ds_conv, use_glu)
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),

View File

@@ -35,13 +35,14 @@ class T2TViT(nn.Module):
for i, (kernel_size, stride) in enumerate(t2t_layers):
layer_dim *= kernel_size ** 2
is_first = i == 0
is_last = i == (len(t2t_layers) - 1)
output_image_size = conv_output_size(output_image_size, kernel_size, stride, stride // 2)
layers.extend([
RearrangeImage() if not is_first else nn.Identity(),
nn.Unfold(kernel_size = kernel_size, stride = stride, padding = stride // 2),
Rearrange('b c n -> b n c'),
Transformer(dim = layer_dim, heads = 1, depth = 1, dim_head = layer_dim, mlp_dim = layer_dim, dropout = dropout),
Transformer(dim = layer_dim, heads = 1, depth = 1, dim_head = layer_dim, mlp_dim = layer_dim, dropout = dropout) if not is_last else nn.Identity(),
])
layers.append(nn.Linear(layer_dim, dim))
@@ -71,7 +72,7 @@ class T2TViT(nn.Module):
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding
x += self.pos_embedding[:, :n+1]
x = self.dropout(x)
x = self.transformer(x)

View File

@@ -1,6 +1,5 @@
import torch
from torch import nn, einsum
import torch.nn.functional as F
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
@@ -51,15 +50,14 @@ class Attention(nn.Module):
) if project_out else nn.Identity()
def forward(self, x):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)