mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 16:12:29 +00:00
Compare commits
14 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bca88e9039 | ||
|
|
96f66d2754 | ||
|
|
12249dcc5f | ||
|
|
8b8da8dede | ||
|
|
5578ac472f | ||
|
|
d446a41243 | ||
|
|
0ad09c4cbc | ||
|
|
92b69321f4 | ||
|
|
fb4ac25174 | ||
|
|
53fe345e85 | ||
|
|
efb94608ea | ||
|
|
51310d1d07 | ||
|
|
1616288e30 | ||
|
|
9e1e824385 |
45
README.md
45
README.md
@@ -25,6 +25,7 @@
|
||||
- [MaxViT](#maxvit)
|
||||
- [NesT](#nest)
|
||||
- [MobileViT](#mobilevit)
|
||||
- [XCiT](#xcit)
|
||||
- [Masked Autoencoder](#masked-autoencoder)
|
||||
- [Simple Masked Image Modeling](#simple-masked-image-modeling)
|
||||
- [Masked Patch Prediction](#masked-patch-prediction)
|
||||
@@ -92,7 +93,7 @@ preds = v(img) # (1, 1000)
|
||||
- `image_size`: int.
|
||||
Image size. If you have rectangular images, make sure your image size is the maximum of the width and height
|
||||
- `patch_size`: int.
|
||||
Number of patches. `image_size` must be divisible by `patch_size`.
|
||||
Size of patches. `image_size` must be divisible by `patch_size`.
|
||||
The number of patches is: ` n = (image_size // patch_size) ** 2` and `n` **must be greater than 16**.
|
||||
- `num_classes`: int.
|
||||
Number of classes to classify.
|
||||
@@ -772,6 +773,38 @@ img = torch.randn(1, 3, 256, 256)
|
||||
pred = mbvit_xs(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## XCiT
|
||||
|
||||
<img src="./images/xcit.png" width="400px"></img>
|
||||
|
||||
This <a href="https://arxiv.org/abs/2106.09681">paper</a> introduces the cross covariance attention (abbreviated XCA). One can think of it as doing attention across the features dimension rather than the spatial one (another perspective would be a dynamic 1x1 convolution, the kernel being attention map defined by spatial correlations).
|
||||
|
||||
Technically, this amounts to simply transposing the query, key, values before executing cosine similarity attention with learned temperature.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.xcit import XCiT
|
||||
|
||||
v = XCiT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 12, # depth of xcit transformer
|
||||
cls_depth = 2, # depth of cross attention of CLS tokens to patch, attention pool at end
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1,
|
||||
layer_dropout = 0.05, # randomly dropout 5% of the layers
|
||||
local_patch_kernel_size = 3 # kernel size of the local patch interaction module (depthwise convs)
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 256, 256)
|
||||
|
||||
preds = v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## Simple Masked Image Modeling
|
||||
|
||||
<img src="./images/simmim.png" width="400px"/>
|
||||
@@ -2029,4 +2062,14 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{ElNouby2021XCiTCI,
|
||||
title = {XCiT: Cross-Covariance Image Transformers},
|
||||
author = {Alaaeldin El-Nouby and Hugo Touvron and Mathilde Caron and Piotr Bojanowski and Matthijs Douze and Armand Joulin and Ivan Laptev and Natalia Neverova and Gabriel Synnaeve and Jakob Verbeek and Herv{\'e} J{\'e}gou},
|
||||
booktitle = {Neural Information Processing Systems},
|
||||
year = {2021},
|
||||
url = {https://api.semanticscholar.org/CorpusID:235458262}
|
||||
}
|
||||
```
|
||||
|
||||
*I visualise a time when we will be to robots what dogs are to humans, and I’m rooting for the machines.* — Claude Shannon
|
||||
|
||||
BIN
images/xcit.png
Normal file
BIN
images/xcit.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 814 KiB |
6
setup.py
6
setup.py
@@ -1,11 +1,15 @@
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
with open('README.md') as f:
|
||||
long_description = f.read()
|
||||
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '1.5.3',
|
||||
version = '1.6.8',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
long_description=long_description,
|
||||
long_description_content_type = 'text/markdown',
|
||||
author = 'Phil Wang',
|
||||
author_email = 'lucidrains@gmail.com',
|
||||
|
||||
@@ -1,10 +1,3 @@
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse('2.0.0'):
|
||||
from einops._torch_specific import allow_ops_in_compiled_graph
|
||||
allow_ops_in_compiled_graph()
|
||||
|
||||
from vit_pytorch.vit import ViT
|
||||
from vit_pytorch.simple_vit import SimpleViT
|
||||
|
||||
|
||||
@@ -170,12 +170,13 @@ class ImageEmbedder(nn.Module):
|
||||
dim,
|
||||
image_size,
|
||||
patch_size,
|
||||
dropout = 0.
|
||||
dropout = 0.,
|
||||
channels = 3
|
||||
):
|
||||
super().__init__()
|
||||
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
patch_dim = 3 * patch_size ** 2
|
||||
patch_dim = channels * patch_size ** 2
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
||||
@@ -223,11 +224,12 @@ class CrossViT(nn.Module):
|
||||
cross_attn_dim_head = 64,
|
||||
depth = 3,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
emb_dropout = 0.1,
|
||||
channels = 3
|
||||
):
|
||||
super().__init__()
|
||||
self.sm_image_embedder = ImageEmbedder(dim = sm_dim, image_size = image_size, patch_size = sm_patch_size, dropout = emb_dropout)
|
||||
self.lg_image_embedder = ImageEmbedder(dim = lg_dim, image_size = image_size, patch_size = lg_patch_size, dropout = emb_dropout)
|
||||
self.sm_image_embedder = ImageEmbedder(dim = sm_dim, channels= channels, image_size = image_size, patch_size = sm_patch_size, dropout = emb_dropout)
|
||||
self.lg_image_embedder = ImageEmbedder(dim = lg_dim, channels = channels, image_size = image_size, patch_size = lg_patch_size, dropout = emb_dropout)
|
||||
|
||||
self.multi_scale_encoder = MultiScaleEncoder(
|
||||
depth = depth,
|
||||
|
||||
@@ -140,12 +140,13 @@ class CvT(nn.Module):
|
||||
s3_heads = 6,
|
||||
s3_depth = 10,
|
||||
s3_mlp_mult = 4,
|
||||
dropout = 0.
|
||||
dropout = 0.,
|
||||
channels = 3
|
||||
):
|
||||
super().__init__()
|
||||
kwargs = dict(locals())
|
||||
|
||||
dim = 3
|
||||
dim = channels
|
||||
layers = []
|
||||
|
||||
for prefix in ('s1', 's2', 's3'):
|
||||
|
||||
@@ -198,7 +198,7 @@ class NaViT(nn.Module):
|
||||
self.calc_token_dropout = token_dropout_prob
|
||||
|
||||
elif isinstance(token_dropout_prob, (float, int)):
|
||||
assert 0. < token_dropout_prob < 1.
|
||||
assert 0. <= token_dropout_prob < 1.
|
||||
token_dropout_prob = float(token_dropout_prob)
|
||||
self.calc_token_dropout = lambda height, width: token_dropout_prob
|
||||
|
||||
@@ -249,7 +249,7 @@ class NaViT(nn.Module):
|
||||
group_images = False,
|
||||
group_max_seq_len = 2048
|
||||
):
|
||||
p, c, device, has_token_dropout = self.patch_size, self.channels, self.device, exists(self.calc_token_dropout)
|
||||
p, c, device, has_token_dropout = self.patch_size, self.channels, self.device, exists(self.calc_token_dropout) and self.training
|
||||
|
||||
arange = partial(torch.arange, device = device)
|
||||
pad_sequence = partial(orig_pad_sequence, batch_first = True)
|
||||
@@ -260,7 +260,7 @@ class NaViT(nn.Module):
|
||||
batched_images = group_images_by_max_seq_len(
|
||||
batched_images,
|
||||
patch_size = self.patch_size,
|
||||
calc_token_dropout = self.calc_token_dropout,
|
||||
calc_token_dropout = self.calc_token_dropout if self.training else None,
|
||||
max_seq_len = group_max_seq_len
|
||||
)
|
||||
|
||||
@@ -314,8 +314,8 @@ class NaViT(nn.Module):
|
||||
# derive key padding mask
|
||||
|
||||
lengths = torch.tensor([seq.shape[-2] for seq in batched_sequences], device = device, dtype = torch.long)
|
||||
max_length = arange(lengths.amax().item())
|
||||
key_pad_mask = rearrange(lengths, 'b -> b 1') <= rearrange(max_length, 'n -> 1 n')
|
||||
seq_arange = arange(lengths.amax().item())
|
||||
key_pad_mask = rearrange(seq_arange, 'n -> 1 n') < rearrange(lengths, 'b -> b 1')
|
||||
|
||||
# derive attention mask, and combine with key padding mask from above
|
||||
|
||||
|
||||
171
vit_pytorch/simple_flash_attn_vit_3d.py
Normal file
171
vit_pytorch/simple_flash_attn_vit_3d.py
Normal file
@@ -0,0 +1,171 @@
|
||||
from packaging import version
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Module, ModuleList
|
||||
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# constants
|
||||
|
||||
Config = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
|
||||
|
||||
# helpers
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def posemb_sincos_3d(patches, temperature = 10000, dtype = torch.float32):
|
||||
_, f, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype
|
||||
|
||||
z, y, x = torch.meshgrid(
|
||||
torch.arange(f, device = device),
|
||||
torch.arange(h, device = device),
|
||||
torch.arange(w, device = device),
|
||||
indexing = 'ij')
|
||||
|
||||
fourier_dim = dim // 6
|
||||
|
||||
omega = torch.arange(fourier_dim, device = device) / (fourier_dim - 1)
|
||||
omega = 1. / (temperature ** omega)
|
||||
|
||||
z = z.flatten()[:, None] * omega[None, :]
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos(), z.sin(), z.cos()), dim = 1)
|
||||
|
||||
pe = F.pad(pe, (0, dim - (fourier_dim * 6))) # pad if feature dimension not cleanly divisible by 6
|
||||
return pe.type(dtype)
|
||||
|
||||
# main class
|
||||
|
||||
class Attend(Module):
|
||||
def __init__(self, use_flash = False, config: Config = Config(True, True, True)):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.use_flash = use_flash
|
||||
assert not (use_flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
|
||||
|
||||
def flash_attn(self, q, k, v):
|
||||
# flash attention - https://arxiv.org/abs/2205.14135
|
||||
|
||||
with torch.backends.cuda.sdp_kernel(**self.config._asdict()):
|
||||
out = F.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
return out
|
||||
|
||||
def forward(self, q, k, v):
|
||||
n, device, scale = q.shape[-2], q.device, q.shape[-1] ** -0.5
|
||||
|
||||
if self.use_flash:
|
||||
return self.flash_attn(q, k, v)
|
||||
|
||||
# similarity
|
||||
|
||||
sim = einsum("b h i d, b j d -> b h i j", q, k) * scale
|
||||
|
||||
# attention
|
||||
|
||||
attn = sim.softmax(dim=-1)
|
||||
|
||||
# aggregate values
|
||||
|
||||
out = einsum("b h i j, b j d -> b h i d", attn, v)
|
||||
|
||||
return out
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, use_flash = True):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.attend = Attend(use_flash = use_flash)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
out = self.attend(q, k, v)
|
||||
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, use_flash):
|
||||
super().__init__()
|
||||
self.layers = ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, use_flash = use_flash),
|
||||
FeedForward(dim, mlp_dim)
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
|
||||
return x
|
||||
|
||||
class SimpleViT(Module):
|
||||
def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, use_flash_attn = True):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(image_patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
assert frames % frame_patch_size == 0, 'Frames must be divisible by the frame patch size'
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width) * (frames // frame_patch_size)
|
||||
patch_dim = channels * patch_height * patch_width * frame_patch_size
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, use_flash_attn)
|
||||
|
||||
self.to_latent = nn.Identity()
|
||||
self.linear_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, video):
|
||||
*_, h, w, dtype = *video.shape, video.dtype
|
||||
|
||||
x = self.to_patch_embedding(video)
|
||||
pe = posemb_sincos_3d(x)
|
||||
x = rearrange(x, 'b ... d -> b (...) d') + pe
|
||||
|
||||
x = self.transformer(x)
|
||||
x = x.mean(dim = 1)
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.linear_head(x)
|
||||
162
vit_pytorch/simple_vit_with_fft.py
Normal file
162
vit_pytorch/simple_vit_with_fft.py
Normal file
@@ -0,0 +1,162 @@
|
||||
import torch
|
||||
from torch.fft import fft2
|
||||
from torch import nn
|
||||
|
||||
from einops import rearrange, reduce, pack, unpack
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
|
||||
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
|
||||
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
|
||||
omega = torch.arange(dim // 4) / (dim // 4 - 1)
|
||||
omega = 1.0 / (temperature ** omega)
|
||||
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
|
||||
return pe.type(dtype)
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head),
|
||||
FeedForward(dim, mlp_dim)
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return self.norm(x)
|
||||
|
||||
class SimpleViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, freq_patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
freq_patch_height, freq_patch_width = pair(freq_patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
assert image_height % freq_patch_height == 0 and image_width % freq_patch_width == 0, 'Image dimensions must be divisible by the freq patch size.'
|
||||
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
freq_patch_dim = channels * 2 * freq_patch_height * freq_patch_width
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.to_freq_embedding = nn.Sequential(
|
||||
Rearrange("b c (h p1) (w p2) ri -> b (h w) (p1 p2 ri c)", p1 = freq_patch_height, p2 = freq_patch_width),
|
||||
nn.LayerNorm(freq_patch_dim),
|
||||
nn.Linear(freq_patch_dim, dim),
|
||||
nn.LayerNorm(dim)
|
||||
)
|
||||
|
||||
self.pos_embedding = posemb_sincos_2d(
|
||||
h = image_height // patch_height,
|
||||
w = image_width // patch_width,
|
||||
dim = dim,
|
||||
)
|
||||
|
||||
self.freq_pos_embedding = posemb_sincos_2d(
|
||||
h = image_height // freq_patch_height,
|
||||
w = image_width // freq_patch_width,
|
||||
dim = dim
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
|
||||
|
||||
self.pool = "mean"
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.linear_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, img):
|
||||
device, dtype = img.device, img.dtype
|
||||
|
||||
x = self.to_patch_embedding(img)
|
||||
freqs = torch.view_as_real(fft2(img))
|
||||
|
||||
f = self.to_freq_embedding(freqs)
|
||||
|
||||
x += self.pos_embedding.to(device, dtype = dtype)
|
||||
f += self.freq_pos_embedding.to(device, dtype = dtype)
|
||||
|
||||
x, ps = pack((f, x), 'b * d')
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
_, x = unpack(x, ps, 'b * d')
|
||||
x = reduce(x, 'b n d -> b d', 'mean')
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.linear_head(x)
|
||||
|
||||
if __name__ == '__main__':
|
||||
vit = SimpleViT(
|
||||
num_classes = 1000,
|
||||
image_size = 256,
|
||||
patch_size = 8,
|
||||
freq_patch_size = 8,
|
||||
dim = 1024,
|
||||
depth = 1,
|
||||
heads = 8,
|
||||
mlp_dim = 2048,
|
||||
)
|
||||
|
||||
images = torch.randn(8, 3, 256, 256)
|
||||
|
||||
logits = vit(images)
|
||||
@@ -10,7 +10,7 @@ class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Layernorm(dim),
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
|
||||
283
vit_pytorch/xcit.py
Normal file
283
vit_pytorch/xcit.py
Normal file
@@ -0,0 +1,283 @@
|
||||
from random import randrange
|
||||
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
from torch.nn import Module, ModuleList
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat, pack, unpack
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def pack_one(t, pattern):
|
||||
return pack([t], pattern)
|
||||
|
||||
def unpack_one(t, ps, pattern):
|
||||
return unpack(t, ps, pattern)[0]
|
||||
|
||||
def l2norm(t):
|
||||
return F.normalize(t, dim = -1, p = 2)
|
||||
|
||||
def dropout_layers(layers, dropout):
|
||||
if dropout == 0:
|
||||
return layers
|
||||
|
||||
num_layers = len(layers)
|
||||
to_drop = torch.zeros(num_layers).uniform_(0., 1.) < dropout
|
||||
|
||||
# make sure at least one layer makes it
|
||||
if all(to_drop):
|
||||
rand_index = randrange(num_layers)
|
||||
to_drop[rand_index] = False
|
||||
|
||||
layers = [layer for (layer, drop) in zip(layers, to_drop) if not drop]
|
||||
return layers
|
||||
|
||||
# classes
|
||||
|
||||
class LayerScale(Module):
|
||||
def __init__(self, dim, fn, depth):
|
||||
super().__init__()
|
||||
if depth <= 18:
|
||||
init_eps = 0.1
|
||||
elif 18 > depth <= 24:
|
||||
init_eps = 1e-5
|
||||
else:
|
||||
init_eps = 1e-6
|
||||
|
||||
self.fn = fn
|
||||
self.scale = nn.Parameter(torch.full((dim,), init_eps))
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(x, **kwargs) * self.scale
|
||||
|
||||
class FeedForward(Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, context = None):
|
||||
h = self.heads
|
||||
|
||||
x = self.norm(x)
|
||||
context = x if not exists(context) else torch.cat((x, context), dim = 1)
|
||||
|
||||
qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
||||
|
||||
sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
|
||||
attn = self.attend(sim)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class XCAttention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
h = self.heads
|
||||
x, ps = pack_one(x, 'b * d')
|
||||
|
||||
x = self.norm(x)
|
||||
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h d n', h = h), (q, k, v))
|
||||
|
||||
q, k = map(l2norm, (q, k))
|
||||
|
||||
sim = einsum('b h i n, b h j n -> b h i j', q, k) * self.temperature.exp()
|
||||
|
||||
attn = self.attend(sim)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = einsum('b h i j, b h j n -> b h i n', attn, v)
|
||||
out = rearrange(out, 'b h d n -> b n (h d)')
|
||||
|
||||
out = unpack_one(out, ps, 'b * d')
|
||||
return self.to_out(out)
|
||||
|
||||
class LocalPatchInteraction(Module):
|
||||
def __init__(self, dim, kernel_size = 3):
|
||||
super().__init__()
|
||||
assert (kernel_size % 2) == 1
|
||||
padding = kernel_size // 2
|
||||
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
Rearrange('b h w c -> b c h w'),
|
||||
nn.Conv2d(dim, dim, kernel_size, padding = padding, groups = dim),
|
||||
nn.BatchNorm2d(dim),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(dim, dim, kernel_size, padding = padding, groups = dim),
|
||||
Rearrange('b c h w -> b h w c'),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., layer_dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = ModuleList([])
|
||||
self.layer_dropout = layer_dropout
|
||||
|
||||
for ind in range(depth):
|
||||
layer = ind + 1
|
||||
self.layers.append(ModuleList([
|
||||
LayerScale(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), depth = layer),
|
||||
LayerScale(dim, FeedForward(dim, mlp_dim, dropout = dropout), depth = layer)
|
||||
]))
|
||||
|
||||
def forward(self, x, context = None):
|
||||
layers = dropout_layers(self.layers, dropout = self.layer_dropout)
|
||||
|
||||
for attn, ff in layers:
|
||||
x = attn(x, context = context) + x
|
||||
x = ff(x) + x
|
||||
|
||||
return x
|
||||
|
||||
class XCATransformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, local_patch_kernel_size = 3, dropout = 0., layer_dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = ModuleList([])
|
||||
self.layer_dropout = layer_dropout
|
||||
|
||||
for ind in range(depth):
|
||||
layer = ind + 1
|
||||
self.layers.append(ModuleList([
|
||||
LayerScale(dim, XCAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout), depth = layer),
|
||||
LayerScale(dim, LocalPatchInteraction(dim, local_patch_kernel_size), depth = layer),
|
||||
LayerScale(dim, FeedForward(dim, mlp_dim, dropout = dropout), depth = layer)
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
layers = dropout_layers(self.layers, dropout = self.layer_dropout)
|
||||
|
||||
for cross_covariance_attn, local_patch_interaction, ff in layers:
|
||||
x = cross_covariance_attn(x) + x
|
||||
x = local_patch_interaction(x) + x
|
||||
x = ff(x) + x
|
||||
|
||||
return x
|
||||
|
||||
class XCiT(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_size,
|
||||
patch_size,
|
||||
num_classes,
|
||||
dim,
|
||||
depth,
|
||||
cls_depth,
|
||||
heads,
|
||||
mlp_dim,
|
||||
dim_head = 64,
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.,
|
||||
local_patch_kernel_size = 3,
|
||||
layer_dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
patch_dim = 3 * patch_size ** 2
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim)
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(dim))
|
||||
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.xcit_transformer = XCATransformer(dim, depth, heads, dim_head, mlp_dim, local_patch_kernel_size, dropout, layer_dropout)
|
||||
|
||||
self.final_norm = nn.LayerNorm(dim)
|
||||
|
||||
self.cls_transformer = Transformer(dim, cls_depth, heads, dim_head, mlp_dim, dropout, layer_dropout)
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
x = self.to_patch_embedding(img)
|
||||
|
||||
x, ps = pack_one(x, 'b * d')
|
||||
|
||||
b, n, _ = x.shape
|
||||
x += self.pos_embedding[:, :n]
|
||||
|
||||
x = unpack_one(x, ps, 'b * d')
|
||||
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.xcit_transformer(x)
|
||||
|
||||
x = self.final_norm(x)
|
||||
|
||||
cls_tokens = repeat(self.cls_token, 'd -> b 1 d', b = b)
|
||||
|
||||
x = rearrange(x, 'b ... d -> b (...) d')
|
||||
cls_tokens = self.cls_transformer(cls_tokens, context = x)
|
||||
|
||||
return self.mlp_head(cls_tokens[:, 0])
|
||||
Reference in New Issue
Block a user