Compare commits

...

27 Commits

Author SHA1 Message Date
Phil Wang
ca5e30bdf8 add SimMIM 2021-11-21 15:45:51 -08:00
Phil Wang
c5a461661c Merge pull request #170 from ankandrew/patch-1
add Table of Contents
2021-11-17 16:55:09 -08:00
ankandrew
e212918e2d add Table of Contents 2021-11-17 21:21:19 -03:00
Phil Wang
dc57c75478 cleanup 2021-11-14 12:24:48 -08:00
Phil Wang
99c44cf5f6 readme 2021-11-14 11:49:12 -08:00
Phil Wang
5b16e8f809 readme 2021-11-12 20:19:38 -08:00
Phil Wang
e8f6d72033 release masked autoencoder 2021-11-12 20:08:48 -08:00
Phil Wang
cb1729af28 more efficient feedforward for regionvit 2021-11-07 17:18:59 -08:00
Phil Wang
9e50b2a41e readme 2021-11-07 09:59:49 -08:00
Phil Wang
06d375351e add RegionViT paper 2021-11-07 09:47:28 -08:00
Phil Wang
f196d1ec5b move freqs in RvT to linspace 2021-10-05 09:23:44 -07:00
Phil Wang
529044c9b3 Merge pull request #153 from developer0hye/fix-example
fix transforms for val an test process in example code
2021-09-02 06:57:16 -07:00
yhkwon-DT01
c30655f3bc fix transforms for val an test process 2021-09-02 17:30:18 +09:00
Phil Wang
d2d6de01d3 0.20.7 2021-08-30 08:14:43 -07:00
Phil Wang
b9eadaef60 Merge pull request #151 from developer0hye/patch-1
Cleanup Attention Class & matmul based implementation for TensorRT conversion
2021-08-30 08:14:11 -07:00
Yonghye Kwon
24ac8350bf remove unused package 2021-08-30 18:25:03 +09:00
Yonghye Kwon
ca3cef9de0 Cleanup Attention Class 2021-08-30 18:05:16 +09:00
Phil Wang
6e1be11517 0.20.6 2021-08-21 09:03:54 -07:00
Phil Wang
73ed562ce4 Merge pull request #147 from developer0hye/patch-4
Make T2T process any scale image
2021-08-21 09:03:42 -07:00
Phil Wang
ff863175a6 Merge pull request #146 from developer0hye/patch-1
Make Pit process image with width and height less than the image_size
2021-08-21 09:03:31 -07:00
Yonghye Kwon
ca0bdca192 Make model process any scale image
Related to #145
2021-08-21 22:35:26 +09:00
Yonghye Kwon
1c70271778 Support image with width and height less than the image_size
Related to #145
2021-08-21 22:25:46 +09:00
Phil Wang
d7d3febfe3 Merge pull request #144 from developer0hye/patch-1
Remove unused package
2021-08-20 10:14:02 -07:00
Yonghye Kwon
946815164a Remove unused package 2021-08-20 13:44:57 +09:00
Phil Wang
aeed3381c1 use hardswish for levit 2021-08-19 08:22:55 -07:00
Phil Wang
3f754956fb remove last transformer layer in t2t 2021-08-14 08:06:23 -07:00
Phil Wang
918869571c fix hard distillation, thanks to @CiaoHe 2021-08-12 08:40:57 -07:00
17 changed files with 653 additions and 34 deletions

197
README.md
View File

@@ -1,5 +1,35 @@
<img src="./images/vit.gif" width="500px"></img>
## Table of Contents
- [Vision Transformer - Pytorch](#vision-transformer---pytorch)
- [Install](#install)
- [Usage](#usage)
- [Parameters](#parameters)
- [Distillation](#distillation)
- [Deep ViT](#deep-vit)
- [CaiT](#cait)
- [Token-to-Token ViT](#token-to-token-vit)
- [CCT](#cct)
- [Cross ViT](#cross-vit)
- [PiT](#pit)
- [LeViT](#levit)
- [CvT](#cvt)
- [Twins SVT](#twins-svt)
- [RegionViT](#regionvit)
- [NesT](#nest)
- [Masked Autoencoder](#masked-autoencoder)
- [Simple Masked Image Modeling](#simple-masked-image-modeling)
- [Masked Patch Prediction](#masked-patch-prediction)
- [Dino](#dino)
- [Accessing Attention](#accessing-attention)
- [Research Ideas](#research-ideas)
* [Efficient Attention](#efficient-attention)
* [Combining with other Transformer improvements](#combining-with-other-transformer-improvements)
- [FAQ](#faq)
- [Resources](#resources)
- [Citations](#citations)
## Vision Transformer - Pytorch
Implementation of <a href="https://openreview.net/pdf?id=YicbFdNTTy">Vision Transformer</a>, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. Significance is further explained in <a href="https://www.youtube.com/watch?v=TrdevFK_am4">Yannic Kilcher's</a> video. There's really not much to code here, but may as well lay it out for everyone so we expedite the attention revolution.
@@ -435,6 +465,34 @@ img = torch.randn(1, 3, 224, 224)
pred = model(img) # (1, 1000)
```
## RegionViT
<img src="./images/regionvit.png" width="400px"></img>
<img src="./images/regionvit2.png" width="400px"></img>
<a href="https://arxiv.org/abs/2106.02689">This paper</a> proposes to divide up the feature map into local regions, whereby the local tokens attend to each other. Each local region has its own regional token which then attends to all its local tokens, as well as other regional tokens.
You can use it as follows
```python
import torch
from vit_pytorch.regionvit import RegionViT
model = RegionViT(
dim = (64, 128, 256, 512), # tuple of size 4, indicating dimension at each stage
depth = (2, 2, 8, 2), # depth of the region to local transformer at each stage
window_size = 7, # window size, which should be either 7 or 14
num_classes = 1000, # number of output classes
tokenize_local_3_conv = False, # whether to use a 3 layer convolution to encode the local tokens from the image. the paper uses this for the smaller models, but uses only 1 conv (set to False) for the larger models
use_peg = False, # whether to use positional generating module. they used this for object detection for a boost in performance
)
img = torch.randn(1, 3, 224, 224)
pred = model(img) # (1, 1000)
```
## NesT
<img src="./images/nest.png" width="400px"></img>
@@ -458,9 +516,93 @@ nest = NesT(
)
img = torch.randn(1, 3, 224, 224)
pred = nest(img) # (1, 1000)
```
## Simple Masked Image Modeling
<img src="./images/simmim.png" width="400px"/>
This <a href="https://arxiv.org/abs/2111.09886">paper</a> proposes a simple masked image modeling (SimMIM) scheme, using only a linear projection off the masked tokens into pixel space followed by an L1 loss with the pixel values of the masked patches. Results are competitive with other more complicated approaches.
You can use this as follows
```python
import torch
from vit_pytorch import ViT
from vit_pytorch.simmim import SimMIM
v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048
)
mim = SimMIM(
encoder = v,
masking_ratio = 0.5 # they found 50% to yield the best results
)
images = torch.randn(8, 3, 256, 256)
loss = mim(images)
loss.backward()
# that's all!
# do the above in a for loop many times with a lot of images and your vision transformer will learn
torch.save(v.state_dict(), './trained-vit.pt')
```
## Masked Autoencoder
<img src="./images/mae.png" width="400px"/>
A new <a href="https://arxiv.org/abs/2111.06377">Kaiming He paper</a> proposes a simple autoencoder scheme where the vision transformer attends to a set of unmasked patches, and a smaller decoder tries to reconstruct the masked pixel values.
<a href="https://www.youtube.com/watch?v=LKixq2S2Pz8">DeepReader quick paper review</a>
You can use it with the following code
```python
import torch
from vit_pytorch import ViT, MAE
v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048
)
mae = MAE(
encoder = v,
masking_ratio = 0.75, # the paper recommended 75% masked patches
decoder_dim = 512, # paper showed good results with just 512
decoder_depth = 6 # anywhere from 1 to 8
)
images = torch.randn(8, 3, 256, 256)
loss = mae(images)
loss.backward()
# that's all!
# do the above in a for loop many times with a lot of images and your vision transformer will learn
# save your improved vision transformer
torch.save(v.state_dict(), './trained-vit.pt')
```
## Masked Patch Prediction
Thanks to <a href="https://github.com/zankner">Zach</a>, you can train using the original masked patch prediction task presented in the paper, with the following code.
@@ -739,13 +881,13 @@ Coming from computer vision and new to transformers? Here are some resources tha
## Citations
```bibtex
@article{hassani2021escaping,
title = {Escaping the Big Data Paradigm with Compact Transformers},
author = {Ali Hassani and Steven Walton and Nikhil Shah and Abulikemu Abuduweili and Jiachen Li and Humphrey Shi},
year = 2021,
url = {https://arxiv.org/abs/2104.05704},
eprint = {2104.05704},
archiveprefix = {arXiv},
primaryclass = {cs.CV}
title = {Escaping the Big Data Paradigm with Compact Transformers},
author = {Ali Hassani and Steven Walton and Nikhil Shah and Abulikemu Abuduweili and Jiachen Li and Humphrey Shi},
year = 2021,
url = {https://arxiv.org/abs/2104.05704},
eprint = {2104.05704},
archiveprefix = {arXiv},
primaryclass = {cs.CV}
}
```
@@ -773,10 +915,10 @@ Coming from computer vision and new to transformers? Here are some resources tha
```bibtex
@misc{yuan2021tokenstotoken,
title = {Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet},
author = {Li Yuan and Yunpeng Chen and Tao Wang and Weihao Yu and Yujun Shi and Francis EH Tay and Jiashi Feng and Shuicheng Yan},
year = {2021},
eprint = {2101.11986},
title = {Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet},
author = {Li Yuan and Yunpeng Chen and Tao Wang and Weihao Yu and Yujun Shi and Francis EH Tay and Jiashi Feng and Shuicheng Yan},
year = {2021},
eprint = {2101.11986},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
@@ -892,6 +1034,17 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```
```bibtex
@misc{chen2021regionvit,
title = {RegionViT: Regional-to-Local Attention for Vision Transformers},
author = {Chun-Fu Chen and Rameswar Panda and Quanfu Fan},
year = {2021},
eprint = {2106.02689},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@misc{caron2021emerging,
title = {Emerging Properties in Self-Supervised Vision Transformers},
@@ -903,6 +1056,28 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```
```bibtex
@misc{he2021masked,
title = {Masked Autoencoders Are Scalable Vision Learners},
author = {Kaiming He and Xinlei Chen and Saining Xie and Yanghao Li and Piotr Dollár and Ross Girshick},
year = {2021},
eprint = {2111.06377},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@misc{xie2021simmim,
title = {SimMIM: A Simple Framework for Masked Image Modeling},
author = {Zhenda Xie and Zheng Zhang and Yue Cao and Yutong Lin and Jianmin Bao and Zhuliang Yao and Qi Dai and Han Hu},
year = {2021},
eprint = {2111.09886},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@misc{vaswani2017attention,
title = {Attention Is All You Need},

View File

@@ -364,9 +364,8 @@
"\n",
"val_transforms = transforms.Compose(\n",
" [\n",
" transforms.Resize((224, 224)),\n",
" transforms.RandomResizedCrop(224),\n",
" transforms.RandomHorizontalFlip(),\n",
" transforms.Resize(256),\n",
" transforms.CenterCrop(224),\n",
" transforms.ToTensor(),\n",
" ]\n",
")\n",
@@ -374,9 +373,8 @@
"\n",
"test_transforms = transforms.Compose(\n",
" [\n",
" transforms.Resize((224, 224)),\n",
" transforms.RandomResizedCrop(224),\n",
" transforms.RandomHorizontalFlip(),\n",
" transforms.Resize(256),\n",
" transforms.CenterCrop(224),\n",
" transforms.ToTensor(),\n",
" ]\n",
")\n"
@@ -6250,4 +6248,4 @@
},
"nbformat": 4,
"nbformat_minor": 1
}
}

BIN
images/mae.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 198 KiB

BIN
images/regionvit.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 94 KiB

BIN
images/regionvit2.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 55 KiB

BIN
images/simmim.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 365 KiB

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.20.2',
version = '0.23.0',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',

View File

@@ -1,2 +1,3 @@
from vit_pytorch.vit import ViT
from vit_pytorch.mae import MAE
from vit_pytorch.dino import Dino

View File

@@ -148,6 +148,6 @@ class DistillWrapper(nn.Module):
else:
teacher_labels = teacher_logits.argmax(dim = -1)
distill_loss = F.cross_entropy(student_logits, teacher_labels)
distill_loss = F.cross_entropy(distill_logits, teacher_labels)
return loss * (1 - alpha) + distill_loss * alpha

View File

@@ -29,7 +29,7 @@ class FeedForward(nn.Module):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(dim, dim * mult, 1),
nn.GELU(),
nn.Hardswish(),
nn.Dropout(dropout),
nn.Conv2d(dim * mult, dim, 1),
nn.Dropout(dropout)

92
vit_pytorch/mae.py Normal file
View File

@@ -0,0 +1,92 @@
import torch
from torch import nn
import torch.nn.functional as F
from einops import repeat
from vit_pytorch.vit import Transformer
class MAE(nn.Module):
def __init__(
self,
*,
encoder,
decoder_dim,
masking_ratio = 0.75,
decoder_depth = 1,
decoder_heads = 8,
decoder_dim_head = 64
):
super().__init__()
assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1'
self.masking_ratio = masking_ratio
# extract some hyperparameters and functions from encoder (vision transformer to be trained)
self.encoder = encoder
num_patches, encoder_dim = encoder.pos_embedding.shape[-2:]
self.to_patch, self.patch_to_emb = encoder.to_patch_embedding[:2]
pixel_values_per_patch = self.patch_to_emb.weight.shape[-1]
# decoder parameters
self.enc_to_dec = nn.Linear(encoder_dim, decoder_dim) if encoder_dim != decoder_dim else nn.Identity()
self.mask_token = nn.Parameter(torch.randn(decoder_dim))
self.decoder = Transformer(dim = decoder_dim, depth = decoder_depth, heads = decoder_heads, dim_head = decoder_dim_head, mlp_dim = decoder_dim * 4)
self.decoder_pos_emb = nn.Embedding(num_patches, decoder_dim)
self.to_pixels = nn.Linear(decoder_dim, pixel_values_per_patch)
def forward(self, img):
device = img.device
# get patches
patches = self.to_patch(img)
batch, num_patches, *_ = patches.shape
# patch to encoder tokens and add positions
tokens = self.patch_to_emb(patches)
tokens = tokens + self.encoder.pos_embedding[:, 1:(num_patches + 1)]
# calculate of patches needed to be masked, and get random indices, dividing it up for mask vs unmasked
num_masked = int(self.masking_ratio * num_patches)
rand_indices = torch.rand(batch, num_patches, device = device).argsort(dim = -1)
masked_indices, unmasked_indices = rand_indices[:, :num_masked], rand_indices[:, num_masked:]
# get the unmasked tokens to be encoded
batch_range = torch.arange(batch, device = device)[:, None]
tokens = tokens[batch_range, unmasked_indices]
# get the patches to be masked for the final reconstruction loss
masked_patches = patches[batch_range, masked_indices]
# attend with vision transformer
encoded_tokens = self.encoder.transformer(tokens)
# project encoder to decoder dimensions, if they are not equal - the paper says you can get away with a smaller dimension for decoder
decoder_tokens = self.enc_to_dec(encoded_tokens)
# repeat mask tokens for number of masked, and add the positions using the masked indices derived above
mask_tokens = repeat(self.mask_token, 'd -> b n d', b = batch, n = num_masked)
mask_tokens = mask_tokens + self.decoder_pos_emb(masked_indices)
# concat the masked tokens to the decoder tokens and attend with decoder
decoder_tokens = torch.cat((decoder_tokens, mask_tokens), dim = 1)
decoded_tokens = self.decoder(decoder_tokens)
# splice out the mask tokens and project to pixel values
mask_tokens = decoded_tokens[:, -num_masked:]
pred_pixel_values = self.to_pixels(mask_tokens)
# calculate reconstruction loss
recon_loss = F.mse_loss(pred_pixel_values, masked_patches)
return recon_loss

View File

@@ -175,7 +175,7 @@ class PiT(nn.Module):
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding
x += self.pos_embedding[:, :n+1]
x = self.dropout(x)
x = self.layers(x)

267
vit_pytorch/regionvit.py Normal file
View File

@@ -0,0 +1,267 @@
import torch
from torch import nn, einsum
from einops import rearrange
from einops.layers.torch import Rearrange, Reduce
import torch.nn.functional as F
# helpers
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
def divisible_by(val, d):
return (val % d) == 0
# helper classes
class Downsample(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.conv = nn.Conv2d(dim_in, dim_out, 3, stride = 2, padding = 1)
def forward(self, x):
return self.conv(x)
class PEG(nn.Module):
def __init__(self, dim, kernel_size = 3):
super().__init__()
self.proj = nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1)
def forward(self, x):
return self.proj(x) + x
# transformer classes
def FeedForward(dim, mult = 4, dropout = 0.):
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, dim * mult, 1),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim, 1)
)
class Attention(nn.Module):
def __init__(
self,
dim,
heads = 4,
dim_head = 32,
dropout = 0.
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
inner_dim = dim_head * heads
self.norm = nn.LayerNorm(dim)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
def forward(self, x, rel_pos_bias = None):
h = self.heads
# prenorm
x = self.norm(x)
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
# split heads
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
q = q * self.scale
sim = einsum('b h i d, b h j d -> b h i j', q, k)
# add relative positional bias for local tokens
if exists(rel_pos_bias):
sim = sim + rel_pos_bias
attn = sim.softmax(dim = -1)
# merge heads
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class R2LTransformer(nn.Module):
def __init__(
self,
dim,
*,
window_size,
depth = 4,
heads = 4,
dim_head = 32,
attn_dropout = 0.,
ff_dropout = 0.,
):
super().__init__()
self.layers = nn.ModuleList([])
self.window_size = window_size
rel_positions = 2 * window_size - 1
self.local_rel_pos_bias = nn.Embedding(rel_positions ** 2, heads)
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout),
FeedForward(dim, dropout = ff_dropout)
]))
def forward(self, local_tokens, region_tokens):
device = local_tokens.device
lh, lw = local_tokens.shape[-2:]
rh, rw = region_tokens.shape[-2:]
window_size_h, window_size_w = lh // rh, lw // rw
local_tokens = rearrange(local_tokens, 'b c h w -> b (h w) c')
region_tokens = rearrange(region_tokens, 'b c h w -> b (h w) c')
# calculate local relative positional bias
h_range = torch.arange(window_size_h, device = device)
w_range = torch.arange(window_size_w, device = device)
grid_x, grid_y = torch.meshgrid(h_range, w_range)
grid = torch.stack((grid_x, grid_y))
grid = rearrange(grid, 'c h w -> c (h w)')
grid = (grid[:, :, None] - grid[:, None, :]) + (self.window_size - 1)
bias_indices = (grid * torch.tensor([1, self.window_size * 2 - 1], device = device)[:, None, None]).sum(dim = 0)
rel_pos_bias = self.local_rel_pos_bias(bias_indices)
rel_pos_bias = rearrange(rel_pos_bias, 'i j h -> () h i j')
rel_pos_bias = F.pad(rel_pos_bias, (1, 0, 1, 0), value = 0)
# go through r2l transformer layers
for attn, ff in self.layers:
region_tokens = attn(region_tokens) + region_tokens
# concat region tokens to local tokens
local_tokens = rearrange(local_tokens, 'b (h w) d -> b h w d', h = lh)
local_tokens = rearrange(local_tokens, 'b (h p1) (w p2) d -> (b h w) (p1 p2) d', p1 = window_size_h, p2 = window_size_w)
region_tokens = rearrange(region_tokens, 'b n d -> (b n) () d')
# do self attention on local tokens, along with its regional token
region_and_local_tokens = torch.cat((region_tokens, local_tokens), dim = 1)
region_and_local_tokens = attn(region_and_local_tokens, rel_pos_bias = rel_pos_bias) + region_and_local_tokens
# feedforward
region_and_local_tokens = ff(region_and_local_tokens) + region_and_local_tokens
# split back local and regional tokens
region_tokens, local_tokens = region_and_local_tokens[:, :1], region_and_local_tokens[:, 1:]
local_tokens = rearrange(local_tokens, '(b h w) (p1 p2) d -> b (h p1 w p2) d', h = lh // window_size_h, w = lw // window_size_w, p1 = window_size_h)
region_tokens = rearrange(region_tokens, '(b n) () d -> b n d', n = rh * rw)
local_tokens = rearrange(local_tokens, 'b (h w) c -> b c h w', h = lh, w = lw)
region_tokens = rearrange(region_tokens, 'b (h w) c -> b c h w', h = rh, w = rw)
return local_tokens, region_tokens
# classes
class RegionViT(nn.Module):
def __init__(
self,
*,
dim = (64, 128, 256, 512),
depth = (2, 2, 8, 2),
window_size = 7,
num_classes = 1000,
tokenize_local_3_conv = False,
local_patch_size = 4,
use_peg = False,
attn_dropout = 0.,
ff_dropout = 0.,
channels = 3,
):
super().__init__()
dim = cast_tuple(dim, 4)
depth = cast_tuple(depth, 4)
assert len(dim) == 4, 'dim needs to be a single value or a tuple of length 4'
assert len(depth) == 4, 'depth needs to be a single value or a tuple of length 4'
self.local_patch_size = local_patch_size
region_patch_size = local_patch_size * window_size
self.region_patch_size = local_patch_size * window_size
init_dim, *_, last_dim = dim
# local and region encoders
if tokenize_local_3_conv:
self.local_encoder = nn.Sequential(
nn.Conv2d(3, init_dim, 3, 2, 1),
nn.LayerNorm(init_dim),
nn.GELU(),
nn.Conv2d(init_dim, init_dim, 3, 2, 1),
nn.LayerNorm(init_dim),
nn.GELU(),
nn.Conv2d(init_dim, init_dim, 3, 1, 1)
)
else:
self.local_encoder = nn.Conv2d(3, init_dim, 8, 4, 3)
self.region_encoder = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = region_patch_size, p2 = region_patch_size),
nn.Conv2d((region_patch_size ** 2) * channels, init_dim, 1)
)
# layers
current_dim = init_dim
self.layers = nn.ModuleList([])
for ind, dim, num_layers in zip(range(4), dim, depth):
not_first = ind != 0
need_downsample = not_first
need_peg = not_first and use_peg
self.layers.append(nn.ModuleList([
Downsample(current_dim, dim) if need_downsample else nn.Identity(),
PEG(dim) if need_peg else nn.Identity(),
R2LTransformer(dim, depth = num_layers, window_size = window_size, attn_dropout = attn_dropout, ff_dropout = ff_dropout)
]))
current_dim = dim
# final logits
self.to_logits = nn.Sequential(
Reduce('b c h w -> b c', 'mean'),
nn.LayerNorm(last_dim),
nn.Linear(last_dim, num_classes)
)
def forward(
self,
x,
return_local_tokens = False
):
*_, h, w = x.shape
assert divisible_by(h, self.region_patch_size) and divisible_by(w, self.region_patch_size), 'height and width must be divisible by region patch size'
assert divisible_by(h, self.local_patch_size) and divisible_by(w, self.local_patch_size), 'height and width must be divisible by local patch size'
local_tokens = self.local_encoder(x)
region_tokens = self.region_encoder(x)
for down, peg, transformer in self.layers:
local_tokens, region_tokens = down(local_tokens), down(region_tokens)
local_tokens = peg(local_tokens)
local_tokens, region_tokens = transformer(local_tokens, region_tokens)
return self.to_logits(region_tokens)

View File

@@ -19,7 +19,7 @@ class AxialRotaryEmbedding(nn.Module):
def __init__(self, dim, max_freq = 10):
super().__init__()
self.dim = dim
scales = torch.logspace(0., log(max_freq / 2) / log(2), self.dim // 4, base = 2)
scales = torch.linspace(1., max_freq / 2, self.dim // 4)
self.register_buffer('scales', scales)
def forward(self, x):
@@ -154,10 +154,10 @@ class Attention(nn.Module):
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., use_rotary = True, use_ds_conv = True, use_glu = True):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, image_size, dropout = 0., use_rotary = True, use_ds_conv = True, use_glu = True):
super().__init__()
self.layers = nn.ModuleList([])
self.pos_emb = AxialRotaryEmbedding(dim_head)
self.pos_emb = AxialRotaryEmbedding(dim_head, max_freq = image_size)
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_rotary = use_rotary, use_ds_conv = use_ds_conv)),
@@ -187,7 +187,7 @@ class RvT(nn.Module):
)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, use_rotary, use_ds_conv, use_glu)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, image_size, dropout, use_rotary, use_ds_conv, use_glu)
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),

87
vit_pytorch/simmim.py Normal file
View File

@@ -0,0 +1,87 @@
import torch
from torch import nn
import torch.nn.functional as F
from einops import repeat
from vit_pytorch.vit import Transformer
class SimMIM(nn.Module):
def __init__(
self,
*,
encoder,
masking_ratio = 0.5
):
super().__init__()
assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1'
self.masking_ratio = masking_ratio
# extract some hyperparameters and functions from encoder (vision transformer to be trained)
self.encoder = encoder
num_patches, encoder_dim = encoder.pos_embedding.shape[-2:]
self.to_patch, self.patch_to_emb = encoder.to_patch_embedding[:2]
pixel_values_per_patch = self.patch_to_emb.weight.shape[-1]
# simple linear head
self.mask_token = nn.Parameter(torch.randn(encoder_dim))
self.to_pixels = nn.Linear(encoder_dim, pixel_values_per_patch)
def forward(self, img):
device = img.device
# get patches
patches = self.to_patch(img)
batch, num_patches, *_ = patches.shape
# for indexing purposes
batch_range = torch.arange(batch, device = device)[:, None]
# get positions
pos_emb = self.encoder.pos_embedding[:, 1:(num_patches + 1)]
# patch to encoder tokens and add positions
tokens = self.patch_to_emb(patches)
tokens = tokens + pos_emb
# prepare mask tokens
mask_tokens = repeat(self.mask_token, 'd -> b n d', b = batch, n = num_patches)
mask_tokens = mask_tokens + pos_emb
# calculate of patches needed to be masked, and get random indices, dividing it up for mask vs unmasked
num_masked = int(self.masking_ratio * num_patches)
rand_indices = torch.rand(batch, num_patches, device = device).argsort(dim = -1)
masked_indices = rand_indices[:, :num_masked]
masked_bool_mask = torch.zeros((batch, num_patches), device = device).scatter_(-1, masked_indices, 1).bool()
# mask tokens
tokens = torch.where(masked_bool_mask[..., None], mask_tokens, tokens)
# attend with vision transformer
encoded = self.encoder.transformer(tokens)
# get the masked tokens
encoded_mask_tokens = encoded[batch_range, masked_indices]
# small linear projection for predicted pixel values
pred_pixel_values = self.to_pixels(encoded_mask_tokens)
# get the masked patches for the final reconstruction loss
masked_patches = patches[batch_range, masked_indices]
# calculate reconstruction loss
recon_loss = F.l1_loss(pred_pixel_values, masked_patches) / num_masked
return recon_loss

View File

@@ -35,13 +35,14 @@ class T2TViT(nn.Module):
for i, (kernel_size, stride) in enumerate(t2t_layers):
layer_dim *= kernel_size ** 2
is_first = i == 0
is_last = i == (len(t2t_layers) - 1)
output_image_size = conv_output_size(output_image_size, kernel_size, stride, stride // 2)
layers.extend([
RearrangeImage() if not is_first else nn.Identity(),
nn.Unfold(kernel_size = kernel_size, stride = stride, padding = stride // 2),
Rearrange('b c n -> b n c'),
Transformer(dim = layer_dim, heads = 1, depth = 1, dim_head = layer_dim, mlp_dim = layer_dim, dropout = dropout),
Transformer(dim = layer_dim, heads = 1, depth = 1, dim_head = layer_dim, mlp_dim = layer_dim, dropout = dropout) if not is_last else nn.Identity(),
])
layers.append(nn.Linear(layer_dim, dim))
@@ -71,7 +72,7 @@ class T2TViT(nn.Module):
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding
x += self.pos_embedding[:, :n+1]
x = self.dropout(x)
x = self.transformer(x)

View File

@@ -1,6 +1,5 @@
import torch
from torch import nn, einsum
import torch.nn.functional as F
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
@@ -51,15 +50,14 @@ class Attention(nn.Module):
) if project_out else nn.Identity()
def forward(self, x):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)