Compare commits

...

5 Commits

Author SHA1 Message Date
Phil Wang
0b347973fe add EsViT, by popular request, an alternative to Dino that is compatible with efficient ViTs with accounting for regional self-supervised loss 2022-05-03 10:18:36 -07:00
Zhengzhong Tu
c2aab05ebf fix bibtex typo (#212) 2022-04-06 22:15:05 -07:00
Phil Wang
81661e3966 fix mbconv residual block 2022-04-06 16:43:06 -07:00
Phil Wang
13f8e123bb fix maxvit - need feedforwards after attention 2022-04-06 16:34:40 -07:00
Phil Wang
2d4089c88e link to maxvit in readme 2022-04-06 16:24:12 -07:00
4 changed files with 100 additions and 7 deletions

View File

@@ -32,6 +32,7 @@
- [Parallel ViT](#parallel-vit)
- [Learnable Memory ViT](#learnable-memory-vit)
- [Dino](#dino)
- [EsViT](#esvit)
- [Accessing Attention](#accessing-attention)
- [Research Ideas](#research-ideas)
* [Efficient Attention](#efficient-attention)
@@ -601,7 +602,7 @@ preds = v(img) # (1, 1000)
<img src="./images/max-vit.png" width="400px"></img>
This paper proposes a hybrid convolutional / attention network, using MBConv from the convolution side, and then block / grid axial sparse attention.
<a href="https://arxiv.org/abs/2204.01697">This paper</a> proposes a hybrid convolutional / attention network, using MBConv from the convolution side, and then block / grid axial sparse attention.
They also claim this specific vision transformer is good for generative models (GANs).
@@ -1076,6 +1077,80 @@ for _ in range(100):
torch.save(model.state_dict(), './pretrained-net.pt')
```
## EsViT
<img src="./images/esvit.png" width="350px"></img>
`EsViT` is a variant of Dino (from above) re-engineered to support efficient `ViT`s with patch merging / downsampling by taking into an account an extra regional loss between the augmented views. To quote the abstract, it `outperforms its supervised counterpart on 17 out of 18 datasets` at 3 times higher throughput.
Even though it is named as though it were a new `ViT` variant, it actually is just a strategy for training any multistage `ViT` (in the paper, they focused on Swin). The example below will show how to use it with `CvT`. You'll need to set the `hidden_layer` to the name of the layer within your efficient ViT that outputs the non-average pooled visual representations, just before the global pooling and projection to logits.
```python
import torch
from vit_pytorch.cvt import CvT
from vit_pytorch.es_vit import EsViTTrainer
cvt = CvT(
num_classes = 1000,
s1_emb_dim = 64,
s1_emb_kernel = 7,
s1_emb_stride = 4,
s1_proj_kernel = 3,
s1_kv_proj_stride = 2,
s1_heads = 1,
s1_depth = 1,
s1_mlp_mult = 4,
s2_emb_dim = 192,
s2_emb_kernel = 3,
s2_emb_stride = 2,
s2_proj_kernel = 3,
s2_kv_proj_stride = 2,
s2_heads = 3,
s2_depth = 2,
s2_mlp_mult = 4,
s3_emb_dim = 384,
s3_emb_kernel = 3,
s3_emb_stride = 2,
s3_proj_kernel = 3,
s3_kv_proj_stride = 2,
s3_heads = 4,
s3_depth = 10,
s3_mlp_mult = 4,
dropout = 0.
)
learner = EsViTTrainer(
cvt,
image_size = 256,
hidden_layer = 'layers', # hidden layer name or index, from which to extract the embedding
projection_hidden_size = 256, # projector network hidden dimension
projection_layers = 4, # number of layers in projection network
num_classes_K = 65336, # output logits dimensions (referenced as K in paper)
student_temp = 0.9, # student temperature
teacher_temp = 0.04, # teacher temperature, needs to be annealed from 0.04 to 0.07 over 30 epochs
local_upper_crop_scale = 0.4, # upper bound for local crop - 0.4 was recommended in the paper
global_lower_crop_scale = 0.5, # lower bound for global crop - 0.5 was recommended in the paper
moving_average_decay = 0.9, # moving average of encoder - paper showed anywhere from 0.9 to 0.999 was ok
center_moving_average_decay = 0.9, # moving average of teacher centers - paper showed anywhere from 0.9 to 0.999 was ok
)
opt = torch.optim.AdamW(learner.parameters(), lr = 3e-4)
def sample_unlabelled_images():
return torch.randn(8, 3, 256, 256)
for _ in range(1000):
images = sample_unlabelled_images()
loss = learner(images)
opt.zero_grad()
loss.backward()
opt.step()
learner.update_moving_average() # update moving average of teacher encoder and teacher centers
# save your improved network
torch.save(cvt.state_dict(), './pretrained-net.pt')
```
## Accessing Attention
If you would like to visualize the attention weights (post-softmax) for your research, just follow the procedure below
@@ -1579,7 +1654,7 @@ Coming from computer vision and new to transformers? Here are some resources tha
```bibtex
@inproceedings{Tu2022MaxViTMV,
title = {MaxViT: Multi-Axis Vision Transformer},
author = {Zhe-Wei Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
author = {Zhengzhong Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
year = {2022}
}
```

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.33.0',
version = '0.34.0',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',

View File

@@ -164,12 +164,14 @@ class CvT(nn.Module):
dim = config['emb_dim']
self.layers = nn.Sequential(
*layers,
self.layers = nn.Sequential(*layers)
self.to_logits = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
Rearrange('... () () -> ...'),
nn.Linear(dim, num_classes)
)
def forward(self, x):
return self.layers(x)
latents = self.layers(x)
return self.to_logits(latents)

View File

@@ -28,6 +28,20 @@ class PreNormResidual(nn.Module):
def forward(self, x):
return self.fn(self.norm(x)) + x
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
inner_dim = int(dim * mult)
self.net = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
# MBConv
class SqueezeExcitation(nn.Module):
@@ -57,7 +71,7 @@ class MBConvResidual(nn.Module):
def forward(self, x):
out = self.fn(x)
out = self.dropsample(out)
return out
return out + x
class Dropsample(nn.Module):
def __init__(self, prob = 0):
@@ -244,10 +258,12 @@ class MaxViT(nn.Module):
),
Rearrange('b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w), # block-like attention
PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
PreNormResidual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
Rearrange('b x y w1 w2 d -> b d (x w1) (y w2)'),
Rearrange('b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w), # grid-like attention
PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
PreNormResidual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
Rearrange('b x y w1 w2 d -> b d (w1 x) (w2 y)'),
)