mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
Compare commits
21 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3df6c31c61 | ||
|
|
54af220930 | ||
|
|
bad4b94e7b | ||
|
|
fbced01fe7 | ||
|
|
e42e9876bc | ||
|
|
566365978d | ||
|
|
34f78294d3 | ||
|
|
4c29328363 | ||
|
|
27ac10c1f1 | ||
|
|
fa216c45ea | ||
|
|
1d8b7826bf | ||
|
|
53b3af05f6 | ||
|
|
6289619e3f | ||
|
|
b42fa7862e | ||
|
|
dc6622c05c | ||
|
|
30b37c4028 | ||
|
|
4497f1e90f | ||
|
|
b50d3e1334 | ||
|
|
e075460937 | ||
|
|
5e23e48e4d | ||
|
|
db04c0f319 |
27
README.md
27
README.md
@@ -279,11 +279,10 @@ levit = LeViT(
|
||||
num_classes = 1000,
|
||||
stages = 3, # number of stages
|
||||
dim = (256, 384, 512), # dimensions at each stage
|
||||
depth = 4,
|
||||
depth = 4, # transformer of depth 4 at each stage
|
||||
heads = (4, 6, 8), # heads at each stage
|
||||
mlp_mult = 2,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
dropout = 0.1
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 224, 224)
|
||||
@@ -655,6 +654,28 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{li2021localvit,
|
||||
title = {LocalViT: Bringing Locality to Vision Transformers},
|
||||
author = {Yawei Li and Kai Zhang and Jiezhang Cao and Radu Timofte and Luc Van Gool},
|
||||
year = {2021},
|
||||
eprint = {2104.05707},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{su2021roformer,
|
||||
title = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
|
||||
author = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
|
||||
year = {2021},
|
||||
eprint = {2104.09864},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CL}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{vaswani2017attention,
|
||||
title = {Attention Is All You Need},
|
||||
|
||||
BIN
images/levit.png
Normal file
BIN
images/levit.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 71 KiB |
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.15.0',
|
||||
version = '0.16.11',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
|
||||
@@ -22,15 +22,25 @@ def group_by_key_prefix_and_remove_prefix(prefix, d):
|
||||
|
||||
# classes
|
||||
|
||||
class LayerNorm(nn.Module): # layernorm, but done in the channel dimension #1
|
||||
def __init__(self, dim, eps = 1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (std + self.eps) * self.g + self.b
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.norm = LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
x = rearrange(x, 'b c h w -> b h w c')
|
||||
x = self.norm(x)
|
||||
x = rearrange(x, 'b h w c -> b c h w')
|
||||
return self.fn(x, **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
@@ -67,8 +77,8 @@ class Attention(nn.Module):
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
self.to_q = DepthWiseConv2d(dim, inner_dim, 3, padding = padding, stride = 1, bias = False)
|
||||
self.to_kv = DepthWiseConv2d(dim, inner_dim * 2, 3, padding = padding, stride = kv_proj_stride, bias = False)
|
||||
self.to_q = DepthWiseConv2d(dim, inner_dim, proj_kernel, padding = padding, stride = 1, bias = False)
|
||||
self.to_kv = DepthWiseConv2d(dim, inner_dim * 2, proj_kernel, padding = padding, stride = kv_proj_stride, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Conv2d(inner_dim, dim, 1),
|
||||
@@ -130,7 +140,7 @@ class CvT(nn.Module):
|
||||
s3_emb_stride = 2,
|
||||
s3_proj_kernel = 3,
|
||||
s3_kv_proj_stride = 2,
|
||||
s3_heads = 4,
|
||||
s3_heads = 6,
|
||||
s3_depth = 10,
|
||||
s3_mlp_mult = 4,
|
||||
dropout = 0.
|
||||
@@ -146,6 +156,7 @@ class CvT(nn.Module):
|
||||
|
||||
layers.append(nn.Sequential(
|
||||
nn.Conv2d(dim, config['emb_dim'], kernel_size = config['emb_kernel'], padding = (config['emb_kernel'] // 2), stride = config['emb_stride']),
|
||||
LayerNorm(config['emb_dim']),
|
||||
Transformer(dim = config['emb_dim'], proj_kernel = config['proj_kernel'], kv_proj_stride = config['kv_proj_stride'], depth = config['depth'], heads = config['heads'], mlp_mult = config['mlp_mult'], dropout = dropout)
|
||||
))
|
||||
|
||||
|
||||
@@ -150,4 +150,4 @@ class DistillWrapper(nn.Module):
|
||||
teacher_labels = teacher_logits.argmax(dim = -1)
|
||||
distill_loss = F.cross_entropy(student_logits, teacher_labels)
|
||||
|
||||
return loss * alpha + distill_loss * (1 - alpha)
|
||||
return loss * (1 - alpha) + distill_loss * alpha
|
||||
|
||||
@@ -81,7 +81,6 @@ class Attention(nn.Module):
|
||||
def apply_pos_bias(self, fmap):
|
||||
bias = self.pos_bias(self.pos_indices)
|
||||
bias = rearrange(bias, 'i j h -> () h i j')
|
||||
print(bias.shape, fmap.shape)
|
||||
return fmap + bias
|
||||
|
||||
def forward(self, x):
|
||||
@@ -136,7 +135,6 @@ class LeViT(nn.Module):
|
||||
dim_key = 32,
|
||||
dim_value = 64,
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.,
|
||||
num_distill_classes = None
|
||||
):
|
||||
super().__init__()
|
||||
@@ -147,7 +145,7 @@ class LeViT(nn.Module):
|
||||
|
||||
assert all(map(lambda t: len(t) == stages, (dims, depths, layer_heads))), 'dimensions, depths, and heads must be a tuple that is less than the designated number of stages'
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
self.conv_embedding = nn.Sequential(
|
||||
nn.Conv2d(3, 32, 3, stride = 2, padding = 1),
|
||||
nn.Conv2d(32, 64, 3, stride = 2, padding = 1),
|
||||
nn.Conv2d(64, 128, 3, stride = 2, padding = 1),
|
||||
@@ -177,7 +175,7 @@ class LeViT(nn.Module):
|
||||
self.mlp_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, img):
|
||||
x = self.to_patch_embedding(img)
|
||||
x = self.conv_embedding(img)
|
||||
|
||||
x = self.backbone(x)
|
||||
|
||||
|
||||
152
vit_pytorch/local_vit.py
Normal file
152
vit_pytorch/local_vit.py
Normal file
@@ -0,0 +1,152 @@
|
||||
from math import sqrt
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# classes
|
||||
|
||||
class Residual(nn.Module):
|
||||
def __init__(self, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(x, **kwargs) + x
|
||||
|
||||
class ExcludeCLS(nn.Module):
|
||||
def __init__(self, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
cls_token, x = x[:, :1], x[:, 1:]
|
||||
x = self.fn(x, **kwargs)
|
||||
return torch.cat((cls_token, x), dim = 1)
|
||||
|
||||
# prenorm
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
# feed forward related classes
|
||||
|
||||
class DepthWiseConv2d(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, kernel_size, padding, stride = 1, bias = True):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
|
||||
nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(dim, hidden_dim, 1),
|
||||
nn.Hardswish(),
|
||||
DepthWiseConv2d(hidden_dim, hidden_dim, 3, padding = 1),
|
||||
nn.Hardswish(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Conv2d(hidden_dim, dim, 1),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
h = w = int(sqrt(x.shape[-2]))
|
||||
x = rearrange(x, 'b (h w) c -> b c h w', h = h, w = w)
|
||||
x = self.net(x)
|
||||
x = rearrange(x, 'b c h w -> b (h w) c')
|
||||
return x
|
||||
|
||||
# attention
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
b, n, _, h = *x.shape, self.heads
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
||||
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
|
||||
ExcludeCLS(Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))))
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x)
|
||||
x = ff(x)
|
||||
return x
|
||||
|
||||
# main class
|
||||
|
||||
class LocalViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
|
||||
super().__init__()
|
||||
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
patch_dim = channels * patch_size ** 2
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
||||
nn.Linear(patch_dim, dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
return self.mlp_head(x[:, 0])
|
||||
@@ -162,8 +162,9 @@ class PiT(nn.Module):
|
||||
layers.append(Pool(dim))
|
||||
dim *= 2
|
||||
|
||||
self.layers = nn.Sequential(
|
||||
*layers,
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
@@ -177,4 +178,6 @@ class PiT(nn.Module):
|
||||
x += self.pos_embedding
|
||||
x = self.dropout(x)
|
||||
|
||||
return self.layers(x)
|
||||
x = self.layers(x)
|
||||
|
||||
return self.mlp_head(x[:, 0])
|
||||
|
||||
208
vit_pytorch/rvt.py
Normal file
208
vit_pytorch/rvt.py
Normal file
@@ -0,0 +1,208 @@
|
||||
from math import sqrt, pi, log
|
||||
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# rotary embeddings
|
||||
|
||||
def rotate_every_two(x):
|
||||
x = rearrange(x, '... (d j) -> ... d j', j = 2)
|
||||
x1, x2 = x.unbind(dim = -1)
|
||||
x = torch.stack((-x2, x1), dim = -1)
|
||||
return rearrange(x, '... d j -> ... (d j)')
|
||||
|
||||
class AxialRotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_freq = 10):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
scales = torch.logspace(0., log(max_freq / 2) / log(2), self.dim // 4, base = 2)
|
||||
self.register_buffer('scales', scales)
|
||||
|
||||
def forward(self, x):
|
||||
device, dtype, n = x.device, x.dtype, int(sqrt(x.shape[-2]))
|
||||
|
||||
seq = torch.linspace(-1., 1., steps = n, device = device)
|
||||
seq = seq.unsqueeze(-1)
|
||||
|
||||
scales = self.scales[(*((None,) * (len(seq.shape) - 1)), Ellipsis)]
|
||||
scales = scales.to(x)
|
||||
|
||||
seq = seq * scales * pi
|
||||
|
||||
x_sinu = repeat(seq, 'i d -> i j d', j = n)
|
||||
y_sinu = repeat(seq, 'j d -> i j d', i = n)
|
||||
|
||||
sin = torch.cat((x_sinu.sin(), y_sinu.sin()), dim = -1)
|
||||
cos = torch.cat((x_sinu.cos(), y_sinu.cos()), dim = -1)
|
||||
|
||||
sin, cos = map(lambda t: rearrange(t, 'i j d -> (i j) d'), (sin, cos))
|
||||
sin, cos = map(lambda t: repeat(t, 'n d -> () n (d j)', j = 2), (sin, cos))
|
||||
return sin, cos
|
||||
|
||||
class DepthWiseConv2d(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, kernel_size, padding, stride = 1, bias = True):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
|
||||
nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
# helper classes
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
class SpatialConv(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, kernel, bias = False):
|
||||
super().__init__()
|
||||
self.conv = DepthWiseConv2d(dim_in, dim_out, kernel, padding = kernel // 2, bias = False)
|
||||
self.cls_proj = nn.Linear(dim_in, dim_out) if dim_in != dim_out else nn.Identity()
|
||||
|
||||
def forward(self, x, fmap_dims):
|
||||
cls_token, x = x[:, :1], x[:, 1:]
|
||||
x = rearrange(x, 'b (h w) d -> b d h w', **fmap_dims)
|
||||
x = self.conv(x)
|
||||
x = rearrange(x, 'b d h w -> b (h w) d')
|
||||
cls_token = self.cls_proj(cls_token)
|
||||
return torch.cat((cls_token, x), dim = 1)
|
||||
|
||||
class GEGLU(nn.Module):
|
||||
def forward(self, x):
|
||||
x, gates = x.chunk(2, dim = -1)
|
||||
return F.gelu(gates) * x
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim * 2),
|
||||
GEGLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., use_rotary = True, use_ds_conv = True, conv_query_kernel = 5):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.use_rotary = use_rotary
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
self.use_ds_conv = use_ds_conv
|
||||
|
||||
self.to_q = SpatialConv(dim, inner_dim, conv_query_kernel, bias = False) if use_ds_conv else nn.Linear(dim, inner_dim, bias = False)
|
||||
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, pos_emb, fmap_dims):
|
||||
b, n, _, h = *x.shape, self.heads
|
||||
|
||||
q = self.to_q(x, fmap_dims = fmap_dims) if self.use_ds_conv else self.to_q(x)
|
||||
|
||||
qkv = (q, *self.to_kv(x).chunk(2, dim = -1))
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)
|
||||
|
||||
if self.use_rotary:
|
||||
# apply 2d rotary embeddings to queries and keys, excluding CLS tokens
|
||||
|
||||
sin, cos = pos_emb
|
||||
dim_rotary = sin.shape[-1]
|
||||
|
||||
(q_cls, q), (k_cls, k) = map(lambda t: (t[:, :1], t[:, 1:]), (q, k))
|
||||
|
||||
# handle the case where rotary dimension < head dimension
|
||||
|
||||
(q, q_pass), (k, k_pass) = map(lambda t: (t[..., :dim_rotary], t[..., dim_rotary:]), (q, k))
|
||||
q, k = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k))
|
||||
q, k = map(lambda t: torch.cat(t, dim = -1), ((q, q_pass), (k, k_pass)))
|
||||
|
||||
# concat back the CLS tokens
|
||||
|
||||
q = torch.cat((q_cls, q), dim = 1)
|
||||
k = torch.cat((k_cls, k), dim = 1)
|
||||
|
||||
dots = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = einsum('b i j, b j d -> b i d', attn, v)
|
||||
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., use_rotary = True, use_ds_conv = True):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
self.pos_emb = AxialRotaryEmbedding(dim_head)
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_rotary = use_rotary, use_ds_conv = use_ds_conv)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
|
||||
]))
|
||||
def forward(self, x, fmap_dims):
|
||||
pos_emb = self.pos_emb(x[:, 1:])
|
||||
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x, pos_emb = pos_emb, fmap_dims = fmap_dims) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
# Rotary Vision Transformer
|
||||
|
||||
class RvT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., use_rotary = True, use_ds_conv = True):
|
||||
super().__init__()
|
||||
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
patch_dim = channels * patch_size ** 2
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
||||
nn.Linear(patch_dim, dim),
|
||||
)
|
||||
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, use_rotary, use_ds_conv)
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
b, _, h, w, p = *img.shape, self.patch_size
|
||||
|
||||
x = self.to_patch_embedding(img)
|
||||
n = x.shape[1]
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
|
||||
fmap_dims = {'h': h // p, 'w': w // p}
|
||||
x = self.transformer(x, fmap_dims = fmap_dims)
|
||||
|
||||
return self.mlp_head(x[:, 0])
|
||||
Reference in New Issue
Block a user