mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 16:12:29 +00:00
Compare commits
24 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5ae555750f | ||
|
|
c5a461661c | ||
|
|
e212918e2d | ||
|
|
dc57c75478 | ||
|
|
99c44cf5f6 | ||
|
|
5b16e8f809 | ||
|
|
e8f6d72033 | ||
|
|
cb1729af28 | ||
|
|
9e50b2a41e | ||
|
|
06d375351e | ||
|
|
f196d1ec5b | ||
|
|
529044c9b3 | ||
|
|
c30655f3bc | ||
|
|
d2d6de01d3 | ||
|
|
b9eadaef60 | ||
|
|
24ac8350bf | ||
|
|
ca3cef9de0 | ||
|
|
6e1be11517 | ||
|
|
73ed562ce4 | ||
|
|
ff863175a6 | ||
|
|
ca0bdca192 | ||
|
|
1c70271778 | ||
|
|
d7d3febfe3 | ||
|
|
946815164a |
197
README.md
197
README.md
@@ -1,5 +1,35 @@
|
||||
<img src="./images/vit.gif" width="500px"></img>
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Vision Transformer - Pytorch](#vision-transformer---pytorch)
|
||||
- [Install](#install)
|
||||
- [Usage](#usage)
|
||||
- [Parameters](#parameters)
|
||||
- [Distillation](#distillation)
|
||||
- [Deep ViT](#deep-vit)
|
||||
- [CaiT](#cait)
|
||||
- [Token-to-Token ViT](#token-to-token-vit)
|
||||
- [CCT](#cct)
|
||||
- [Cross ViT](#cross-vit)
|
||||
- [PiT](#pit)
|
||||
- [LeViT](#levit)
|
||||
- [CvT](#cvt)
|
||||
- [Twins SVT](#twins-svt)
|
||||
- [RegionViT](#regionvit)
|
||||
- [NesT](#nest)
|
||||
- [Masked Autoencoder](#masked-autoencoder)
|
||||
- [Simple Masked Image Modeling](#simple-masked-image-modeling)
|
||||
- [Masked Patch Prediction](#masked-patch-prediction)
|
||||
- [Dino](#dino)
|
||||
- [Accessing Attention](#accessing-attention)
|
||||
- [Research Ideas](#research-ideas)
|
||||
* [Efficient Attention](#efficient-attention)
|
||||
* [Combining with other Transformer improvements](#combining-with-other-transformer-improvements)
|
||||
- [FAQ](#faq)
|
||||
- [Resources](#resources)
|
||||
- [Citations](#citations)
|
||||
|
||||
## Vision Transformer - Pytorch
|
||||
|
||||
Implementation of <a href="https://openreview.net/pdf?id=YicbFdNTTy">Vision Transformer</a>, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. Significance is further explained in <a href="https://www.youtube.com/watch?v=TrdevFK_am4">Yannic Kilcher's</a> video. There's really not much to code here, but may as well lay it out for everyone so we expedite the attention revolution.
|
||||
@@ -435,6 +465,34 @@ img = torch.randn(1, 3, 224, 224)
|
||||
pred = model(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## RegionViT
|
||||
|
||||
<img src="./images/regionvit.png" width="400px"></img>
|
||||
|
||||
<img src="./images/regionvit2.png" width="400px"></img>
|
||||
|
||||
<a href="https://arxiv.org/abs/2106.02689">This paper</a> proposes to divide up the feature map into local regions, whereby the local tokens attend to each other. Each local region has its own regional token which then attends to all its local tokens, as well as other regional tokens.
|
||||
|
||||
You can use it as follows
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.regionvit import RegionViT
|
||||
|
||||
model = RegionViT(
|
||||
dim = (64, 128, 256, 512), # tuple of size 4, indicating dimension at each stage
|
||||
depth = (2, 2, 8, 2), # depth of the region to local transformer at each stage
|
||||
window_size = 7, # window size, which should be either 7 or 14
|
||||
num_classes = 1000, # number of output classes
|
||||
tokenize_local_3_conv = False, # whether to use a 3 layer convolution to encode the local tokens from the image. the paper uses this for the smaller models, but uses only 1 conv (set to False) for the larger models
|
||||
use_peg = False, # whether to use positional generating module. they used this for object detection for a boost in performance
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 224, 224)
|
||||
|
||||
pred = model(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## NesT
|
||||
|
||||
<img src="./images/nest.png" width="400px"></img>
|
||||
@@ -458,9 +516,93 @@ nest = NesT(
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 224, 224)
|
||||
|
||||
pred = nest(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## Simple Masked Image Modeling
|
||||
|
||||
<img src="./images/simmim.png" width="400px"/>
|
||||
|
||||
This <a href="https://arxiv.org/abs/2111.09886">paper</a> proposes a simple masked image modeling (SimMIM) scheme, using only a linear projection off the masked tokens into pixel space followed by an L1 loss with the pixel values of the masked patches. Results are competitive with other more complicated approaches.
|
||||
|
||||
You can use this as follows
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch import ViT
|
||||
from vit_pytorch.simmim import SimMIM
|
||||
|
||||
v = ViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048
|
||||
)
|
||||
|
||||
mim = SimMIM(
|
||||
encoder = v,
|
||||
masking_ratio = 0.5 # they found 50% to yield the best results
|
||||
)
|
||||
|
||||
images = torch.randn(8, 3, 256, 256)
|
||||
|
||||
loss = mim(images)
|
||||
loss.backward()
|
||||
|
||||
# that's all!
|
||||
# do the above in a for loop many times with a lot of images and your vision transformer will learn
|
||||
|
||||
torch.save(v.state_dict(), './trained-vit.pt')
|
||||
```
|
||||
|
||||
|
||||
## Masked Autoencoder
|
||||
|
||||
<img src="./images/mae.png" width="400px"/>
|
||||
|
||||
A new <a href="https://arxiv.org/abs/2111.06377">Kaiming He paper</a> proposes a simple autoencoder scheme where the vision transformer attends to a set of unmasked patches, and a smaller decoder tries to reconstruct the masked pixel values.
|
||||
|
||||
<a href="https://www.youtube.com/watch?v=LKixq2S2Pz8">DeepReader quick paper review</a>
|
||||
|
||||
You can use it with the following code
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch import ViT, MAE
|
||||
|
||||
v = ViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048
|
||||
)
|
||||
|
||||
mae = MAE(
|
||||
encoder = v,
|
||||
masking_ratio = 0.75, # the paper recommended 75% masked patches
|
||||
decoder_dim = 512, # paper showed good results with just 512
|
||||
decoder_depth = 6 # anywhere from 1 to 8
|
||||
)
|
||||
|
||||
images = torch.randn(8, 3, 256, 256)
|
||||
|
||||
loss = mae(images)
|
||||
loss.backward()
|
||||
|
||||
# that's all!
|
||||
# do the above in a for loop many times with a lot of images and your vision transformer will learn
|
||||
|
||||
# save your improved vision transformer
|
||||
torch.save(v.state_dict(), './trained-vit.pt')
|
||||
```
|
||||
|
||||
## Masked Patch Prediction
|
||||
|
||||
Thanks to <a href="https://github.com/zankner">Zach</a>, you can train using the original masked patch prediction task presented in the paper, with the following code.
|
||||
@@ -739,13 +881,13 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
## Citations
|
||||
```bibtex
|
||||
@article{hassani2021escaping,
|
||||
title = {Escaping the Big Data Paradigm with Compact Transformers},
|
||||
author = {Ali Hassani and Steven Walton and Nikhil Shah and Abulikemu Abuduweili and Jiachen Li and Humphrey Shi},
|
||||
year = 2021,
|
||||
url = {https://arxiv.org/abs/2104.05704},
|
||||
eprint = {2104.05704},
|
||||
archiveprefix = {arXiv},
|
||||
primaryclass = {cs.CV}
|
||||
title = {Escaping the Big Data Paradigm with Compact Transformers},
|
||||
author = {Ali Hassani and Steven Walton and Nikhil Shah and Abulikemu Abuduweili and Jiachen Li and Humphrey Shi},
|
||||
year = 2021,
|
||||
url = {https://arxiv.org/abs/2104.05704},
|
||||
eprint = {2104.05704},
|
||||
archiveprefix = {arXiv},
|
||||
primaryclass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
@@ -773,10 +915,10 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
|
||||
```bibtex
|
||||
@misc{yuan2021tokenstotoken,
|
||||
title = {Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet},
|
||||
author = {Li Yuan and Yunpeng Chen and Tao Wang and Weihao Yu and Yujun Shi and Francis EH Tay and Jiashi Feng and Shuicheng Yan},
|
||||
year = {2021},
|
||||
eprint = {2101.11986},
|
||||
title = {Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet},
|
||||
author = {Li Yuan and Yunpeng Chen and Tao Wang and Weihao Yu and Yujun Shi and Francis EH Tay and Jiashi Feng and Shuicheng Yan},
|
||||
year = {2021},
|
||||
eprint = {2101.11986},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
@@ -892,6 +1034,17 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{chen2021regionvit,
|
||||
title = {RegionViT: Regional-to-Local Attention for Vision Transformers},
|
||||
author = {Chun-Fu Chen and Rameswar Panda and Quanfu Fan},
|
||||
year = {2021},
|
||||
eprint = {2106.02689},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{caron2021emerging,
|
||||
title = {Emerging Properties in Self-Supervised Vision Transformers},
|
||||
@@ -903,6 +1056,28 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{he2021masked,
|
||||
title = {Masked Autoencoders Are Scalable Vision Learners},
|
||||
author = {Kaiming He and Xinlei Chen and Saining Xie and Yanghao Li and Piotr Dollár and Ross Girshick},
|
||||
year = {2021},
|
||||
eprint = {2111.06377},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{xie2021simmim,
|
||||
title = {SimMIM: A Simple Framework for Masked Image Modeling},
|
||||
author = {Zhenda Xie and Zheng Zhang and Yue Cao and Yutong Lin and Jianmin Bao and Zhuliang Yao and Qi Dai and Han Hu},
|
||||
year = {2021},
|
||||
eprint = {2111.09886},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{vaswani2017attention,
|
||||
title = {Attention Is All You Need},
|
||||
|
||||
@@ -364,9 +364,8 @@
|
||||
"\n",
|
||||
"val_transforms = transforms.Compose(\n",
|
||||
" [\n",
|
||||
" transforms.Resize((224, 224)),\n",
|
||||
" transforms.RandomResizedCrop(224),\n",
|
||||
" transforms.RandomHorizontalFlip(),\n",
|
||||
" transforms.Resize(256),\n",
|
||||
" transforms.CenterCrop(224),\n",
|
||||
" transforms.ToTensor(),\n",
|
||||
" ]\n",
|
||||
")\n",
|
||||
@@ -374,9 +373,8 @@
|
||||
"\n",
|
||||
"test_transforms = transforms.Compose(\n",
|
||||
" [\n",
|
||||
" transforms.Resize((224, 224)),\n",
|
||||
" transforms.RandomResizedCrop(224),\n",
|
||||
" transforms.RandomHorizontalFlip(),\n",
|
||||
" transforms.Resize(256),\n",
|
||||
" transforms.CenterCrop(224),\n",
|
||||
" transforms.ToTensor(),\n",
|
||||
" ]\n",
|
||||
")\n"
|
||||
@@ -6250,4 +6248,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 1
|
||||
}
|
||||
}
|
||||
BIN
images/mae.png
Normal file
BIN
images/mae.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 198 KiB |
BIN
images/regionvit.png
Normal file
BIN
images/regionvit.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 94 KiB |
BIN
images/regionvit2.png
Normal file
BIN
images/regionvit2.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 55 KiB |
BIN
images/simmim.png
Normal file
BIN
images/simmim.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 365 KiB |
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.20.5',
|
||||
version = '0.23.2',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
from vit_pytorch.vit import ViT
|
||||
from vit_pytorch.mae import MAE
|
||||
from vit_pytorch.dino import Dino
|
||||
|
||||
92
vit_pytorch/mae.py
Normal file
92
vit_pytorch/mae.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from einops import repeat
|
||||
|
||||
from vit_pytorch.vit import Transformer
|
||||
|
||||
class MAE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
encoder,
|
||||
decoder_dim,
|
||||
masking_ratio = 0.75,
|
||||
decoder_depth = 1,
|
||||
decoder_heads = 8,
|
||||
decoder_dim_head = 64
|
||||
):
|
||||
super().__init__()
|
||||
assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1'
|
||||
self.masking_ratio = masking_ratio
|
||||
|
||||
# extract some hyperparameters and functions from encoder (vision transformer to be trained)
|
||||
|
||||
self.encoder = encoder
|
||||
num_patches, encoder_dim = encoder.pos_embedding.shape[-2:]
|
||||
self.to_patch, self.patch_to_emb = encoder.to_patch_embedding[:2]
|
||||
pixel_values_per_patch = self.patch_to_emb.weight.shape[-1]
|
||||
|
||||
# decoder parameters
|
||||
|
||||
self.enc_to_dec = nn.Linear(encoder_dim, decoder_dim) if encoder_dim != decoder_dim else nn.Identity()
|
||||
self.mask_token = nn.Parameter(torch.randn(decoder_dim))
|
||||
self.decoder = Transformer(dim = decoder_dim, depth = decoder_depth, heads = decoder_heads, dim_head = decoder_dim_head, mlp_dim = decoder_dim * 4)
|
||||
self.decoder_pos_emb = nn.Embedding(num_patches, decoder_dim)
|
||||
self.to_pixels = nn.Linear(decoder_dim, pixel_values_per_patch)
|
||||
|
||||
def forward(self, img):
|
||||
device = img.device
|
||||
|
||||
# get patches
|
||||
|
||||
patches = self.to_patch(img)
|
||||
batch, num_patches, *_ = patches.shape
|
||||
|
||||
# patch to encoder tokens and add positions
|
||||
|
||||
tokens = self.patch_to_emb(patches)
|
||||
tokens = tokens + self.encoder.pos_embedding[:, 1:(num_patches + 1)]
|
||||
|
||||
# calculate of patches needed to be masked, and get random indices, dividing it up for mask vs unmasked
|
||||
|
||||
num_masked = int(self.masking_ratio * num_patches)
|
||||
rand_indices = torch.rand(batch, num_patches, device = device).argsort(dim = -1)
|
||||
masked_indices, unmasked_indices = rand_indices[:, :num_masked], rand_indices[:, num_masked:]
|
||||
|
||||
# get the unmasked tokens to be encoded
|
||||
|
||||
batch_range = torch.arange(batch, device = device)[:, None]
|
||||
tokens = tokens[batch_range, unmasked_indices]
|
||||
|
||||
# get the patches to be masked for the final reconstruction loss
|
||||
|
||||
masked_patches = patches[batch_range, masked_indices]
|
||||
|
||||
# attend with vision transformer
|
||||
|
||||
encoded_tokens = self.encoder.transformer(tokens)
|
||||
|
||||
# project encoder to decoder dimensions, if they are not equal - the paper says you can get away with a smaller dimension for decoder
|
||||
|
||||
decoder_tokens = self.enc_to_dec(encoded_tokens)
|
||||
|
||||
# repeat mask tokens for number of masked, and add the positions using the masked indices derived above
|
||||
|
||||
mask_tokens = repeat(self.mask_token, 'd -> b n d', b = batch, n = num_masked)
|
||||
mask_tokens = mask_tokens + self.decoder_pos_emb(masked_indices)
|
||||
|
||||
# concat the masked tokens to the decoder tokens and attend with decoder
|
||||
|
||||
decoder_tokens = torch.cat((decoder_tokens, mask_tokens), dim = 1)
|
||||
decoded_tokens = self.decoder(decoder_tokens)
|
||||
|
||||
# splice out the mask tokens and project to pixel values
|
||||
|
||||
mask_tokens = decoded_tokens[:, -num_masked:]
|
||||
pred_pixel_values = self.to_pixels(mask_tokens)
|
||||
|
||||
# calculate reconstruction loss
|
||||
|
||||
recon_loss = F.mse_loss(pred_pixel_values, masked_patches)
|
||||
return recon_loss
|
||||
@@ -175,7 +175,7 @@ class PiT(nn.Module):
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding
|
||||
x += self.pos_embedding[:, :n+1]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.layers(x)
|
||||
|
||||
267
vit_pytorch/regionvit.py
Normal file
267
vit_pytorch/regionvit.py
Normal file
@@ -0,0 +1,267 @@
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange, Reduce
|
||||
import torch.nn.functional as F
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def cast_tuple(val, length = 1):
|
||||
return val if isinstance(val, tuple) else ((val,) * length)
|
||||
|
||||
def divisible_by(val, d):
|
||||
return (val % d) == 0
|
||||
|
||||
# helper classes
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(dim_in, dim_out, 3, stride = 2, padding = 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
class PEG(nn.Module):
|
||||
def __init__(self, dim, kernel_size = 3):
|
||||
super().__init__()
|
||||
self.proj = nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.proj(x) + x
|
||||
|
||||
# transformer classes
|
||||
|
||||
def FeedForward(dim, mult = 4, dropout = 0.):
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, dim * mult, 1),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(dim * mult, dim, 1)
|
||||
)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
heads = 4,
|
||||
dim_head = 32,
|
||||
dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim)
|
||||
|
||||
def forward(self, x, rel_pos_bias = None):
|
||||
h = self.heads
|
||||
|
||||
# prenorm
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
|
||||
|
||||
# split heads
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
|
||||
q = q * self.scale
|
||||
|
||||
sim = einsum('b h i d, b h j d -> b h i j', q, k)
|
||||
|
||||
# add relative positional bias for local tokens
|
||||
|
||||
if exists(rel_pos_bias):
|
||||
sim = sim + rel_pos_bias
|
||||
|
||||
attn = sim.softmax(dim = -1)
|
||||
|
||||
# merge heads
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class R2LTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
window_size,
|
||||
depth = 4,
|
||||
heads = 4,
|
||||
dim_head = 32,
|
||||
attn_dropout = 0.,
|
||||
ff_dropout = 0.,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
self.window_size = window_size
|
||||
rel_positions = 2 * window_size - 1
|
||||
self.local_rel_pos_bias = nn.Embedding(rel_positions ** 2, heads)
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout),
|
||||
FeedForward(dim, dropout = ff_dropout)
|
||||
]))
|
||||
|
||||
def forward(self, local_tokens, region_tokens):
|
||||
device = local_tokens.device
|
||||
lh, lw = local_tokens.shape[-2:]
|
||||
rh, rw = region_tokens.shape[-2:]
|
||||
window_size_h, window_size_w = lh // rh, lw // rw
|
||||
|
||||
local_tokens = rearrange(local_tokens, 'b c h w -> b (h w) c')
|
||||
region_tokens = rearrange(region_tokens, 'b c h w -> b (h w) c')
|
||||
|
||||
# calculate local relative positional bias
|
||||
|
||||
h_range = torch.arange(window_size_h, device = device)
|
||||
w_range = torch.arange(window_size_w, device = device)
|
||||
|
||||
grid_x, grid_y = torch.meshgrid(h_range, w_range)
|
||||
grid = torch.stack((grid_x, grid_y))
|
||||
grid = rearrange(grid, 'c h w -> c (h w)')
|
||||
grid = (grid[:, :, None] - grid[:, None, :]) + (self.window_size - 1)
|
||||
bias_indices = (grid * torch.tensor([1, self.window_size * 2 - 1], device = device)[:, None, None]).sum(dim = 0)
|
||||
rel_pos_bias = self.local_rel_pos_bias(bias_indices)
|
||||
rel_pos_bias = rearrange(rel_pos_bias, 'i j h -> () h i j')
|
||||
rel_pos_bias = F.pad(rel_pos_bias, (1, 0, 1, 0), value = 0)
|
||||
|
||||
# go through r2l transformer layers
|
||||
|
||||
for attn, ff in self.layers:
|
||||
region_tokens = attn(region_tokens) + region_tokens
|
||||
|
||||
# concat region tokens to local tokens
|
||||
|
||||
local_tokens = rearrange(local_tokens, 'b (h w) d -> b h w d', h = lh)
|
||||
local_tokens = rearrange(local_tokens, 'b (h p1) (w p2) d -> (b h w) (p1 p2) d', p1 = window_size_h, p2 = window_size_w)
|
||||
region_tokens = rearrange(region_tokens, 'b n d -> (b n) () d')
|
||||
|
||||
# do self attention on local tokens, along with its regional token
|
||||
|
||||
region_and_local_tokens = torch.cat((region_tokens, local_tokens), dim = 1)
|
||||
region_and_local_tokens = attn(region_and_local_tokens, rel_pos_bias = rel_pos_bias) + region_and_local_tokens
|
||||
|
||||
# feedforward
|
||||
|
||||
region_and_local_tokens = ff(region_and_local_tokens) + region_and_local_tokens
|
||||
|
||||
# split back local and regional tokens
|
||||
|
||||
region_tokens, local_tokens = region_and_local_tokens[:, :1], region_and_local_tokens[:, 1:]
|
||||
local_tokens = rearrange(local_tokens, '(b h w) (p1 p2) d -> b (h p1 w p2) d', h = lh // window_size_h, w = lw // window_size_w, p1 = window_size_h)
|
||||
region_tokens = rearrange(region_tokens, '(b n) () d -> b n d', n = rh * rw)
|
||||
|
||||
local_tokens = rearrange(local_tokens, 'b (h w) c -> b c h w', h = lh, w = lw)
|
||||
region_tokens = rearrange(region_tokens, 'b (h w) c -> b c h w', h = rh, w = rw)
|
||||
return local_tokens, region_tokens
|
||||
|
||||
# classes
|
||||
|
||||
class RegionViT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim = (64, 128, 256, 512),
|
||||
depth = (2, 2, 8, 2),
|
||||
window_size = 7,
|
||||
num_classes = 1000,
|
||||
tokenize_local_3_conv = False,
|
||||
local_patch_size = 4,
|
||||
use_peg = False,
|
||||
attn_dropout = 0.,
|
||||
ff_dropout = 0.,
|
||||
channels = 3,
|
||||
):
|
||||
super().__init__()
|
||||
dim = cast_tuple(dim, 4)
|
||||
depth = cast_tuple(depth, 4)
|
||||
assert len(dim) == 4, 'dim needs to be a single value or a tuple of length 4'
|
||||
assert len(depth) == 4, 'depth needs to be a single value or a tuple of length 4'
|
||||
|
||||
self.local_patch_size = local_patch_size
|
||||
|
||||
region_patch_size = local_patch_size * window_size
|
||||
self.region_patch_size = local_patch_size * window_size
|
||||
|
||||
init_dim, *_, last_dim = dim
|
||||
|
||||
# local and region encoders
|
||||
|
||||
if tokenize_local_3_conv:
|
||||
self.local_encoder = nn.Sequential(
|
||||
nn.Conv2d(3, init_dim, 3, 2, 1),
|
||||
nn.LayerNorm(init_dim),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(init_dim, init_dim, 3, 2, 1),
|
||||
nn.LayerNorm(init_dim),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(init_dim, init_dim, 3, 1, 1)
|
||||
)
|
||||
else:
|
||||
self.local_encoder = nn.Conv2d(3, init_dim, 8, 4, 3)
|
||||
|
||||
self.region_encoder = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = region_patch_size, p2 = region_patch_size),
|
||||
nn.Conv2d((region_patch_size ** 2) * channels, init_dim, 1)
|
||||
)
|
||||
|
||||
# layers
|
||||
|
||||
current_dim = init_dim
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
for ind, dim, num_layers in zip(range(4), dim, depth):
|
||||
not_first = ind != 0
|
||||
need_downsample = not_first
|
||||
need_peg = not_first and use_peg
|
||||
|
||||
self.layers.append(nn.ModuleList([
|
||||
Downsample(current_dim, dim) if need_downsample else nn.Identity(),
|
||||
PEG(dim) if need_peg else nn.Identity(),
|
||||
R2LTransformer(dim, depth = num_layers, window_size = window_size, attn_dropout = attn_dropout, ff_dropout = ff_dropout)
|
||||
]))
|
||||
|
||||
current_dim = dim
|
||||
|
||||
# final logits
|
||||
|
||||
self.to_logits = nn.Sequential(
|
||||
Reduce('b c h w -> b c', 'mean'),
|
||||
nn.LayerNorm(last_dim),
|
||||
nn.Linear(last_dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
return_local_tokens = False
|
||||
):
|
||||
*_, h, w = x.shape
|
||||
assert divisible_by(h, self.region_patch_size) and divisible_by(w, self.region_patch_size), 'height and width must be divisible by region patch size'
|
||||
assert divisible_by(h, self.local_patch_size) and divisible_by(w, self.local_patch_size), 'height and width must be divisible by local patch size'
|
||||
|
||||
local_tokens = self.local_encoder(x)
|
||||
region_tokens = self.region_encoder(x)
|
||||
|
||||
for down, peg, transformer in self.layers:
|
||||
local_tokens, region_tokens = down(local_tokens), down(region_tokens)
|
||||
local_tokens = peg(local_tokens)
|
||||
local_tokens, region_tokens = transformer(local_tokens, region_tokens)
|
||||
|
||||
return self.to_logits(region_tokens)
|
||||
@@ -19,7 +19,7 @@ class AxialRotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_freq = 10):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
scales = torch.logspace(0., log(max_freq / 2) / log(2), self.dim // 4, base = 2)
|
||||
scales = torch.linspace(1., max_freq / 2, self.dim // 4)
|
||||
self.register_buffer('scales', scales)
|
||||
|
||||
def forward(self, x):
|
||||
@@ -154,10 +154,10 @@ class Attention(nn.Module):
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., use_rotary = True, use_ds_conv = True, use_glu = True):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, image_size, dropout = 0., use_rotary = True, use_ds_conv = True, use_glu = True):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
self.pos_emb = AxialRotaryEmbedding(dim_head)
|
||||
self.pos_emb = AxialRotaryEmbedding(dim_head, max_freq = image_size)
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_rotary = use_rotary, use_ds_conv = use_ds_conv)),
|
||||
@@ -187,7 +187,7 @@ class RvT(nn.Module):
|
||||
)
|
||||
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, use_rotary, use_ds_conv, use_glu)
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, image_size, dropout, use_rotary, use_ds_conv, use_glu)
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
|
||||
84
vit_pytorch/simmim.py
Normal file
84
vit_pytorch/simmim.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from einops import repeat
|
||||
|
||||
class SimMIM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
encoder,
|
||||
masking_ratio = 0.5
|
||||
):
|
||||
super().__init__()
|
||||
assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1'
|
||||
self.masking_ratio = masking_ratio
|
||||
|
||||
# extract some hyperparameters and functions from encoder (vision transformer to be trained)
|
||||
|
||||
self.encoder = encoder
|
||||
num_patches, encoder_dim = encoder.pos_embedding.shape[-2:]
|
||||
self.to_patch, self.patch_to_emb = encoder.to_patch_embedding[:2]
|
||||
pixel_values_per_patch = self.patch_to_emb.weight.shape[-1]
|
||||
|
||||
# simple linear head
|
||||
|
||||
self.mask_token = nn.Parameter(torch.randn(encoder_dim))
|
||||
self.to_pixels = nn.Linear(encoder_dim, pixel_values_per_patch)
|
||||
|
||||
def forward(self, img):
|
||||
device = img.device
|
||||
|
||||
# get patches
|
||||
|
||||
patches = self.to_patch(img)
|
||||
batch, num_patches, *_ = patches.shape
|
||||
|
||||
# for indexing purposes
|
||||
|
||||
batch_range = torch.arange(batch, device = device)[:, None]
|
||||
|
||||
# get positions
|
||||
|
||||
pos_emb = self.encoder.pos_embedding[:, 1:(num_patches + 1)]
|
||||
|
||||
# patch to encoder tokens and add positions
|
||||
|
||||
tokens = self.patch_to_emb(patches)
|
||||
tokens = tokens + pos_emb
|
||||
|
||||
# prepare mask tokens
|
||||
|
||||
mask_tokens = repeat(self.mask_token, 'd -> b n d', b = batch, n = num_patches)
|
||||
mask_tokens = mask_tokens + pos_emb
|
||||
|
||||
# calculate of patches needed to be masked, and get positions (indices) to be masked
|
||||
|
||||
num_masked = int(self.masking_ratio * num_patches)
|
||||
masked_indices = torch.rand(batch, num_patches, device = device).topk(k = num_masked, dim = -1).indices
|
||||
masked_bool_mask = torch.zeros((batch, num_patches), device = device).scatter_(-1, masked_indices, 1).bool()
|
||||
|
||||
# mask tokens
|
||||
|
||||
tokens = torch.where(masked_bool_mask[..., None], mask_tokens, tokens)
|
||||
|
||||
# attend with vision transformer
|
||||
|
||||
encoded = self.encoder.transformer(tokens)
|
||||
|
||||
# get the masked tokens
|
||||
|
||||
encoded_mask_tokens = encoded[batch_range, masked_indices]
|
||||
|
||||
# small linear projection for predicted pixel values
|
||||
|
||||
pred_pixel_values = self.to_pixels(encoded_mask_tokens)
|
||||
|
||||
# get the masked patches for the final reconstruction loss
|
||||
|
||||
masked_patches = patches[batch_range, masked_indices]
|
||||
|
||||
# calculate reconstruction loss
|
||||
|
||||
recon_loss = F.l1_loss(pred_pixel_values, masked_patches) / num_masked
|
||||
return recon_loss
|
||||
@@ -72,7 +72,7 @@ class T2TViT(nn.Module):
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding
|
||||
x += self.pos_embedding[:, :n+1]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
@@ -51,15 +50,14 @@ class Attention(nn.Module):
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
b, n, _, h = *x.shape, self.heads
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user