release MobileViT, from @murufeng

This commit is contained in:
Phil Wang
2021-12-21 10:22:59 -08:00
parent 86a7302ba6
commit b983bbee39
3 changed files with 72 additions and 51 deletions

View File

@@ -554,17 +554,17 @@ pred = nest(img) # (1, 1000)
<img src="./images/mbvit.png" width="400px"></img>
This <a href="https://arxiv.org/abs/2110.02178">paper</a> introduce MobileViT, a light-weight and generalpurpose vision transformer for mobile devices. MobileViT presents a different
This <a href="https://arxiv.org/abs/2110.02178">paper</a> introduce MobileViT, a light-weight and general purpose vision transformer for mobile devices. MobileViT presents a different
perspective for the global processing of information with transformers.
You can use it with the following code (ex. mobilevit_xs)
```
```python
import torch
from vit_pytorch.mobile_vit import MobileViT
mbvit_xs = MobileViT(
image_size=(256, 256),
image_size = (256, 256),
dims = [96, 120, 144],
channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384],
num_classes = 1000
@@ -1190,6 +1190,17 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```
```bibtex
@misc{mehta2021mobilevit,
title = {MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer},
author = {Sachin Mehta and Mohammad Rastegari},
year = {2021},
eprint = {2110.02178},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@misc{vaswani2017attention,
title = {Attention Is All You Need},

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.24.3',
version = '0.25.0',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',

View File

@@ -9,9 +9,9 @@ import torch
import torch.nn as nn
from einops import rearrange
from einops.layers.torch import Reduce
def _make_divisible(v, divisor, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
@@ -20,7 +20,7 @@ def _make_divisible(v, divisor, min_value=None):
return new_v
def Conv_BN_ReLU(inp, oup, kernel, stride=1):
def conv_bn_relu(inp, oup, kernel, stride=1):
return nn.Sequential(
nn.Conv2d(inp, oup, kernel_size=kernel, stride=stride, padding=1, bias=False),
nn.BatchNorm2d(oup),
@@ -63,8 +63,6 @@ class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
@@ -74,7 +72,7 @@ class Attention(nn.Module):
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
)
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim=-1)
@@ -96,6 +94,7 @@ class Transformer(nn.Module):
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
@@ -136,23 +135,24 @@ class MV2Block(nn.Module):
)
def forward(self, x):
out = self.conv(x)
if self.identity:
return x + self.conv(x)
else:
return self.conv(x)
out = out + x
return out
class MobileViTBlock(nn.Module):
def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):
super().__init__()
self.ph, self.pw = patch_size
self.conv1 = Conv_BN_ReLU(channel, channel, kernel_size)
self.conv1 = conv_bn_relu(channel, channel, kernel_size)
self.conv2 = conv_1x1_bn(channel, dim)
self.transformer = Transformer(dim, depth, 1, 32, mlp_dim, dropout)
self.conv3 = conv_1x1_bn(dim, channel)
self.conv4 = Conv_BN_ReLU(2 * channel, channel, kernel_size)
self.conv4 = conv_bn_relu(2 * channel, channel, kernel_size)
def forward(self, x):
y = x.clone()
@@ -165,8 +165,7 @@ class MobileViTBlock(nn.Module):
_, _, h, w = x.shape
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
x = self.transformer(x)
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h // self.ph, w=w // self.pw, ph=self.ph,
pw=self.pw)
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h // self.ph, w=w // self.pw, ph=self.ph, pw=self.pw)
# Fusion
x = self.conv3(x)
@@ -176,54 +175,65 @@ class MobileViTBlock(nn.Module):
class MobileViT(nn.Module):
def __init__(self, image_size, dims, channels, num_classes, expansion=4, kernel_size=3, patch_size=(2, 2)):
def __init__(
self,
image_size,
dims,
channels,
num_classes,
expansion = 4,
kernel_size = 3,
patch_size = (2, 2),
depths = (2, 4, 3)
):
super().__init__()
assert len(dims) == 3, 'dims must be a tuple of 3'
assert len(depths) == 3, 'depths must be a tuple of 3'
ih, iw = image_size
ph, pw = patch_size
assert ih % ph == 0 and iw % pw == 0
L = [2, 4, 3]
init_dim, *_, last_dim = channels
self.conv1 = Conv_BN_ReLU(3, channels[0], kernel=3, stride=2)
self.conv1 = conv_bn_relu(3, init_dim, kernel=3, stride=2)
self.mv2 = nn.ModuleList([])
self.mv2.append(MV2Block(channels[0], channels[1], 1, expansion))
self.mv2.append(MV2Block(channels[1], channels[2], 2, expansion))
self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion))
self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion))
self.mv2.append(MV2Block(channels[3], channels[4], 2, expansion))
self.mv2.append(MV2Block(channels[5], channels[6], 2, expansion))
self.mv2.append(MV2Block(channels[7], channels[8], 2, expansion))
self.stem = nn.ModuleList([])
self.stem.append(MV2Block(channels[0], channels[1], 1, expansion))
self.stem.append(MV2Block(channels[1], channels[2], 2, expansion))
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
self.trunk = nn.ModuleList([])
self.trunk.append(nn.ModuleList([
MV2Block(channels[3], channels[4], 2, expansion),
MobileViTBlock(dims[0], depths[0], channels[5], kernel_size, patch_size, int(dims[0] * 2))
]))
self.mvit = nn.ModuleList([])
self.mvit.append(MobileViTBlock(dims[0], L[0], channels[5], kernel_size, patch_size, int(dims[0] * 2)))
self.mvit.append(MobileViTBlock(dims[1], L[1], channels[7], kernel_size, patch_size, int(dims[1] * 4)))
self.mvit.append(MobileViTBlock(dims[2], L[2], channels[9], kernel_size, patch_size, int(dims[2] * 4)))
self.trunk.append(nn.ModuleList([
MV2Block(channels[5], channels[6], 2, expansion),
MobileViTBlock(dims[1], depths[1], channels[7], kernel_size, patch_size, int(dims[1] * 4))
]))
self.conv2 = conv_1x1_bn(channels[-2], channels[-1])
self.trunk.append(nn.ModuleList([
MV2Block(channels[7], channels[8], 2, expansion),
MobileViTBlock(dims[2], depths[2], channels[9], kernel_size, patch_size, int(dims[2] * 4))
]))
self.pool = nn.AvgPool2d(ih // 32, 1)
self.fc = nn.Linear(channels[-1], num_classes, bias=False)
self.to_logits = nn.Sequential(
conv_1x1_bn(channels[-2], last_dim),
Reduce('b c h w -> b c', 'mean'),
nn.Linear(channels[-1], num_classes, bias=False)
)
def forward(self, x):
x = self.conv1(x)
x = self.mv2[0](x)
x = self.mv2[1](x)
x = self.mv2[2](x)
x = self.mv2[3](x)
for conv in self.stem:
x = conv(x)
x = self.mv2[4](x)
x = self.mvit[0](x)
x = self.mv2[5](x)
x = self.mvit[1](x)
x = self.mv2[6](x)
x = self.mvit[2](x)
x = self.conv2(x)
x = self.pool(x).view(-1, x.shape[1])
x = self.fc(x)
return x
for conv, attn in self.trunk:
x = conv(x)
x = attn(x)
return self.to_logits(x)