mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
db04c0f319 | ||
|
|
2cb6b35030 | ||
|
|
2ec9161a98 | ||
|
|
3a3038c702 | ||
|
|
b1f1044c8e | ||
|
|
deb96201d5 | ||
|
|
05b47cc070 | ||
|
|
9ef8da4759 |
98
README.md
98
README.md
@@ -93,7 +93,8 @@ distiller = DistillWrapper(
|
||||
student = v,
|
||||
teacher = teacher,
|
||||
temperature = 3, # temperature of distillation
|
||||
alpha = 0.5 # trade between main loss and distillation loss
|
||||
alpha = 0.5, # trade between main loss and distillation loss
|
||||
hard = False # whether to use soft or hard distillation
|
||||
)
|
||||
|
||||
img = torch.randn(2, 3, 256, 256)
|
||||
@@ -143,6 +144,37 @@ img = torch.randn(1, 3, 256, 256)
|
||||
preds = v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## CaiT
|
||||
|
||||
<a href="https://arxiv.org/abs/2103.17239">This paper</a> also notes difficulty in training vision transformers at greater depths and proposes two solutions. First it proposes to do per-channel multiplication of the output of the residual block. Second, it proposes to have the patches attend to one another, and only allow the CLS token to attend to the patches in the last few layers.
|
||||
|
||||
They also add <a href="https://github.com/lucidrains/x-transformers#talking-heads-attention">Talking Heads</a>, noting improvements
|
||||
|
||||
You can use this scheme as follows
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.cait import CaiT
|
||||
|
||||
v = CaiT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 12, # depth of transformer for patch to patch attention only
|
||||
cls_depth = 2, # depth of cross attention of CLS tokens to patch
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1,
|
||||
layer_dropout = 0.05 # randomly dropout 5% of the layers
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 256, 256)
|
||||
|
||||
preds = v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## Token-to-Token ViT
|
||||
|
||||
<img src="./images/t2t.png" width="400px"></img>
|
||||
@@ -164,7 +196,8 @@ v = T2TViT(
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 224, 224)
|
||||
v(img) # (1, 1000)
|
||||
|
||||
preds = v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## Cross ViT
|
||||
@@ -177,7 +210,7 @@ v(img) # (1, 1000)
|
||||
import torch
|
||||
from vit_pytorch.cross_vit import CrossViT
|
||||
|
||||
model = CrossViT(
|
||||
v = CrossViT(
|
||||
image_size = 256,
|
||||
num_classes = 1000,
|
||||
depth = 4, # number of multi-scale encoding blocks
|
||||
@@ -199,7 +232,7 @@ model = CrossViT(
|
||||
|
||||
img = torch.randn(1, 3, 256, 256)
|
||||
|
||||
pred = model(img) # (1, 1000)
|
||||
pred = v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## PiT
|
||||
@@ -212,7 +245,7 @@ pred = model(img) # (1, 1000)
|
||||
import torch
|
||||
from vit_pytorch.pit import PiT
|
||||
|
||||
p = PiT(
|
||||
v = PiT(
|
||||
image_size = 224,
|
||||
patch_size = 14,
|
||||
dim = 256,
|
||||
@@ -228,7 +261,34 @@ p = PiT(
|
||||
|
||||
img = torch.randn(1, 3, 224, 224)
|
||||
|
||||
preds = p(img) # (1, 1000)
|
||||
preds = v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## LeViT
|
||||
|
||||
<img src="./images/levit.png" width="300px"></img>
|
||||
|
||||
<a href="https://arxiv.org/abs/2104.01136">This paper</a> proposes a number of changes, including (1) convolutional embedding instead of patch-wise projection (2) downsampling in stages (3) extra non-linearity in attention (4) 2d relative positional biases instead of initial absolute positional bias (5) batchnorm in place of layernorm.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.levit import LeViT
|
||||
|
||||
levit = LeViT(
|
||||
image_size = 224,
|
||||
num_classes = 1000,
|
||||
stages = 3, # number of stages
|
||||
dim = (256, 384, 512), # dimensions at each stage
|
||||
depth = 4,
|
||||
heads = (4, 6, 8), # heads at each stage
|
||||
mlp_mult = 2,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 224, 224)
|
||||
|
||||
levit(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## CvT
|
||||
@@ -241,7 +301,7 @@ preds = p(img) # (1, 1000)
|
||||
import torch
|
||||
from vit_pytorch.cvt import CvT
|
||||
|
||||
model = CvT(
|
||||
v = CvT(
|
||||
num_classes = 1000,
|
||||
s1_emb_dim = 64, # stage 1 - dimension
|
||||
s1_emb_kernel = 7, # stage 1 - conv kernel
|
||||
@@ -272,7 +332,7 @@ model = CvT(
|
||||
|
||||
img = torch.randn(1, 3, 224, 224)
|
||||
|
||||
pred = model(img) # (1, 1000)
|
||||
pred = v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## Masked Patch Prediction
|
||||
@@ -540,6 +600,17 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{touvron2021going,
|
||||
title = {Going deeper with Image Transformers},
|
||||
author = {Hugo Touvron and Matthieu Cord and Alexandre Sablayrolles and Gabriel Synnaeve and Hervé Jégou},
|
||||
year = {2021},
|
||||
eprint = {2103.17239},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{chen2021crossvit,
|
||||
title = {CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification},
|
||||
@@ -573,6 +644,17 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{graham2021levit,
|
||||
title = {LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference},
|
||||
author = {Ben Graham and Alaaeldin El-Nouby and Hugo Touvron and Pierre Stock and Armand Joulin and Hervé Jégou and Matthijs Douze},
|
||||
year = {2021},
|
||||
eprint = {2104.01136},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{vaswani2017attention,
|
||||
title = {Attention Is All You Need},
|
||||
|
||||
BIN
images/cait.png
Normal file
BIN
images/cait.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 63 KiB |
BIN
images/levit.png
Normal file
BIN
images/levit.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 71 KiB |
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.12.0',
|
||||
version = '0.15.0',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
|
||||
177
vit_pytorch/cait.py
Normal file
177
vit_pytorch/cait.py
Normal file
@@ -0,0 +1,177 @@
|
||||
from random import randrange
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def dropout_layers(layers, dropout):
|
||||
if dropout == 0:
|
||||
return layers
|
||||
|
||||
num_layers = len(layers)
|
||||
to_drop = torch.zeros(num_layers).uniform_(0., 1.) < dropout
|
||||
|
||||
# make sure at least one layer makes it
|
||||
if all(to_drop):
|
||||
rand_index = randrange(num_layers)
|
||||
to_drop[rand_index] = False
|
||||
|
||||
layers = [layer for (layer, drop) in zip(layers, to_drop) if not drop]
|
||||
return layers
|
||||
|
||||
# classes
|
||||
|
||||
class LayerScale(nn.Module):
|
||||
def __init__(self, dim, fn, depth):
|
||||
super().__init__()
|
||||
if depth <= 18: # epsilon detailed in section 2 of paper
|
||||
init_eps = 0.1
|
||||
elif depth > 18 and depth <= 24:
|
||||
init_eps = 1e-5
|
||||
else:
|
||||
init_eps = 1e-6
|
||||
|
||||
scale = torch.zeros(1, 1, dim).fill_(init_eps)
|
||||
self.scale = nn.Parameter(scale)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(x, **kwargs) * self.scale
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
self.mix_heads_pre_attn = nn.Parameter(torch.randn(heads, heads))
|
||||
self.mix_heads_post_attn = nn.Parameter(torch.randn(heads, heads))
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, context = None):
|
||||
b, n, _, h = *x.shape, self.heads
|
||||
|
||||
context = x if not exists(context) else torch.cat((x, context), dim = 1)
|
||||
|
||||
qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
||||
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
|
||||
dots = einsum('b h i j, h g -> b g i j', dots, self.mix_heads_pre_attn) # talking heads, pre-softmax
|
||||
attn = self.attend(dots)
|
||||
attn = einsum('b h i j, h g -> b g i j', attn, self.mix_heads_post_attn) # talking heads, post-softmax
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., layer_dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
self.layer_dropout = layer_dropout
|
||||
|
||||
for ind in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
LayerScale(dim, PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), depth = ind + 1),
|
||||
LayerScale(dim, PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)), depth = ind + 1)
|
||||
]))
|
||||
def forward(self, x, context = None):
|
||||
layers = dropout_layers(self.layers, dropout = self.layer_dropout)
|
||||
|
||||
for attn, ff in layers:
|
||||
x = attn(x, context = context) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
class CaiT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_size,
|
||||
patch_size,
|
||||
num_classes,
|
||||
dim,
|
||||
depth,
|
||||
cls_depth,
|
||||
heads,
|
||||
mlp_dim,
|
||||
dim_head = 64,
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.,
|
||||
layer_dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
patch_dim = 3 * patch_size ** 2
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
||||
nn.Linear(patch_dim, dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.patch_transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, layer_dropout)
|
||||
self.cls_transformer = Transformer(dim, cls_depth, heads, dim_head, mlp_dim, dropout, layer_dropout)
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
x += self.pos_embedding[:, :n]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.patch_transformer(x)
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = self.cls_transformer(cls_tokens, context = x)
|
||||
|
||||
return self.mlp_head(x[:, 0])
|
||||
@@ -104,7 +104,8 @@ class DistillWrapper(nn.Module):
|
||||
teacher,
|
||||
student,
|
||||
temperature = 1.,
|
||||
alpha = 0.5
|
||||
alpha = 0.5,
|
||||
hard = False
|
||||
):
|
||||
super().__init__()
|
||||
assert (isinstance(student, (DistillableViT, DistillableT2TViT, DistillableEfficientViT))) , 'student must be a vision transformer'
|
||||
@@ -116,6 +117,7 @@ class DistillWrapper(nn.Module):
|
||||
num_classes = student.num_classes
|
||||
self.temperature = temperature
|
||||
self.alpha = alpha
|
||||
self.hard = hard
|
||||
|
||||
self.distillation_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
|
||||
@@ -137,11 +139,15 @@ class DistillWrapper(nn.Module):
|
||||
|
||||
loss = F.cross_entropy(student_logits, labels)
|
||||
|
||||
distill_loss = F.kl_div(
|
||||
F.log_softmax(distill_logits / T, dim = -1),
|
||||
F.softmax(teacher_logits / T, dim = -1).detach(),
|
||||
reduction = 'batchmean')
|
||||
if not self.hard:
|
||||
distill_loss = F.kl_div(
|
||||
F.log_softmax(distill_logits / T, dim = -1),
|
||||
F.softmax(teacher_logits / T, dim = -1).detach(),
|
||||
reduction = 'batchmean')
|
||||
distill_loss *= T ** 2
|
||||
|
||||
distill_loss *= T ** 2
|
||||
else:
|
||||
teacher_labels = teacher_logits.argmax(dim = -1)
|
||||
distill_loss = F.cross_entropy(student_logits, teacher_labels)
|
||||
|
||||
return loss * alpha + distill_loss * (1 - alpha)
|
||||
|
||||
192
vit_pytorch/levit.py
Normal file
192
vit_pytorch/levit.py
Normal file
@@ -0,0 +1,192 @@
|
||||
from math import ceil
|
||||
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def cast_tuple(val, l = 3):
|
||||
val = val if isinstance(val, tuple) else (val,)
|
||||
return (*val, *((val[-1],) * max(l - len(val), 0)))
|
||||
|
||||
def always(val):
|
||||
return lambda *args, **kwargs: val
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, mult, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(dim, dim * mult, 1),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Conv2d(dim * mult, dim, 1),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, fmap_size, heads = 8, dim_key = 32, dim_value = 64, dropout = 0., dim_out = None, downsample = False):
|
||||
super().__init__()
|
||||
inner_dim_key = dim_key * heads
|
||||
inner_dim_value = dim_value * heads
|
||||
dim_out = default(dim_out, dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_key ** -0.5
|
||||
|
||||
self.to_q = nn.Sequential(nn.Conv2d(dim, inner_dim_key, 1, stride = (2 if downsample else 1), bias = False), nn.BatchNorm2d(inner_dim_key))
|
||||
self.to_k = nn.Sequential(nn.Conv2d(dim, inner_dim_key, 1, bias = False), nn.BatchNorm2d(inner_dim_key))
|
||||
self.to_v = nn.Sequential(nn.Conv2d(dim, inner_dim_value, 1, bias = False), nn.BatchNorm2d(inner_dim_value))
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.GELU(),
|
||||
nn.Conv2d(inner_dim_value, dim_out, 1),
|
||||
nn.BatchNorm2d(dim_out),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
# positional bias
|
||||
|
||||
self.pos_bias = nn.Embedding(fmap_size * fmap_size, heads)
|
||||
|
||||
q_range = torch.arange(0, fmap_size, step = (2 if downsample else 1))
|
||||
k_range = torch.arange(fmap_size)
|
||||
|
||||
q_pos = torch.stack(torch.meshgrid(q_range, q_range), dim = -1)
|
||||
k_pos = torch.stack(torch.meshgrid(k_range, k_range), dim = -1)
|
||||
|
||||
q_pos, k_pos = map(lambda t: rearrange(t, 'i j c -> (i j) c'), (q_pos, k_pos))
|
||||
rel_pos = (q_pos[:, None, ...] - k_pos[None, :, ...]).abs()
|
||||
|
||||
x_rel, y_rel = rel_pos.unbind(dim = -1)
|
||||
pos_indices = (x_rel * fmap_size) + y_rel
|
||||
|
||||
self.register_buffer('pos_indices', pos_indices)
|
||||
|
||||
def apply_pos_bias(self, fmap):
|
||||
bias = self.pos_bias(self.pos_indices)
|
||||
bias = rearrange(bias, 'i j h -> () h i j')
|
||||
print(bias.shape, fmap.shape)
|
||||
return fmap + bias
|
||||
|
||||
def forward(self, x):
|
||||
b, n, *_, h = *x.shape, self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
y = q.shape[2]
|
||||
|
||||
qkv = (q, self.to_k(x), self.to_v(x))
|
||||
q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = h), qkv)
|
||||
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
|
||||
dots = self.apply_pos_bias(dots)
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h (x y) d -> b (h d) x y', h = h, y = y)
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, fmap_size, depth, heads, dim_key, dim_value, mlp_mult = 2, dropout = 0., dim_out = None, downsample = False):
|
||||
super().__init__()
|
||||
dim_out = default(dim_out, dim)
|
||||
self.layers = nn.ModuleList([])
|
||||
self.attn_residual = (not downsample) and dim == dim_out
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, fmap_size = fmap_size, heads = heads, dim_key = dim_key, dim_value = dim_value, dropout = dropout, downsample = downsample, dim_out = dim_out),
|
||||
FeedForward(dim_out, mlp_mult, dropout = dropout)
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
attn_res = (x if self.attn_residual else 0)
|
||||
x = attn(x) + attn_res
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
class LeViT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_size,
|
||||
num_classes,
|
||||
dim,
|
||||
depth,
|
||||
heads,
|
||||
mlp_mult,
|
||||
stages = 3,
|
||||
dim_key = 32,
|
||||
dim_value = 64,
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.,
|
||||
num_distill_classes = None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
dims = cast_tuple(dim, stages)
|
||||
depths = cast_tuple(depth, stages)
|
||||
layer_heads = cast_tuple(heads, stages)
|
||||
|
||||
assert all(map(lambda t: len(t) == stages, (dims, depths, layer_heads))), 'dimensions, depths, and heads must be a tuple that is less than the designated number of stages'
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
nn.Conv2d(3, 32, 3, stride = 2, padding = 1),
|
||||
nn.Conv2d(32, 64, 3, stride = 2, padding = 1),
|
||||
nn.Conv2d(64, 128, 3, stride = 2, padding = 1),
|
||||
nn.Conv2d(128, dims[0], 3, stride = 2, padding = 1)
|
||||
)
|
||||
|
||||
fmap_size = image_size // (2 ** 4)
|
||||
layers = []
|
||||
|
||||
for ind, dim, depth, heads in zip(range(stages), dims, depths, layer_heads):
|
||||
is_last = ind == (stages - 1)
|
||||
layers.append(Transformer(dim, fmap_size, depth, heads, dim_key, dim_value, mlp_mult, dropout))
|
||||
|
||||
if not is_last:
|
||||
next_dim = dims[ind + 1]
|
||||
layers.append(Transformer(dim, fmap_size, 1, heads * 2, dim_key, dim_value, dim_out = next_dim, downsample = True))
|
||||
fmap_size = ceil(fmap_size / 2)
|
||||
|
||||
self.backbone = nn.Sequential(*layers)
|
||||
|
||||
self.pool = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
Rearrange('... () () -> ...')
|
||||
)
|
||||
|
||||
self.distill_head = nn.Linear(dim, num_distill_classes) if exists(num_distill_classes) else always(None)
|
||||
self.mlp_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, img):
|
||||
x = self.to_patch_embedding(img)
|
||||
|
||||
x = self.backbone(x)
|
||||
|
||||
x = self.pool(x)
|
||||
|
||||
out = self.mlp_head(x)
|
||||
distill = self.distill_head(x)
|
||||
|
||||
if exists(distill):
|
||||
return out, distill
|
||||
|
||||
return out
|
||||
Reference in New Issue
Block a user