mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 16:12:29 +00:00
Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
681a7407e9 | ||
|
|
86a7302ba6 | ||
|
|
89d3a04b3f | ||
|
|
e7075c64aa | ||
|
|
5ea1559e4c | ||
|
|
f4b0b14094 | ||
|
|
365b4d931e | ||
|
|
79c864d796 | ||
|
|
b45c1356a1 |
73
README.md
73
README.md
@@ -19,9 +19,11 @@
|
||||
- [CrossFormer](#crossformer)
|
||||
- [RegionViT](#regionvit)
|
||||
- [NesT](#nest)
|
||||
- [MobileViT](#mobilevit)
|
||||
- [Masked Autoencoder](#masked-autoencoder)
|
||||
- [Simple Masked Image Modeling](#simple-masked-image-modeling)
|
||||
- [Masked Patch Prediction](#masked-patch-prediction)
|
||||
- [Adaptive Token Sampling](#adaptive-token-sampling)
|
||||
- [Dino](#dino)
|
||||
- [Accessing Attention](#accessing-attention)
|
||||
- [Research Ideas](#research-ideas)
|
||||
@@ -548,6 +550,31 @@ img = torch.randn(1, 3, 224, 224)
|
||||
pred = nest(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## MobileViT
|
||||
|
||||
<img src="./images/mbvit.png" width="400px"></img>
|
||||
|
||||
This <a href="https://arxiv.org/abs/2110.02178">paper</a> introduce MobileViT, a light-weight and generalpurpose vision transformer for mobile devices. MobileViT presents a different
|
||||
perspective for the global processing of information with transformers.
|
||||
|
||||
You can use it with the following code (ex. mobilevit_xs)
|
||||
|
||||
```
|
||||
import torch
|
||||
from vit_pytorch.mobile_vit import MobileViT
|
||||
|
||||
mbvit_xs = MobileViT(
|
||||
image_size=(256, 256),
|
||||
dims = [96, 120, 144],
|
||||
channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384],
|
||||
num_classes = 1000
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 256, 256)
|
||||
|
||||
pred = mbvit_xs(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## Simple Masked Image Modeling
|
||||
|
||||
<img src="./images/simmim.png" width="400px"/>
|
||||
@@ -596,6 +623,8 @@ A new <a href="https://arxiv.org/abs/2111.06377">Kaiming He paper</a> proposes a
|
||||
|
||||
<a href="https://www.youtube.com/watch?v=LKixq2S2Pz8">DeepReader quick paper review</a>
|
||||
|
||||
<a href="https://www.youtube.com/watch?v=Dp6iICL2dVI">AI Coffeebreak with Letitia</a>
|
||||
|
||||
You can use it with the following code
|
||||
|
||||
```python
|
||||
@@ -677,6 +706,39 @@ for _ in range(100):
|
||||
torch.save(model.state_dict(), './pretrained-net.pt')
|
||||
```
|
||||
|
||||
## Adaptive Token Sampling
|
||||
|
||||
<img src="./images/ats.png" width="400px"></img>
|
||||
|
||||
This <a href="https://arxiv.org/abs/2111.15667">paper</a> proposes to use the CLS attention scores, re-weighed by the norms of the value heads, as means to discard unimportant tokens at different layers.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.ats_vit import ViT
|
||||
|
||||
v = ViT(
|
||||
image_size = 256,
|
||||
patch_size = 16,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
max_tokens_per_depth = (256, 128, 64, 32, 16, 8), # a tuple that denotes the maximum number of tokens that any given layer should have. if the layer has greater than this amount, it will undergo adaptive token sampling
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
img = torch.randn(4, 3, 256, 256)
|
||||
|
||||
preds = v(img) # (1, 1000)
|
||||
|
||||
# you can also get a list of the final sampled patch ids
|
||||
# a value of -1 denotes padding
|
||||
|
||||
preds, token_ids = v(img, return_sampled_token_ids = True) # (1, 1000), (1, <=8)
|
||||
```
|
||||
|
||||
## Dino
|
||||
|
||||
<img src="./images/dino.png" width="350px"></img>
|
||||
@@ -1117,6 +1179,17 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{fayyaz2021ats,
|
||||
title = {ATS: Adaptive Token Sampling For Efficient Vision Transformers},
|
||||
author = {Mohsen Fayyaz and Soroush Abbasi Kouhpayegani and Farnoush Rezaei Jafari and Eric Sommerlade and Hamid Reza Vaezi Joze and Hamed Pirsiavash and Juergen Gall},
|
||||
year = {2021},
|
||||
eprint = {2111.15667},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{vaswani2017attention,
|
||||
title = {Attention Is All You Need},
|
||||
|
||||
BIN
images/ats.png
Normal file
BIN
images/ats.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 198 KiB |
BIN
images/mbvit.png
Normal file
BIN
images/mbvit.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 206 KiB |
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.24.2',
|
||||
version = '0.25.0',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
|
||||
262
vit_pytorch/ats_vit.py
Normal file
262
vit_pytorch/ats_vit.py
Normal file
@@ -0,0 +1,262 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from torch import nn, einsum
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
# adaptive token sampling functions and classes
|
||||
|
||||
def log(t, eps = 1e-6):
|
||||
return torch.log(t + eps)
|
||||
|
||||
def sample_gumbel(shape, device, dtype, eps = 1e-6):
|
||||
u = torch.empty(shape, device = device, dtype = dtype).uniform_(0, 1)
|
||||
return -log(-log(u, eps), eps)
|
||||
|
||||
def batched_index_select(values, indices, dim = 1):
|
||||
value_dims = values.shape[(dim + 1):]
|
||||
values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
|
||||
indices = indices[(..., *((None,) * len(value_dims)))]
|
||||
indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
|
||||
value_expand_len = len(indices_shape) - (dim + 1)
|
||||
values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)]
|
||||
|
||||
value_expand_shape = [-1] * len(values.shape)
|
||||
expand_slice = slice(dim, (dim + value_expand_len))
|
||||
value_expand_shape[expand_slice] = indices.shape[expand_slice]
|
||||
values = values.expand(*value_expand_shape)
|
||||
|
||||
dim += value_expand_len
|
||||
return values.gather(dim, indices)
|
||||
|
||||
class AdaptiveTokenSampling(nn.Module):
|
||||
def __init__(self, output_num_tokens, eps = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.output_num_tokens = output_num_tokens
|
||||
|
||||
def forward(self, attn, value, mask):
|
||||
heads, output_num_tokens, eps, device, dtype = attn.shape[1], self.output_num_tokens, self.eps, attn.device, attn.dtype
|
||||
|
||||
# first get the attention values for CLS token to all other tokens
|
||||
|
||||
cls_attn = attn[..., 0, 1:]
|
||||
|
||||
# calculate the norms of the values, for weighting the scores, as described in the paper
|
||||
|
||||
value_norms = value[..., 1:, :].norm(dim = -1)
|
||||
|
||||
# weigh the attention scores by the norm of the values, sum across all heads
|
||||
|
||||
cls_attn = einsum('b h n, b h n -> b n', cls_attn, value_norms)
|
||||
|
||||
# normalize to 1
|
||||
|
||||
normed_cls_attn = cls_attn / (cls_attn.sum(dim = -1, keepdim = True) + eps)
|
||||
|
||||
# instead of using inverse transform sampling, going to invert the softmax and use gumbel-max sampling instead
|
||||
|
||||
pseudo_logits = log(normed_cls_attn)
|
||||
|
||||
# mask out pseudo logits for gumbel-max sampling
|
||||
|
||||
mask_without_cls = mask[:, 1:]
|
||||
mask_value = -torch.finfo(attn.dtype).max / 2
|
||||
pseudo_logits = pseudo_logits.masked_fill(~mask_without_cls, mask_value)
|
||||
|
||||
# expand k times, k being the adaptive sampling number
|
||||
|
||||
pseudo_logits = repeat(pseudo_logits, 'b n -> b k n', k = output_num_tokens)
|
||||
pseudo_logits = pseudo_logits + sample_gumbel(pseudo_logits.shape, device = device, dtype = dtype)
|
||||
|
||||
# gumble-max and add one to reserve 0 for padding / mask
|
||||
|
||||
sampled_token_ids = pseudo_logits.argmax(dim = -1) + 1
|
||||
|
||||
# calculate unique using torch.unique and then pad the sequence from the right
|
||||
|
||||
unique_sampled_token_ids_list = [torch.unique(t, sorted = True) for t in torch.unbind(sampled_token_ids)]
|
||||
unique_sampled_token_ids = pad_sequence(unique_sampled_token_ids_list, batch_first = True)
|
||||
|
||||
# calculate the new mask, based on the padding
|
||||
|
||||
new_mask = unique_sampled_token_ids != 0
|
||||
|
||||
# CLS token never gets masked out (gets a value of True)
|
||||
|
||||
new_mask = F.pad(new_mask, (1, 0), value = True)
|
||||
|
||||
# prepend a 0 token id to keep the CLS attention scores
|
||||
|
||||
unique_sampled_token_ids = F.pad(unique_sampled_token_ids, (1, 0), value = 0)
|
||||
expanded_unique_sampled_token_ids = repeat(unique_sampled_token_ids, 'b n -> b h n', h = heads)
|
||||
|
||||
# gather the new attention scores
|
||||
|
||||
new_attn = batched_index_select(attn, expanded_unique_sampled_token_ids, dim = 2)
|
||||
|
||||
# return the sampled attention scores, new mask (denoting padding), as well as the sampled token indices (for the residual)
|
||||
return new_attn, new_mask, unique_sampled_token_ids
|
||||
|
||||
# classes
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., output_num_tokens = None):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.output_num_tokens = output_num_tokens
|
||||
self.ats = AdaptiveTokenSampling(output_num_tokens) if exists(output_num_tokens) else None
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, *, mask):
|
||||
num_tokens = x.shape[1]
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
if exists(mask):
|
||||
dots_mask = rearrange(mask, 'b i -> b 1 i 1') * rearrange(mask, 'b j -> b 1 1 j')
|
||||
mask_value = -torch.finfo(dots.dtype).max
|
||||
dots = dots.masked_fill(~dots_mask, mask_value)
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
sampled_token_ids = None
|
||||
|
||||
# if adaptive token sampling is enabled
|
||||
# and number of tokens is greater than the number of output tokens
|
||||
if exists(self.output_num_tokens) and (num_tokens - 1) > self.output_num_tokens:
|
||||
attn, mask, sampled_token_ids = self.ats(attn, v, mask = mask)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
|
||||
return self.to_out(out), mask, sampled_token_ids
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, max_tokens_per_depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
assert len(max_tokens_per_depth) == depth, 'max_tokens_per_depth must be a tuple of length that is equal to the depth of the transformer'
|
||||
assert sorted(max_tokens_per_depth, reverse = True) == list(max_tokens_per_depth), 'max_tokens_per_depth must be in decreasing order'
|
||||
assert min(max_tokens_per_depth) > 0, 'max_tokens_per_depth must have at least 1 token at any layer'
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
for _, output_num_tokens in zip(range(depth), max_tokens_per_depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, output_num_tokens = output_num_tokens, heads = heads, dim_head = dim_head, dropout = dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
b, n, device = *x.shape[:2], x.device
|
||||
|
||||
# use mask to keep track of the paddings when sampling tokens
|
||||
# as the duplicates (when sampling) are just removed, as mentioned in the paper
|
||||
mask = torch.ones((b, n), device = device, dtype = torch.bool)
|
||||
|
||||
token_ids = torch.arange(n, device = device)
|
||||
token_ids = repeat(token_ids, 'n -> b n', b = b)
|
||||
|
||||
for attn, ff in self.layers:
|
||||
attn_out, mask, sampled_token_ids = attn(x, mask = mask)
|
||||
|
||||
# when token sampling, one needs to then gather the residual tokens with the sampled token ids
|
||||
if exists(sampled_token_ids):
|
||||
x = batched_index_select(x, sampled_token_ids, dim = 1)
|
||||
token_ids = batched_index_select(token_ids, sampled_token_ids, dim = 1)
|
||||
|
||||
x = x + attn_out
|
||||
|
||||
x = ff(x) + x
|
||||
|
||||
return x, token_ids
|
||||
|
||||
class ViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, max_tokens_per_depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
|
||||
nn.Linear(patch_dim, dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, max_tokens_per_depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img, return_sampled_token_ids = False):
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.dropout(x)
|
||||
|
||||
x, token_ids = self.transformer(x)
|
||||
|
||||
logits = self.mlp_head(x[:, 0])
|
||||
|
||||
if return_sampled_token_ids:
|
||||
# remove CLS token and decrement by 1 to make -1 the padding
|
||||
token_ids = token_ids[:, 1:] - 1
|
||||
return logits, token_ids
|
||||
|
||||
return logits
|
||||
@@ -6,18 +6,9 @@ import torch.nn.functional as F
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def cast_tuple(val, length = 1):
|
||||
return val if isinstance(val, tuple) else ((val,) * length)
|
||||
|
||||
def divisible_by(val, d):
|
||||
return (val % d) == 0
|
||||
|
||||
# cross embed layer
|
||||
|
||||
class CrossEmbedLayer(nn.Module):
|
||||
|
||||
229
vit_pytorch/mobile_vit.py
Normal file
229
vit_pytorch/mobile_vit.py
Normal file
@@ -0,0 +1,229 @@
|
||||
"""
|
||||
An implementation of MobileViT Model as defined in:
|
||||
MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer
|
||||
Arxiv: https://arxiv.org/abs/2110.02178
|
||||
Origin Code: https://github.com/murufeng/awesome_lightweight_networks
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
def _make_divisible(v, divisor, min_value=None):
|
||||
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
def Conv_BN_ReLU(inp, oup, kernel, stride=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, kernel_size=kernel, stride=stride, padding=1, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
nn.ReLU6(inplace=True)
|
||||
)
|
||||
|
||||
|
||||
def conv_1x1_bn(inp, oup):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
nn.ReLU6(inplace=True)
|
||||
)
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout=0.):
|
||||
super().__init__()
|
||||
self.ffn = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.ffn(x)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim=-1)
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
attn = self.attend(dots)
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b p h n d -> b p n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
class MV2Block(nn.Module):
|
||||
def __init__(self, inp, oup, stride=1, expand_ratio=4):
|
||||
super(MV2Block, self).__init__()
|
||||
assert stride in [1, 2]
|
||||
|
||||
hidden_dim = round(inp * expand_ratio)
|
||||
self.identity = stride == 1 and inp == oup
|
||||
|
||||
if expand_ratio == 1:
|
||||
self.conv = nn.Sequential(
|
||||
# dw
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.SiLU(),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
)
|
||||
else:
|
||||
self.conv = nn.Sequential(
|
||||
# pw
|
||||
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.SiLU(),
|
||||
# dw
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.SiLU(),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if self.identity:
|
||||
return x + self.conv(x)
|
||||
else:
|
||||
return self.conv(x)
|
||||
|
||||
class MobileViTBlock(nn.Module):
|
||||
def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):
|
||||
super().__init__()
|
||||
self.ph, self.pw = patch_size
|
||||
|
||||
self.conv1 = Conv_BN_ReLU(channel, channel, kernel_size)
|
||||
self.conv2 = conv_1x1_bn(channel, dim)
|
||||
|
||||
self.transformer = Transformer(dim, depth, 1, 32, mlp_dim, dropout)
|
||||
|
||||
self.conv3 = conv_1x1_bn(dim, channel)
|
||||
self.conv4 = Conv_BN_ReLU(2 * channel, channel, kernel_size)
|
||||
|
||||
def forward(self, x):
|
||||
y = x.clone()
|
||||
|
||||
# Local representations
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
|
||||
# Global representations
|
||||
_, _, h, w = x.shape
|
||||
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
|
||||
x = self.transformer(x)
|
||||
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h // self.ph, w=w // self.pw, ph=self.ph,
|
||||
pw=self.pw)
|
||||
|
||||
# Fusion
|
||||
x = self.conv3(x)
|
||||
x = torch.cat((x, y), 1)
|
||||
x = self.conv4(x)
|
||||
return x
|
||||
|
||||
|
||||
class MobileViT(nn.Module):
|
||||
def __init__(self, image_size, dims, channels, num_classes, expansion=4, kernel_size=3, patch_size=(2, 2)):
|
||||
super().__init__()
|
||||
ih, iw = image_size
|
||||
ph, pw = patch_size
|
||||
assert ih % ph == 0 and iw % pw == 0
|
||||
|
||||
L = [2, 4, 3]
|
||||
|
||||
self.conv1 = Conv_BN_ReLU(3, channels[0], kernel=3, stride=2)
|
||||
|
||||
self.mv2 = nn.ModuleList([])
|
||||
self.mv2.append(MV2Block(channels[0], channels[1], 1, expansion))
|
||||
self.mv2.append(MV2Block(channels[1], channels[2], 2, expansion))
|
||||
self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion))
|
||||
self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion))
|
||||
self.mv2.append(MV2Block(channels[3], channels[4], 2, expansion))
|
||||
self.mv2.append(MV2Block(channels[5], channels[6], 2, expansion))
|
||||
self.mv2.append(MV2Block(channels[7], channels[8], 2, expansion))
|
||||
|
||||
self.mvit = nn.ModuleList([])
|
||||
self.mvit.append(MobileViTBlock(dims[0], L[0], channels[5], kernel_size, patch_size, int(dims[0] * 2)))
|
||||
self.mvit.append(MobileViTBlock(dims[1], L[1], channels[7], kernel_size, patch_size, int(dims[1] * 4)))
|
||||
self.mvit.append(MobileViTBlock(dims[2], L[2], channels[9], kernel_size, patch_size, int(dims[2] * 4)))
|
||||
|
||||
self.conv2 = conv_1x1_bn(channels[-2], channels[-1])
|
||||
|
||||
self.pool = nn.AvgPool2d(ih // 32, 1)
|
||||
self.fc = nn.Linear(channels[-1], num_classes, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.mv2[0](x)
|
||||
|
||||
x = self.mv2[1](x)
|
||||
x = self.mv2[2](x)
|
||||
x = self.mv2[3](x)
|
||||
|
||||
x = self.mv2[4](x)
|
||||
x = self.mvit[0](x)
|
||||
|
||||
x = self.mv2[5](x)
|
||||
x = self.mvit[1](x)
|
||||
|
||||
x = self.mv2[6](x)
|
||||
x = self.mvit[2](x)
|
||||
x = self.conv2(x)
|
||||
|
||||
x = self.pool(x).view(-1, x.shape[1])
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
Reference in New Issue
Block a user