Compare commits

..

3 Commits

Author SHA1 Message Date
Phil Wang
563f23886d add some tests 2021-12-22 09:12:44 -08:00
Phil Wang
2c368d1d4e add extractor wrapper 2021-12-21 11:11:39 -08:00
Phil Wang
b983bbee39 release MobileViT, from @murufeng 2021-12-21 10:22:59 -08:00
5 changed files with 194 additions and 51 deletions

33
.github/workflows/python-test.yml vendored Normal file
View File

@@ -0,0 +1,33 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: Test
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.7, 3.8, 3.9]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install pytest
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Test with pytest
run: |
python setup.py test

View File

@@ -554,17 +554,17 @@ pred = nest(img) # (1, 1000)
<img src="./images/mbvit.png" width="400px"></img>
This <a href="https://arxiv.org/abs/2110.02178">paper</a> introduce MobileViT, a light-weight and generalpurpose vision transformer for mobile devices. MobileViT presents a different
This <a href="https://arxiv.org/abs/2110.02178">paper</a> introduce MobileViT, a light-weight and general purpose vision transformer for mobile devices. MobileViT presents a different
perspective for the global processing of information with transformers.
You can use it with the following code (ex. mobilevit_xs)
```
```python
import torch
from vit_pytorch.mobile_vit import MobileViT
mbvit_xs = MobileViT(
image_size=(256, 256),
image_size = (256, 256),
dims = [96, 120, 144],
channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384],
num_classes = 1000
@@ -834,6 +834,41 @@ to cleanup the class and the hooks once you have collected enough data
v = v.eject() # wrapper is discarded and original ViT instance is returned
```
## Accessing Embeddings
You can similarly access the embeddings with the `Extractor` wrapper
```python
import torch
from vit_pytorch.vit import ViT
v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
# import Recorder and wrap the ViT
from vit_pytorch.extractor import Extractor
v = Extractor(v)
# forward pass now returns predictions and the attention maps
img = torch.randn(1, 3, 256, 256)
logits, embeddings = v(img)
# there is one extra token due to the CLS token
embeddings # (1, 65, 1024) - (batch x patches x model dim)
```
## Research Ideas
### Efficient Attention
@@ -1190,6 +1225,17 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```
```bibtex
@misc{mehta2021mobilevit,
title = {MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer},
author = {Sachin Mehta and Mohammad Rastegari},
year = {2021},
eprint = {2110.02178},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@misc{vaswani2017attention,
title = {Attention Is All You Need},

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.25.0',
version = '0.25.2',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',
@@ -19,6 +19,12 @@ setup(
'torch>=1.6',
'torchvision'
],
setup_requires=[
'pytest-runner',
],
tests_require=[
'pytest'
],
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',

48
vit_pytorch/extractor.py Normal file
View File

@@ -0,0 +1,48 @@
import torch
from torch import nn
def exists(val):
return val is not None
class Extractor(nn.Module):
def __init__(self, vit, device = None):
super().__init__()
self.vit = vit
self.data = None
self.latents = None
self.hooks = []
self.hook_registered = False
self.ejected = False
self.device = device
def _hook(self, _, input, output):
self.latents = output.clone().detach()
def _register_hook(self):
handle = self.vit.transformer.register_forward_hook(self._hook)
self.hooks.append(handle)
self.hook_registered = True
def eject(self):
self.ejected = True
for hook in self.hooks:
hook.remove()
self.hooks.clear()
return self.vit
def clear(self):
del self.latents
self.latents = None
def forward(self, img):
assert not self.ejected, 'extractor has been ejected, cannot be used anymore'
self.clear()
if not self.hook_registered:
self._register_hook()
pred = self.vit(img)
target_device = self.device if exists(self.device) else img.device
latents = self.latents.to(target_device)
return pred, latents

View File

@@ -9,9 +9,9 @@ import torch
import torch.nn as nn
from einops import rearrange
from einops.layers.torch import Reduce
def _make_divisible(v, divisor, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
@@ -20,7 +20,7 @@ def _make_divisible(v, divisor, min_value=None):
return new_v
def Conv_BN_ReLU(inp, oup, kernel, stride=1):
def conv_bn_relu(inp, oup, kernel, stride=1):
return nn.Sequential(
nn.Conv2d(inp, oup, kernel_size=kernel, stride=stride, padding=1, bias=False),
nn.BatchNorm2d(oup),
@@ -63,8 +63,6 @@ class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
@@ -74,7 +72,7 @@ class Attention(nn.Module):
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
)
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim=-1)
@@ -96,6 +94,7 @@ class Transformer(nn.Module):
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
@@ -136,23 +135,24 @@ class MV2Block(nn.Module):
)
def forward(self, x):
out = self.conv(x)
if self.identity:
return x + self.conv(x)
else:
return self.conv(x)
out = out + x
return out
class MobileViTBlock(nn.Module):
def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):
super().__init__()
self.ph, self.pw = patch_size
self.conv1 = Conv_BN_ReLU(channel, channel, kernel_size)
self.conv1 = conv_bn_relu(channel, channel, kernel_size)
self.conv2 = conv_1x1_bn(channel, dim)
self.transformer = Transformer(dim, depth, 1, 32, mlp_dim, dropout)
self.conv3 = conv_1x1_bn(dim, channel)
self.conv4 = Conv_BN_ReLU(2 * channel, channel, kernel_size)
self.conv4 = conv_bn_relu(2 * channel, channel, kernel_size)
def forward(self, x):
y = x.clone()
@@ -165,8 +165,7 @@ class MobileViTBlock(nn.Module):
_, _, h, w = x.shape
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
x = self.transformer(x)
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h // self.ph, w=w // self.pw, ph=self.ph,
pw=self.pw)
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h // self.ph, w=w // self.pw, ph=self.ph, pw=self.pw)
# Fusion
x = self.conv3(x)
@@ -176,54 +175,65 @@ class MobileViTBlock(nn.Module):
class MobileViT(nn.Module):
def __init__(self, image_size, dims, channels, num_classes, expansion=4, kernel_size=3, patch_size=(2, 2)):
def __init__(
self,
image_size,
dims,
channels,
num_classes,
expansion = 4,
kernel_size = 3,
patch_size = (2, 2),
depths = (2, 4, 3)
):
super().__init__()
assert len(dims) == 3, 'dims must be a tuple of 3'
assert len(depths) == 3, 'depths must be a tuple of 3'
ih, iw = image_size
ph, pw = patch_size
assert ih % ph == 0 and iw % pw == 0
L = [2, 4, 3]
init_dim, *_, last_dim = channels
self.conv1 = Conv_BN_ReLU(3, channels[0], kernel=3, stride=2)
self.conv1 = conv_bn_relu(3, init_dim, kernel=3, stride=2)
self.mv2 = nn.ModuleList([])
self.mv2.append(MV2Block(channels[0], channels[1], 1, expansion))
self.mv2.append(MV2Block(channels[1], channels[2], 2, expansion))
self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion))
self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion))
self.mv2.append(MV2Block(channels[3], channels[4], 2, expansion))
self.mv2.append(MV2Block(channels[5], channels[6], 2, expansion))
self.mv2.append(MV2Block(channels[7], channels[8], 2, expansion))
self.stem = nn.ModuleList([])
self.stem.append(MV2Block(channels[0], channels[1], 1, expansion))
self.stem.append(MV2Block(channels[1], channels[2], 2, expansion))
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
self.trunk = nn.ModuleList([])
self.trunk.append(nn.ModuleList([
MV2Block(channels[3], channels[4], 2, expansion),
MobileViTBlock(dims[0], depths[0], channels[5], kernel_size, patch_size, int(dims[0] * 2))
]))
self.mvit = nn.ModuleList([])
self.mvit.append(MobileViTBlock(dims[0], L[0], channels[5], kernel_size, patch_size, int(dims[0] * 2)))
self.mvit.append(MobileViTBlock(dims[1], L[1], channels[7], kernel_size, patch_size, int(dims[1] * 4)))
self.mvit.append(MobileViTBlock(dims[2], L[2], channels[9], kernel_size, patch_size, int(dims[2] * 4)))
self.trunk.append(nn.ModuleList([
MV2Block(channels[5], channels[6], 2, expansion),
MobileViTBlock(dims[1], depths[1], channels[7], kernel_size, patch_size, int(dims[1] * 4))
]))
self.conv2 = conv_1x1_bn(channels[-2], channels[-1])
self.trunk.append(nn.ModuleList([
MV2Block(channels[7], channels[8], 2, expansion),
MobileViTBlock(dims[2], depths[2], channels[9], kernel_size, patch_size, int(dims[2] * 4))
]))
self.pool = nn.AvgPool2d(ih // 32, 1)
self.fc = nn.Linear(channels[-1], num_classes, bias=False)
self.to_logits = nn.Sequential(
conv_1x1_bn(channels[-2], last_dim),
Reduce('b c h w -> b c', 'mean'),
nn.Linear(channels[-1], num_classes, bias=False)
)
def forward(self, x):
x = self.conv1(x)
x = self.mv2[0](x)
x = self.mv2[1](x)
x = self.mv2[2](x)
x = self.mv2[3](x)
for conv in self.stem:
x = conv(x)
x = self.mv2[4](x)
x = self.mvit[0](x)
x = self.mv2[5](x)
x = self.mvit[1](x)
x = self.mv2[6](x)
x = self.mvit[2](x)
x = self.conv2(x)
x = self.pool(x).view(-1, x.shape[1])
x = self.fc(x)
return x
for conv, attn in self.trunk:
x = conv(x)
x = attn(x)
return self.to_logits(x)