mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d446a41243 | ||
|
|
0ad09c4cbc | ||
|
|
92b69321f4 | ||
|
|
fb4ac25174 | ||
|
|
53fe345e85 | ||
|
|
efb94608ea | ||
|
|
51310d1d07 |
@@ -777,7 +777,7 @@ pred = mbvit_xs(img) # (1, 1000)
|
||||
|
||||
<img src="./images/xcit.png" width="400px"></img>
|
||||
|
||||
This <a href="https://arxiv.org/abs/2106.09681">paper</a> introduces the cross correlation attention (abbreviated XCA). One can think of it as doing attention across the features dimension rather than the spatial one (another perspective would be a dynamic 1x1 convolution, the kernel being attention map defined by spatial correlations).
|
||||
This <a href="https://arxiv.org/abs/2106.09681">paper</a> introduces the cross covariance attention (abbreviated XCA). One can think of it as doing attention across the features dimension rather than the spatial one (another perspective would be a dynamic 1x1 convolution, the kernel being attention map defined by spatial correlations).
|
||||
|
||||
Technically, this amounts to simply transposing the query, key, values before executing cosine similarity attention with learned temperature.
|
||||
|
||||
|
||||
BIN
images/xcit.png
Normal file
BIN
images/xcit.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 814 KiB |
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '1.6.0',
|
||||
version = '1.6.4',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
long_description_content_type = 'text/markdown',
|
||||
|
||||
@@ -1,10 +1,3 @@
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse('2.0.0'):
|
||||
from einops._torch_specific import allow_ops_in_compiled_graph
|
||||
allow_ops_in_compiled_graph()
|
||||
|
||||
from vit_pytorch.vit import ViT
|
||||
from vit_pytorch.simple_vit import SimpleViT
|
||||
|
||||
|
||||
@@ -140,12 +140,13 @@ class CvT(nn.Module):
|
||||
s3_heads = 6,
|
||||
s3_depth = 10,
|
||||
s3_mlp_mult = 4,
|
||||
dropout = 0.
|
||||
dropout = 0.,
|
||||
channels = 3
|
||||
):
|
||||
super().__init__()
|
||||
kwargs = dict(locals())
|
||||
|
||||
dim = 3
|
||||
dim = channels
|
||||
layers = []
|
||||
|
||||
for prefix in ('s1', 's2', 's3'):
|
||||
|
||||
162
vit_pytorch/simple_vit_with_fft.py
Normal file
162
vit_pytorch/simple_vit_with_fft.py
Normal file
@@ -0,0 +1,162 @@
|
||||
import torch
|
||||
from torch.fft import fft
|
||||
from torch import nn
|
||||
|
||||
from einops import rearrange, reduce, pack, unpack
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
|
||||
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
|
||||
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
|
||||
omega = torch.arange(dim // 4) / (dim // 4 - 1)
|
||||
omega = 1.0 / (temperature ** omega)
|
||||
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
|
||||
return pe.type(dtype)
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head),
|
||||
FeedForward(dim, mlp_dim)
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return self.norm(x)
|
||||
|
||||
class SimpleViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, freq_patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
freq_patch_height, freq_patch_width = pair(freq_patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
assert image_height % freq_patch_height == 0 and image_width % freq_patch_width == 0, 'Image dimensions must be divisible by the freq patch size.'
|
||||
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
freq_patch_dim = channels * 2 * freq_patch_height * freq_patch_width
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.to_freq_embedding = nn.Sequential(
|
||||
Rearrange("b c (h p1) (w p2) ri -> b (h w) (p1 p2 ri c)", p1 = freq_patch_height, p2 = freq_patch_width),
|
||||
nn.LayerNorm(freq_patch_dim),
|
||||
nn.Linear(freq_patch_dim, dim),
|
||||
nn.LayerNorm(dim)
|
||||
)
|
||||
|
||||
self.pos_embedding = posemb_sincos_2d(
|
||||
h = image_height // patch_height,
|
||||
w = image_width // patch_width,
|
||||
dim = dim,
|
||||
)
|
||||
|
||||
self.freq_pos_embedding = posemb_sincos_2d(
|
||||
h = image_height // freq_patch_height,
|
||||
w = image_width // freq_patch_width,
|
||||
dim = dim
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
|
||||
|
||||
self.pool = "mean"
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.linear_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, img):
|
||||
device, dtype = img.device, img.dtype
|
||||
|
||||
x = self.to_patch_embedding(img)
|
||||
freqs = torch.view_as_real(fft(img))
|
||||
|
||||
f = self.to_freq_embedding(freqs)
|
||||
|
||||
x += self.pos_embedding.to(device, dtype = dtype)
|
||||
f += self.freq_pos_embedding.to(device, dtype = dtype)
|
||||
|
||||
x, ps = pack((f, x), 'b * d')
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
_, x = unpack(x, ps, 'b * d')
|
||||
x = reduce(x, 'b n d -> b d', 'mean')
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.linear_head(x)
|
||||
|
||||
if __name__ == '__main__':
|
||||
vit = SimpleViT(
|
||||
num_classes = 1000,
|
||||
image_size = 256,
|
||||
patch_size = 8,
|
||||
freq_patch_size = 8,
|
||||
dim = 1024,
|
||||
depth = 1,
|
||||
heads = 8,
|
||||
mlp_dim = 2048,
|
||||
)
|
||||
|
||||
images = torch.randn(8, 3, 256, 256)
|
||||
|
||||
logits = vit(images)
|
||||
@@ -10,7 +10,7 @@ class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Layernorm(dim),
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
|
||||
Reference in New Issue
Block a user