mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 16:12:29 +00:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bca88e9039 | ||
|
|
96f66d2754 | ||
|
|
12249dcc5f | ||
|
|
8b8da8dede | ||
|
|
5578ac472f | ||
|
|
d446a41243 | ||
|
|
0ad09c4cbc |
6
setup.py
6
setup.py
@@ -1,11 +1,15 @@
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
with open('README.md') as f:
|
||||
long_description = f.read()
|
||||
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '1.6.2',
|
||||
version = '1.6.8',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
long_description=long_description,
|
||||
long_description_content_type = 'text/markdown',
|
||||
author = 'Phil Wang',
|
||||
author_email = 'lucidrains@gmail.com',
|
||||
|
||||
@@ -170,12 +170,13 @@ class ImageEmbedder(nn.Module):
|
||||
dim,
|
||||
image_size,
|
||||
patch_size,
|
||||
dropout = 0.
|
||||
dropout = 0.,
|
||||
channels = 3
|
||||
):
|
||||
super().__init__()
|
||||
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
patch_dim = 3 * patch_size ** 2
|
||||
patch_dim = channels * patch_size ** 2
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
||||
@@ -223,11 +224,12 @@ class CrossViT(nn.Module):
|
||||
cross_attn_dim_head = 64,
|
||||
depth = 3,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
emb_dropout = 0.1,
|
||||
channels = 3
|
||||
):
|
||||
super().__init__()
|
||||
self.sm_image_embedder = ImageEmbedder(dim = sm_dim, image_size = image_size, patch_size = sm_patch_size, dropout = emb_dropout)
|
||||
self.lg_image_embedder = ImageEmbedder(dim = lg_dim, image_size = image_size, patch_size = lg_patch_size, dropout = emb_dropout)
|
||||
self.sm_image_embedder = ImageEmbedder(dim = sm_dim, channels= channels, image_size = image_size, patch_size = sm_patch_size, dropout = emb_dropout)
|
||||
self.lg_image_embedder = ImageEmbedder(dim = lg_dim, channels = channels, image_size = image_size, patch_size = lg_patch_size, dropout = emb_dropout)
|
||||
|
||||
self.multi_scale_encoder = MultiScaleEncoder(
|
||||
depth = depth,
|
||||
|
||||
@@ -140,12 +140,13 @@ class CvT(nn.Module):
|
||||
s3_heads = 6,
|
||||
s3_depth = 10,
|
||||
s3_mlp_mult = 4,
|
||||
dropout = 0.
|
||||
dropout = 0.,
|
||||
channels = 3
|
||||
):
|
||||
super().__init__()
|
||||
kwargs = dict(locals())
|
||||
|
||||
dim = 3
|
||||
dim = channels
|
||||
layers = []
|
||||
|
||||
for prefix in ('s1', 's2', 's3'):
|
||||
|
||||
@@ -198,7 +198,7 @@ class NaViT(nn.Module):
|
||||
self.calc_token_dropout = token_dropout_prob
|
||||
|
||||
elif isinstance(token_dropout_prob, (float, int)):
|
||||
assert 0. < token_dropout_prob < 1.
|
||||
assert 0. <= token_dropout_prob < 1.
|
||||
token_dropout_prob = float(token_dropout_prob)
|
||||
self.calc_token_dropout = lambda height, width: token_dropout_prob
|
||||
|
||||
@@ -249,7 +249,7 @@ class NaViT(nn.Module):
|
||||
group_images = False,
|
||||
group_max_seq_len = 2048
|
||||
):
|
||||
p, c, device, has_token_dropout = self.patch_size, self.channels, self.device, exists(self.calc_token_dropout)
|
||||
p, c, device, has_token_dropout = self.patch_size, self.channels, self.device, exists(self.calc_token_dropout) and self.training
|
||||
|
||||
arange = partial(torch.arange, device = device)
|
||||
pad_sequence = partial(orig_pad_sequence, batch_first = True)
|
||||
@@ -260,7 +260,7 @@ class NaViT(nn.Module):
|
||||
batched_images = group_images_by_max_seq_len(
|
||||
batched_images,
|
||||
patch_size = self.patch_size,
|
||||
calc_token_dropout = self.calc_token_dropout,
|
||||
calc_token_dropout = self.calc_token_dropout if self.training else None,
|
||||
max_seq_len = group_max_seq_len
|
||||
)
|
||||
|
||||
@@ -314,8 +314,8 @@ class NaViT(nn.Module):
|
||||
# derive key padding mask
|
||||
|
||||
lengths = torch.tensor([seq.shape[-2] for seq in batched_sequences], device = device, dtype = torch.long)
|
||||
max_length = arange(lengths.amax().item())
|
||||
key_pad_mask = rearrange(lengths, 'b -> b 1') <= rearrange(max_length, 'n -> 1 n')
|
||||
seq_arange = arange(lengths.amax().item())
|
||||
key_pad_mask = rearrange(seq_arange, 'n -> 1 n') < rearrange(lengths, 'b -> b 1')
|
||||
|
||||
# derive attention mask, and combine with key padding mask from above
|
||||
|
||||
|
||||
171
vit_pytorch/simple_flash_attn_vit_3d.py
Normal file
171
vit_pytorch/simple_flash_attn_vit_3d.py
Normal file
@@ -0,0 +1,171 @@
|
||||
from packaging import version
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Module, ModuleList
|
||||
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# constants
|
||||
|
||||
Config = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
|
||||
|
||||
# helpers
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def posemb_sincos_3d(patches, temperature = 10000, dtype = torch.float32):
|
||||
_, f, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype
|
||||
|
||||
z, y, x = torch.meshgrid(
|
||||
torch.arange(f, device = device),
|
||||
torch.arange(h, device = device),
|
||||
torch.arange(w, device = device),
|
||||
indexing = 'ij')
|
||||
|
||||
fourier_dim = dim // 6
|
||||
|
||||
omega = torch.arange(fourier_dim, device = device) / (fourier_dim - 1)
|
||||
omega = 1. / (temperature ** omega)
|
||||
|
||||
z = z.flatten()[:, None] * omega[None, :]
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos(), z.sin(), z.cos()), dim = 1)
|
||||
|
||||
pe = F.pad(pe, (0, dim - (fourier_dim * 6))) # pad if feature dimension not cleanly divisible by 6
|
||||
return pe.type(dtype)
|
||||
|
||||
# main class
|
||||
|
||||
class Attend(Module):
|
||||
def __init__(self, use_flash = False, config: Config = Config(True, True, True)):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.use_flash = use_flash
|
||||
assert not (use_flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
|
||||
|
||||
def flash_attn(self, q, k, v):
|
||||
# flash attention - https://arxiv.org/abs/2205.14135
|
||||
|
||||
with torch.backends.cuda.sdp_kernel(**self.config._asdict()):
|
||||
out = F.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
return out
|
||||
|
||||
def forward(self, q, k, v):
|
||||
n, device, scale = q.shape[-2], q.device, q.shape[-1] ** -0.5
|
||||
|
||||
if self.use_flash:
|
||||
return self.flash_attn(q, k, v)
|
||||
|
||||
# similarity
|
||||
|
||||
sim = einsum("b h i d, b j d -> b h i j", q, k) * scale
|
||||
|
||||
# attention
|
||||
|
||||
attn = sim.softmax(dim=-1)
|
||||
|
||||
# aggregate values
|
||||
|
||||
out = einsum("b h i j, b j d -> b h i d", attn, v)
|
||||
|
||||
return out
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, use_flash = True):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.attend = Attend(use_flash = use_flash)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
out = self.attend(q, k, v)
|
||||
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, use_flash):
|
||||
super().__init__()
|
||||
self.layers = ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, use_flash = use_flash),
|
||||
FeedForward(dim, mlp_dim)
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
|
||||
return x
|
||||
|
||||
class SimpleViT(Module):
|
||||
def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, use_flash_attn = True):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(image_patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
assert frames % frame_patch_size == 0, 'Frames must be divisible by the frame patch size'
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width) * (frames // frame_patch_size)
|
||||
patch_dim = channels * patch_height * patch_width * frame_patch_size
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, use_flash_attn)
|
||||
|
||||
self.to_latent = nn.Identity()
|
||||
self.linear_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, video):
|
||||
*_, h, w, dtype = *video.shape, video.dtype
|
||||
|
||||
x = self.to_patch_embedding(video)
|
||||
pe = posemb_sincos_3d(x)
|
||||
x = rearrange(x, 'b ... d -> b (...) d') + pe
|
||||
|
||||
x = self.transformer(x)
|
||||
x = x.mean(dim = 1)
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.linear_head(x)
|
||||
162
vit_pytorch/simple_vit_with_fft.py
Normal file
162
vit_pytorch/simple_vit_with_fft.py
Normal file
@@ -0,0 +1,162 @@
|
||||
import torch
|
||||
from torch.fft import fft2
|
||||
from torch import nn
|
||||
|
||||
from einops import rearrange, reduce, pack, unpack
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
|
||||
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
|
||||
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
|
||||
omega = torch.arange(dim // 4) / (dim // 4 - 1)
|
||||
omega = 1.0 / (temperature ** omega)
|
||||
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
|
||||
return pe.type(dtype)
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head),
|
||||
FeedForward(dim, mlp_dim)
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return self.norm(x)
|
||||
|
||||
class SimpleViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, freq_patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
freq_patch_height, freq_patch_width = pair(freq_patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
assert image_height % freq_patch_height == 0 and image_width % freq_patch_width == 0, 'Image dimensions must be divisible by the freq patch size.'
|
||||
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
freq_patch_dim = channels * 2 * freq_patch_height * freq_patch_width
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.to_freq_embedding = nn.Sequential(
|
||||
Rearrange("b c (h p1) (w p2) ri -> b (h w) (p1 p2 ri c)", p1 = freq_patch_height, p2 = freq_patch_width),
|
||||
nn.LayerNorm(freq_patch_dim),
|
||||
nn.Linear(freq_patch_dim, dim),
|
||||
nn.LayerNorm(dim)
|
||||
)
|
||||
|
||||
self.pos_embedding = posemb_sincos_2d(
|
||||
h = image_height // patch_height,
|
||||
w = image_width // patch_width,
|
||||
dim = dim,
|
||||
)
|
||||
|
||||
self.freq_pos_embedding = posemb_sincos_2d(
|
||||
h = image_height // freq_patch_height,
|
||||
w = image_width // freq_patch_width,
|
||||
dim = dim
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
|
||||
|
||||
self.pool = "mean"
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.linear_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, img):
|
||||
device, dtype = img.device, img.dtype
|
||||
|
||||
x = self.to_patch_embedding(img)
|
||||
freqs = torch.view_as_real(fft2(img))
|
||||
|
||||
f = self.to_freq_embedding(freqs)
|
||||
|
||||
x += self.pos_embedding.to(device, dtype = dtype)
|
||||
f += self.freq_pos_embedding.to(device, dtype = dtype)
|
||||
|
||||
x, ps = pack((f, x), 'b * d')
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
_, x = unpack(x, ps, 'b * d')
|
||||
x = reduce(x, 'b n d -> b d', 'mean')
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.linear_head(x)
|
||||
|
||||
if __name__ == '__main__':
|
||||
vit = SimpleViT(
|
||||
num_classes = 1000,
|
||||
image_size = 256,
|
||||
patch_size = 8,
|
||||
freq_patch_size = 8,
|
||||
dim = 1024,
|
||||
depth = 1,
|
||||
heads = 8,
|
||||
mlp_dim = 2048,
|
||||
)
|
||||
|
||||
images = torch.randn(8, 3, 256, 256)
|
||||
|
||||
logits = vit(images)
|
||||
Reference in New Issue
Block a user