mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2026-05-13 03:29:58 +00:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fd16cbdf2e |
@@ -24,7 +24,6 @@
|
||||
- [Simple Masked Image Modeling](#simple-masked-image-modeling)
|
||||
- [Masked Patch Prediction](#masked-patch-prediction)
|
||||
- [Adaptive Token Sampling](#adaptive-token-sampling)
|
||||
- [Vision Transformer for Small Datasets](#vision-transformer-for-small-datasets)
|
||||
- [Dino](#dino)
|
||||
- [Accessing Attention](#accessing-attention)
|
||||
- [Research Ideas](#research-ideas)
|
||||
@@ -744,7 +743,7 @@ preds, token_ids = v(img, return_sampled_token_ids = True) # (1, 1000), (1, <=8)
|
||||
|
||||
<img src="./images/vit_for_small_datasets.png" width="400px"></img>
|
||||
|
||||
This <a href="https://arxiv.org/abs/2112.13492">paper</a> proposes a new image to patch function that incorporates shifts of the image, before normalizing and dividing the image into patches. I have found shifting to be extremely helpful in some other transformers work, so decided to include this for further explorations. It also includes the `LSA` with the learned temperature and masking out of a token's attention to itself.
|
||||
This paper proposes a new image to patch function that incorporates shifts of the image, before normalizing and dividing the image into patches. I have found shifting to be extremely helpful in some other transformers work, so decided to include this for further explorations. It also includes the `LRA` with the learned temperature and masking out of token attention to itself.
|
||||
|
||||
You can use as follows:
|
||||
|
||||
|
||||
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.26.2',
|
||||
version = '0.26.0',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
|
||||
@@ -14,13 +14,11 @@ class MAE(nn.Module):
|
||||
masking_ratio = 0.75,
|
||||
decoder_depth = 1,
|
||||
decoder_heads = 8,
|
||||
decoder_dim_head = 64,
|
||||
apply_decoder_pos_emb_all = False # whether to (re)apply decoder positional embedding to encoder unmasked tokens
|
||||
decoder_dim_head = 64
|
||||
):
|
||||
super().__init__()
|
||||
assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1'
|
||||
self.masking_ratio = masking_ratio
|
||||
self.apply_decoder_pos_emb_all = apply_decoder_pos_emb_all
|
||||
|
||||
# extract some hyperparameters and functions from encoder (vision transformer to be trained)
|
||||
|
||||
@@ -73,11 +71,6 @@ class MAE(nn.Module):
|
||||
|
||||
decoder_tokens = self.enc_to_dec(encoded_tokens)
|
||||
|
||||
# reapply decoder position embedding to unmasked tokens, if desired
|
||||
|
||||
if self.apply_decoder_pos_emb_all:
|
||||
decoder_tokens = decoder_tokens + self.decoder_pos_emb(unmasked_indices)
|
||||
|
||||
# repeat mask tokens for number of masked, and add the positions using the masked indices derived above
|
||||
|
||||
mask_tokens = repeat(self.mask_token, 'd -> b n d', b = batch, n = num_masked)
|
||||
|
||||
@@ -1,27 +1,40 @@
|
||||
"""
|
||||
An implementation of MobileViT Model as defined in:
|
||||
MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer
|
||||
Arxiv: https://arxiv.org/abs/2110.02178
|
||||
Origin Code: https://github.com/murufeng/awesome_lightweight_networks
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Reduce
|
||||
|
||||
# helpers
|
||||
def _make_divisible(v, divisor, min_value=None):
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
def conv_bn_relu(inp, oup, kernel, stride=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, kernel_size=kernel, stride=stride, padding=1, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
nn.ReLU6(inplace=True)
|
||||
)
|
||||
|
||||
|
||||
def conv_1x1_bn(inp, oup):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
nn.SiLU()
|
||||
nn.ReLU6(inplace=True)
|
||||
)
|
||||
|
||||
def conv_nxn_bn(inp, oup, kernal_size=3, stride=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
nn.SiLU()
|
||||
)
|
||||
|
||||
# classes
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
@@ -31,11 +44,10 @@ class PreNorm(nn.Module):
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout=0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
self.ffn = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(dropout),
|
||||
@@ -44,7 +56,8 @@ class FeedForward(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
return self.ffn(x)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
|
||||
@@ -63,8 +76,7 @@ class Attention(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
||||
q, k, v = map(lambda t: rearrange(
|
||||
t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
attn = self.attend(dots)
|
||||
@@ -72,19 +84,15 @@ class Attention(nn.Module):
|
||||
out = rearrange(out, 'b p h n d -> b p n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
"""Transformer block described in ViT.
|
||||
Paper: https://arxiv.org/abs/2010.11929
|
||||
Based on: https://github.com/lucidrains/vit-pytorch
|
||||
"""
|
||||
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
|
||||
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
@@ -94,24 +102,17 @@ class Transformer(nn.Module):
|
||||
return x
|
||||
|
||||
class MV2Block(nn.Module):
|
||||
"""MV2 block described in MobileNetV2.
|
||||
Paper: https://arxiv.org/pdf/1801.04381
|
||||
Based on: https://github.com/tonylins/pytorch-mobilenet-v2
|
||||
"""
|
||||
|
||||
def __init__(self, inp, oup, stride=1, expansion=4):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
def __init__(self, inp, oup, stride=1, expand_ratio=4):
|
||||
super(MV2Block, self).__init__()
|
||||
assert stride in [1, 2]
|
||||
|
||||
hidden_dim = int(inp * expansion)
|
||||
self.use_res_connect = self.stride == 1 and inp == oup
|
||||
hidden_dim = round(inp * expand_ratio)
|
||||
self.identity = stride == 1 and inp == oup
|
||||
|
||||
if expansion == 1:
|
||||
if expand_ratio == 1:
|
||||
self.conv = nn.Sequential(
|
||||
# dw
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
|
||||
1, groups=hidden_dim, bias=False),
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.SiLU(),
|
||||
# pw-linear
|
||||
@@ -125,8 +126,7 @@ class MV2Block(nn.Module):
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.SiLU(),
|
||||
# dw
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
|
||||
1, groups=hidden_dim, bias=False),
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.SiLU(),
|
||||
# pw-linear
|
||||
@@ -136,7 +136,8 @@ class MV2Block(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv(x)
|
||||
if self.use_res_connect:
|
||||
|
||||
if self.identity:
|
||||
out = out + x
|
||||
return out
|
||||
|
||||
@@ -145,13 +146,13 @@ class MobileViTBlock(nn.Module):
|
||||
super().__init__()
|
||||
self.ph, self.pw = patch_size
|
||||
|
||||
self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
|
||||
self.conv1 = conv_bn_relu(channel, channel, kernel_size)
|
||||
self.conv2 = conv_1x1_bn(channel, dim)
|
||||
|
||||
self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout)
|
||||
self.transformer = Transformer(dim, depth, 1, 32, mlp_dim, dropout)
|
||||
|
||||
self.conv3 = conv_1x1_bn(dim, channel)
|
||||
self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)
|
||||
self.conv4 = conv_bn_relu(2 * channel, channel, kernel_size)
|
||||
|
||||
def forward(self, x):
|
||||
y = x.clone()
|
||||
@@ -162,11 +163,9 @@ class MobileViTBlock(nn.Module):
|
||||
|
||||
# Global representations
|
||||
_, _, h, w = x.shape
|
||||
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d',
|
||||
ph=self.ph, pw=self.pw)
|
||||
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
|
||||
x = self.transformer(x)
|
||||
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)',
|
||||
h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)
|
||||
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h // self.ph, w=w // self.pw, ph=self.ph, pw=self.pw)
|
||||
|
||||
# Fusion
|
||||
x = self.conv3(x)
|
||||
@@ -174,22 +173,18 @@ class MobileViTBlock(nn.Module):
|
||||
x = self.conv4(x)
|
||||
return x
|
||||
|
||||
class MobileViT(nn.Module):
|
||||
"""MobileViT.
|
||||
Paper: https://arxiv.org/abs/2110.02178
|
||||
Based on: https://github.com/chinhsuanwu/mobilevit-pytorch
|
||||
"""
|
||||
|
||||
class MobileViT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
image_size,
|
||||
dims,
|
||||
channels,
|
||||
num_classes,
|
||||
expansion=4,
|
||||
kernel_size=3,
|
||||
patch_size=(2, 2),
|
||||
depths=(2, 4, 3)
|
||||
expansion = 4,
|
||||
kernel_size = 3,
|
||||
patch_size = (2, 2),
|
||||
depths = (2, 4, 3)
|
||||
):
|
||||
super().__init__()
|
||||
assert len(dims) == 3, 'dims must be a tuple of 3'
|
||||
@@ -201,31 +196,28 @@ class MobileViT(nn.Module):
|
||||
|
||||
init_dim, *_, last_dim = channels
|
||||
|
||||
self.conv1 = conv_nxn_bn(3, init_dim, stride=2)
|
||||
self.conv1 = conv_bn_relu(3, init_dim, kernel=3, stride=2)
|
||||
|
||||
self.stem = nn.ModuleList([])
|
||||
self.stem.append(MV2Block(channels[0], channels[1], 1, expansion))
|
||||
self.stem.append(MV2Block(channels[1], channels[2], 2, expansion))
|
||||
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
|
||||
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
|
||||
|
||||
|
||||
self.trunk = nn.ModuleList([])
|
||||
self.trunk.append(nn.ModuleList([
|
||||
MV2Block(channels[3], channels[4], 2, expansion),
|
||||
MobileViTBlock(dims[0], depths[0], channels[5],
|
||||
kernel_size, patch_size, int(dims[0] * 2))
|
||||
MobileViTBlock(dims[0], depths[0], channels[5], kernel_size, patch_size, int(dims[0] * 2))
|
||||
]))
|
||||
|
||||
self.trunk.append(nn.ModuleList([
|
||||
MV2Block(channels[5], channels[6], 2, expansion),
|
||||
MobileViTBlock(dims[1], depths[1], channels[7],
|
||||
kernel_size, patch_size, int(dims[1] * 4))
|
||||
MobileViTBlock(dims[1], depths[1], channels[7], kernel_size, patch_size, int(dims[1] * 4))
|
||||
]))
|
||||
|
||||
self.trunk.append(nn.ModuleList([
|
||||
MV2Block(channels[7], channels[8], 2, expansion),
|
||||
MobileViTBlock(dims[2], depths[2], channels[9],
|
||||
kernel_size, patch_size, int(dims[2] * 4))
|
||||
MobileViTBlock(dims[2], depths[2], channels[9], kernel_size, patch_size, int(dims[2] * 4))
|
||||
]))
|
||||
|
||||
self.to_logits = nn.Sequential(
|
||||
|
||||
Reference in New Issue
Block a user