mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2026-05-18 06:57:47 +00:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5ae555750f | ||
|
|
c5a461661c | ||
|
|
e212918e2d | ||
|
|
dc57c75478 | ||
|
|
99c44cf5f6 | ||
|
|
5b16e8f809 |
117
README.md
117
README.md
@@ -1,5 +1,35 @@
|
||||
<img src="./images/vit.gif" width="500px"></img>
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Vision Transformer - Pytorch](#vision-transformer---pytorch)
|
||||
- [Install](#install)
|
||||
- [Usage](#usage)
|
||||
- [Parameters](#parameters)
|
||||
- [Distillation](#distillation)
|
||||
- [Deep ViT](#deep-vit)
|
||||
- [CaiT](#cait)
|
||||
- [Token-to-Token ViT](#token-to-token-vit)
|
||||
- [CCT](#cct)
|
||||
- [Cross ViT](#cross-vit)
|
||||
- [PiT](#pit)
|
||||
- [LeViT](#levit)
|
||||
- [CvT](#cvt)
|
||||
- [Twins SVT](#twins-svt)
|
||||
- [RegionViT](#regionvit)
|
||||
- [NesT](#nest)
|
||||
- [Masked Autoencoder](#masked-autoencoder)
|
||||
- [Simple Masked Image Modeling](#simple-masked-image-modeling)
|
||||
- [Masked Patch Prediction](#masked-patch-prediction)
|
||||
- [Dino](#dino)
|
||||
- [Accessing Attention](#accessing-attention)
|
||||
- [Research Ideas](#research-ideas)
|
||||
* [Efficient Attention](#efficient-attention)
|
||||
* [Combining with other Transformer improvements](#combining-with-other-transformer-improvements)
|
||||
- [FAQ](#faq)
|
||||
- [Resources](#resources)
|
||||
- [Citations](#citations)
|
||||
|
||||
## Vision Transformer - Pytorch
|
||||
|
||||
Implementation of <a href="https://openreview.net/pdf?id=YicbFdNTTy">Vision Transformer</a>, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. Significance is further explained in <a href="https://www.youtube.com/watch?v=TrdevFK_am4">Yannic Kilcher's</a> video. There's really not much to code here, but may as well lay it out for everyone so we expedite the attention revolution.
|
||||
@@ -453,7 +483,7 @@ model = RegionViT(
|
||||
dim = (64, 128, 256, 512), # tuple of size 4, indicating dimension at each stage
|
||||
depth = (2, 2, 8, 2), # depth of the region to local transformer at each stage
|
||||
window_size = 7, # window size, which should be either 7 or 14
|
||||
num_classes = 1000, # number of output lcasses
|
||||
num_classes = 1000, # number of output classes
|
||||
tokenize_local_3_conv = False, # whether to use a 3 layer convolution to encode the local tokens from the image. the paper uses this for the smaller models, but uses only 1 conv (set to False) for the larger models
|
||||
use_peg = False, # whether to use positional generating module. they used this for object detection for a boost in performance
|
||||
)
|
||||
@@ -490,12 +520,54 @@ img = torch.randn(1, 3, 224, 224)
|
||||
pred = nest(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## Simple Masked Image Modeling
|
||||
|
||||
<img src="./images/simmim.png" width="400px"/>
|
||||
|
||||
This <a href="https://arxiv.org/abs/2111.09886">paper</a> proposes a simple masked image modeling (SimMIM) scheme, using only a linear projection off the masked tokens into pixel space followed by an L1 loss with the pixel values of the masked patches. Results are competitive with other more complicated approaches.
|
||||
|
||||
You can use this as follows
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch import ViT
|
||||
from vit_pytorch.simmim import SimMIM
|
||||
|
||||
v = ViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048
|
||||
)
|
||||
|
||||
mim = SimMIM(
|
||||
encoder = v,
|
||||
masking_ratio = 0.5 # they found 50% to yield the best results
|
||||
)
|
||||
|
||||
images = torch.randn(8, 3, 256, 256)
|
||||
|
||||
loss = mim(images)
|
||||
loss.backward()
|
||||
|
||||
# that's all!
|
||||
# do the above in a for loop many times with a lot of images and your vision transformer will learn
|
||||
|
||||
torch.save(v.state_dict(), './trained-vit.pt')
|
||||
```
|
||||
|
||||
|
||||
## Masked Autoencoder
|
||||
|
||||
<img src="./images/mae.png" width="400px"/>
|
||||
|
||||
A new <a href="https://arxiv.org/abs/2111.06377">Kaiming He paper</a> proposes a simple autoencoder scheme where the vision transformer attends to a set of unmasked patches, and a smaller decoder tries to reconstruct the masked pixel values.
|
||||
|
||||
<a href="https://www.youtube.com/watch?v=LKixq2S2Pz8">DeepReader quick paper review</a>
|
||||
|
||||
You can use it with the following code
|
||||
|
||||
```python
|
||||
@@ -514,10 +586,9 @@ v = ViT(
|
||||
|
||||
mae = MAE(
|
||||
encoder = v,
|
||||
masking_ratio = 0.75,
|
||||
decoder_dim = 1024,
|
||||
decoder_depth = 6,
|
||||
decoder_heads = 8
|
||||
masking_ratio = 0.75, # the paper recommended 75% masked patches
|
||||
decoder_dim = 512, # paper showed good results with just 512
|
||||
decoder_depth = 6 # anywhere from 1 to 8
|
||||
)
|
||||
|
||||
images = torch.randn(8, 3, 256, 256)
|
||||
@@ -527,6 +598,9 @@ loss.backward()
|
||||
|
||||
# that's all!
|
||||
# do the above in a for loop many times with a lot of images and your vision transformer will learn
|
||||
|
||||
# save your improved vision transformer
|
||||
torch.save(v.state_dict(), './trained-vit.pt')
|
||||
```
|
||||
|
||||
## Masked Patch Prediction
|
||||
@@ -807,13 +881,13 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
## Citations
|
||||
```bibtex
|
||||
@article{hassani2021escaping,
|
||||
title = {Escaping the Big Data Paradigm with Compact Transformers},
|
||||
author = {Ali Hassani and Steven Walton and Nikhil Shah and Abulikemu Abuduweili and Jiachen Li and Humphrey Shi},
|
||||
year = 2021,
|
||||
url = {https://arxiv.org/abs/2104.05704},
|
||||
eprint = {2104.05704},
|
||||
archiveprefix = {arXiv},
|
||||
primaryclass = {cs.CV}
|
||||
title = {Escaping the Big Data Paradigm with Compact Transformers},
|
||||
author = {Ali Hassani and Steven Walton and Nikhil Shah and Abulikemu Abuduweili and Jiachen Li and Humphrey Shi},
|
||||
year = 2021,
|
||||
url = {https://arxiv.org/abs/2104.05704},
|
||||
eprint = {2104.05704},
|
||||
archiveprefix = {arXiv},
|
||||
primaryclass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
@@ -841,10 +915,10 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
|
||||
```bibtex
|
||||
@misc{yuan2021tokenstotoken,
|
||||
title = {Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet},
|
||||
author = {Li Yuan and Yunpeng Chen and Tao Wang and Weihao Yu and Yujun Shi and Francis EH Tay and Jiashi Feng and Shuicheng Yan},
|
||||
year = {2021},
|
||||
eprint = {2101.11986},
|
||||
title = {Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet},
|
||||
author = {Li Yuan and Yunpeng Chen and Tao Wang and Weihao Yu and Yujun Shi and Francis EH Tay and Jiashi Feng and Shuicheng Yan},
|
||||
year = {2021},
|
||||
eprint = {2101.11986},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
@@ -993,6 +1067,17 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{xie2021simmim,
|
||||
title = {SimMIM: A Simple Framework for Masked Image Modeling},
|
||||
author = {Zhenda Xie and Zheng Zhang and Yue Cao and Yutong Lin and Jianmin Bao and Zhuliang Yao and Qi Dai and Han Hu},
|
||||
year = {2021},
|
||||
eprint = {2111.09886},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{vaswani2017attention,
|
||||
title = {Attention Is All You Need},
|
||||
|
||||
BIN
images/simmim.png
Normal file
BIN
images/simmim.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 365 KiB |
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.22.0',
|
||||
version = '0.23.2',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import torch
|
||||
from math import ceil
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from einops import repeat
|
||||
|
||||
from vit_pytorch.vit import Transformer
|
||||
|
||||
|
||||
84
vit_pytorch/simmim.py
Normal file
84
vit_pytorch/simmim.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from einops import repeat
|
||||
|
||||
class SimMIM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
encoder,
|
||||
masking_ratio = 0.5
|
||||
):
|
||||
super().__init__()
|
||||
assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1'
|
||||
self.masking_ratio = masking_ratio
|
||||
|
||||
# extract some hyperparameters and functions from encoder (vision transformer to be trained)
|
||||
|
||||
self.encoder = encoder
|
||||
num_patches, encoder_dim = encoder.pos_embedding.shape[-2:]
|
||||
self.to_patch, self.patch_to_emb = encoder.to_patch_embedding[:2]
|
||||
pixel_values_per_patch = self.patch_to_emb.weight.shape[-1]
|
||||
|
||||
# simple linear head
|
||||
|
||||
self.mask_token = nn.Parameter(torch.randn(encoder_dim))
|
||||
self.to_pixels = nn.Linear(encoder_dim, pixel_values_per_patch)
|
||||
|
||||
def forward(self, img):
|
||||
device = img.device
|
||||
|
||||
# get patches
|
||||
|
||||
patches = self.to_patch(img)
|
||||
batch, num_patches, *_ = patches.shape
|
||||
|
||||
# for indexing purposes
|
||||
|
||||
batch_range = torch.arange(batch, device = device)[:, None]
|
||||
|
||||
# get positions
|
||||
|
||||
pos_emb = self.encoder.pos_embedding[:, 1:(num_patches + 1)]
|
||||
|
||||
# patch to encoder tokens and add positions
|
||||
|
||||
tokens = self.patch_to_emb(patches)
|
||||
tokens = tokens + pos_emb
|
||||
|
||||
# prepare mask tokens
|
||||
|
||||
mask_tokens = repeat(self.mask_token, 'd -> b n d', b = batch, n = num_patches)
|
||||
mask_tokens = mask_tokens + pos_emb
|
||||
|
||||
# calculate of patches needed to be masked, and get positions (indices) to be masked
|
||||
|
||||
num_masked = int(self.masking_ratio * num_patches)
|
||||
masked_indices = torch.rand(batch, num_patches, device = device).topk(k = num_masked, dim = -1).indices
|
||||
masked_bool_mask = torch.zeros((batch, num_patches), device = device).scatter_(-1, masked_indices, 1).bool()
|
||||
|
||||
# mask tokens
|
||||
|
||||
tokens = torch.where(masked_bool_mask[..., None], mask_tokens, tokens)
|
||||
|
||||
# attend with vision transformer
|
||||
|
||||
encoded = self.encoder.transformer(tokens)
|
||||
|
||||
# get the masked tokens
|
||||
|
||||
encoded_mask_tokens = encoded[batch_range, masked_indices]
|
||||
|
||||
# small linear projection for predicted pixel values
|
||||
|
||||
pred_pixel_values = self.to_pixels(encoded_mask_tokens)
|
||||
|
||||
# get the masked patches for the final reconstruction loss
|
||||
|
||||
masked_patches = patches[batch_range, masked_indices]
|
||||
|
||||
# calculate reconstruction loss
|
||||
|
||||
recon_loss = F.l1_loss(pred_pixel_values, masked_patches) / num_masked
|
||||
return recon_loss
|
||||
Reference in New Issue
Block a user