mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
add CvT
This commit is contained in:
11
README.md
11
README.md
@@ -443,6 +443,17 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@misc{wu2021cvt,
|
||||||
|
title = {CvT: Introducing Convolutions to Vision Transformers},
|
||||||
|
author = {Haiping Wu and Bin Xiao and Noel Codella and Mengchen Liu and Xiyang Dai and Lu Yuan and Lei Zhang},
|
||||||
|
year = {2021},
|
||||||
|
eprint = {2103.15808},
|
||||||
|
archivePrefix = {arXiv},
|
||||||
|
primaryClass = {cs.CV}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
```bibtex
|
```bibtex
|
||||||
@misc{vaswani2017attention,
|
@misc{vaswani2017attention,
|
||||||
title = {Attention Is All You Need},
|
title = {Attention Is All You Need},
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
|||||||
setup(
|
setup(
|
||||||
name = 'vit-pytorch',
|
name = 'vit-pytorch',
|
||||||
packages = find_packages(exclude=['examples']),
|
packages = find_packages(exclude=['examples']),
|
||||||
version = '0.10.3',
|
version = '0.11.0',
|
||||||
license='MIT',
|
license='MIT',
|
||||||
description = 'Vision Transformer (ViT) - Pytorch',
|
description = 'Vision Transformer (ViT) - Pytorch',
|
||||||
author = 'Phil Wang',
|
author = 'Phil Wang',
|
||||||
|
|||||||
151
vit_pytorch/cvt.py
Normal file
151
vit_pytorch/cvt.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn, einsum
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from einops.layers.torch import Rearrange
|
||||||
|
|
||||||
|
# helper methods
|
||||||
|
|
||||||
|
def group_dict_by_key(cond, d):
|
||||||
|
return_val = [dict(), dict()]
|
||||||
|
for key in d.keys():
|
||||||
|
match = bool(cond(key))
|
||||||
|
ind = int(not match)
|
||||||
|
return_val[ind][key] = d[key]
|
||||||
|
return (*return_val,)
|
||||||
|
|
||||||
|
def group_by_key_prefix_and_remove_prefix(prefix, d):
|
||||||
|
kwargs_with_prefix, kwargs = group_dict_by_key(lambda x: x.startswith(prefix), d)
|
||||||
|
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
||||||
|
return kwargs_without_prefix, kwargs
|
||||||
|
|
||||||
|
# classes
|
||||||
|
|
||||||
|
class PreNorm(nn.Module):
|
||||||
|
def __init__(self, dim, fn):
|
||||||
|
super().__init__()
|
||||||
|
self.norm = nn.LayerNorm(dim)
|
||||||
|
self.fn = fn
|
||||||
|
def forward(self, x, **kwargs):
|
||||||
|
x = rearrange(x, 'b c h w -> b h w c')
|
||||||
|
x = self.norm(x)
|
||||||
|
x = rearrange(x, 'b h w c -> b c h w')
|
||||||
|
return self.fn(x, **kwargs)
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(self, dim, mult = 4, dropout = 0.):
|
||||||
|
super().__init__()
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
nn.Conv2d(dim, dim * mult, 1),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Conv2d(dim * mult, dim, 1),
|
||||||
|
nn.Dropout(dropout)
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, dim, proj_kernel, kv_proj_stride, heads = 8, dim_head = 64, dropout = 0.):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = dim_head * heads
|
||||||
|
padding = proj_kernel // 2
|
||||||
|
self.heads = heads
|
||||||
|
self.scale = dim_head ** -0.5
|
||||||
|
|
||||||
|
self.attend = nn.Softmax(dim = -1)
|
||||||
|
|
||||||
|
self.to_q = nn.Conv2d(dim, inner_dim, 3, padding = padding, stride = 1, bias = False)
|
||||||
|
self.to_kv = nn.Conv2d(dim, inner_dim * 2, 3, padding = padding, stride = kv_proj_stride, bias = False)
|
||||||
|
|
||||||
|
self.to_out = nn.Sequential(
|
||||||
|
nn.Conv2d(inner_dim, dim, 1),
|
||||||
|
nn.Dropout(dropout)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
shape = x.shape
|
||||||
|
b, n, _, y, h = *shape, self.heads
|
||||||
|
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = 1))
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> (b h) (x y) d', h = h), (q, k, v))
|
||||||
|
|
||||||
|
dots = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||||
|
|
||||||
|
attn = self.attend(dots)
|
||||||
|
|
||||||
|
out = einsum('b i j, b j d -> b i d', attn, v)
|
||||||
|
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, y = y)
|
||||||
|
return self.to_out(out)
|
||||||
|
|
||||||
|
class Transformer(nn.Module):
|
||||||
|
def __init__(self, dim, proj_kernel, kv_proj_stride, depth, heads, dim_head = 64, mlp_mult = 4, dropout = 0.):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = nn.ModuleList([])
|
||||||
|
for _ in range(depth):
|
||||||
|
self.layers.append(nn.ModuleList([
|
||||||
|
PreNorm(dim, Attention(dim, proj_kernel = proj_kernel, kv_proj_stride = kv_proj_stride, heads = heads, dim_head = dim_head, dropout = dropout)),
|
||||||
|
PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout))
|
||||||
|
]))
|
||||||
|
def forward(self, x):
|
||||||
|
for attn, ff in self.layers:
|
||||||
|
x = attn(x) + x
|
||||||
|
x = ff(x) + x
|
||||||
|
return x
|
||||||
|
|
||||||
|
class CvT(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
num_classes,
|
||||||
|
s1_emb_dim = 64,
|
||||||
|
s1_emb_kernel = 7,
|
||||||
|
s1_emb_stride = 4,
|
||||||
|
s1_proj_kernel = 3,
|
||||||
|
s1_kv_proj_stride = 2,
|
||||||
|
s1_heads = 1,
|
||||||
|
s1_depth = 1,
|
||||||
|
s1_mlp_mult = 4,
|
||||||
|
s2_emb_dim = 192,
|
||||||
|
s2_emb_kernel = 3,
|
||||||
|
s2_emb_stride = 2,
|
||||||
|
s2_proj_kernel = 3,
|
||||||
|
s2_kv_proj_stride = 2,
|
||||||
|
s2_heads = 3,
|
||||||
|
s2_depth = 2,
|
||||||
|
s2_mlp_mult = 4,
|
||||||
|
s3_emb_dim = 384,
|
||||||
|
s3_emb_kernel = 3,
|
||||||
|
s3_emb_stride = 2,
|
||||||
|
s3_proj_kernel = 3,
|
||||||
|
s3_kv_proj_stride = 2,
|
||||||
|
s3_heads = 4,
|
||||||
|
s3_depth = 10,
|
||||||
|
s3_mlp_mult = 4,
|
||||||
|
dropout = 0.
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
kwargs = dict(locals())
|
||||||
|
|
||||||
|
dim = 3
|
||||||
|
layers = []
|
||||||
|
|
||||||
|
for prefix in ('s1', 's2', 's3'):
|
||||||
|
config, kwargs = group_by_key_prefix_and_remove_prefix(f'{prefix}_', kwargs)
|
||||||
|
|
||||||
|
layers.append(nn.Sequential(
|
||||||
|
nn.Conv2d(dim, config['emb_dim'], kernel_size = config['emb_kernel'], padding = (config['emb_kernel'] // 2), stride = config['emb_stride']),
|
||||||
|
Transformer(dim = config['emb_dim'], proj_kernel = config['proj_kernel'], kv_proj_stride = config['kv_proj_stride'], depth = config['depth'], heads = config['heads'], mlp_mult = config['mlp_mult'], dropout = dropout)
|
||||||
|
))
|
||||||
|
|
||||||
|
dim = config['emb_dim']
|
||||||
|
|
||||||
|
self.layers = nn.Sequential(
|
||||||
|
*layers,
|
||||||
|
nn.AdaptiveAvgPool2d(1),
|
||||||
|
Rearrange('... () () -> ...'),
|
||||||
|
nn.Linear(dim, num_classes)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.layers(x)
|
||||||
Reference in New Issue
Block a user