mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fb5014f0ee | ||
|
|
0b7518ef45 | ||
|
|
077d8c188f | ||
|
|
5888f05300 | ||
|
|
d518e89573 | ||
|
|
dd6462d19b | ||
|
|
a1ee1daa1a | ||
|
|
3cff5e547a | ||
|
|
fdaf7f92b9 |
14
README.md
14
README.md
@@ -49,7 +49,7 @@
|
||||
|
||||
## Vision Transformer - Pytorch
|
||||
|
||||
Implementation of <a href="https://openreview.net/pdf?id=YicbFdNTTy">Vision Transformer</a>, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. Significance is further explained in <a href="https://www.youtube.com/watch?v=TrdevFK_am4">Yannic Kilcher's</a> video. There's really not much to code here, but may as well lay it out for everyone so we expedite the attention revolution.
|
||||
Implementation of <a href="https://openreview.net/pdf?id=YicbFdNTTy">Vision Transformer</a>, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. Significance is further explained in <a href="https://www.youtube.com/watch?v=TrdevFK_am4">Yannic Kilcher's</a> video. There's really not much to code here, but may as well lay it out for everyone so we expedite the [attention](https://www.youtube.com/watch?v=eMlx5fFNoYc) revolution.
|
||||
|
||||
For a Pytorch implementation with pretrained models, please see Ross Wightman's repository <a href="https://github.com/rwightman/pytorch-image-models">here</a>.
|
||||
|
||||
@@ -2213,4 +2213,16 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{gopalakrishnan2025decouplingwhatwherepolar,
|
||||
title = {Decoupling the "What" and "Where" With Polar Coordinate Positional Embeddings},
|
||||
author = {Anand Gopalakrishnan and Robert Csordás and Jürgen Schmidhuber and Michael C. Mozer},
|
||||
year = {2025},
|
||||
eprint = {2509.10534},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.LG},
|
||||
url = {https://arxiv.org/abs/2509.10534},
|
||||
}
|
||||
```
|
||||
|
||||
*I visualise a time when we will be to robots what dogs are to humans, and I’m rooting for the machines.* — Claude Shannon
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "vit-pytorch"
|
||||
version = "1.16.0"
|
||||
version = "1.17.1"
|
||||
description = "Vision Transformer (ViT) - Pytorch"
|
||||
readme = { file = "README.md", content-type = "text/markdown" }
|
||||
license = { file = "LICENSE" }
|
||||
|
||||
@@ -25,12 +25,12 @@ class DistillMixin:
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = b)
|
||||
cls_tokens = repeat(self.cls_token, 'n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim = 1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x += self.pos_embedding[:(n + 1)]
|
||||
|
||||
if distilling:
|
||||
distill_tokens = repeat(distill_token, '1 n d -> b n d', b = b)
|
||||
distill_tokens = repeat(distill_token, 'n d -> b n d', b = b)
|
||||
x = torch.cat((x, distill_tokens), dim = 1)
|
||||
|
||||
x = self._attend(x)
|
||||
@@ -125,7 +125,7 @@ class DistillWrapper(Module):
|
||||
self.alpha = alpha
|
||||
self.hard = hard
|
||||
|
||||
self.distillation_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.distillation_token = nn.Parameter(torch.randn(1, dim))
|
||||
|
||||
self.distill_mlp = nn.Sequential(
|
||||
nn.LayerNorm(dim) if mlp_layernorm else nn.Identity(),
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
from functools import partial, lru_cache
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
@@ -9,7 +9,6 @@ from torch import nn, Tensor
|
||||
from torch.nn.utils.rnn import pad_sequence as orig_pad_sequence
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
@@ -28,6 +27,12 @@ def pair(t):
|
||||
def divisible_by(numer, denom):
|
||||
return (numer % denom) == 0
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def posemb_grid(ph, pw, device):
|
||||
h_idx = torch.arange(ph, device=device).repeat_interleave(pw)
|
||||
w_idx = torch.arange(pw, device=device).repeat(ph)
|
||||
return torch.stack([h_idx, w_idx], dim=-1)
|
||||
|
||||
# auto grouping images
|
||||
|
||||
def group_images_by_max_seq_len(
|
||||
@@ -117,8 +122,7 @@ class Attention(nn.Module):
|
||||
self.q_norm = RMSNorm(heads, dim_head)
|
||||
self.k_norm = RMSNorm(heads, dim_head)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.dropout_p = dropout
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
|
||||
@@ -145,19 +149,22 @@ class Attention(nn.Module):
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2))
|
||||
# combine masks if both exist
|
||||
if exists(mask) or exists(attn_mask):
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b j -> b 1 1 j')
|
||||
if exists(mask) and exists(attn_mask):
|
||||
attn_mask = mask & attn_mask
|
||||
elif exists(mask):
|
||||
attn_mask = mask
|
||||
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b j -> b 1 1 j')
|
||||
dots = dots.masked_fill(~mask, -torch.finfo(dots.dtype).max)
|
||||
out = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask = attn_mask,
|
||||
dropout_p = self.dropout_p if self.training else 0.,
|
||||
scale = 1. # RMSNorm already includes sqrt(dim) scaling
|
||||
)
|
||||
|
||||
if exists(attn_mask):
|
||||
dots = dots.masked_fill(~attn_mask, -torch.finfo(dots.dtype).max)
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
@@ -281,42 +288,41 @@ class NaViT(nn.Module):
|
||||
for images in batched_images:
|
||||
num_images.append(len(images))
|
||||
|
||||
sequences = []
|
||||
positions = []
|
||||
image_ids = torch.empty((0,), device = device, dtype = torch.long)
|
||||
|
||||
for image_id, image in enumerate(images):
|
||||
assert image.ndim ==3 and image.shape[0] == c
|
||||
# compute patch dimensions for all images
|
||||
patch_dims = []
|
||||
for image in images:
|
||||
assert image.ndim == 3 and image.shape[0] == c
|
||||
image_dims = image.shape[-2:]
|
||||
assert all([divisible_by(dim, p) for dim in image_dims]), f'height and width {image_dims} of images must be divisible by patch size {p}'
|
||||
patch_dims.append((image_dims[0] // p, image_dims[1] // p))
|
||||
|
||||
ph, pw = map(lambda dim: dim // p, image_dims)
|
||||
# extract patches for all images
|
||||
sequences = [rearrange(img, 'c (h p1) (w p2) -> (h w) (c p1 p2)', p1=p, p2=p) for img in images]
|
||||
|
||||
pos = torch.stack(torch.meshgrid((
|
||||
arange(ph),
|
||||
arange(pw)
|
||||
), indexing = 'ij'), dim = -1)
|
||||
# compute positions - uses lru_cache to avoid redundant computation across forward passes
|
||||
positions = [posemb_grid(ph, pw, device) for ph, pw in patch_dims]
|
||||
|
||||
pos = rearrange(pos, 'h w c -> (h w) c')
|
||||
seq = rearrange(image, 'c (h p1) (w p2) -> (h w) (c p1 p2)', p1 = p, p2 = p)
|
||||
|
||||
seq_len = seq.shape[-2]
|
||||
|
||||
if has_token_dropout:
|
||||
# handle token dropout
|
||||
if has_token_dropout:
|
||||
for i, (seq, pos) in enumerate(zip(sequences, positions)):
|
||||
image_dims = images[i].shape[-2:]
|
||||
token_dropout = self.calc_token_dropout(*image_dims)
|
||||
seq_len = seq.shape[0]
|
||||
num_keep = max(1, int(seq_len * (1 - token_dropout)))
|
||||
keep_indices = torch.randn((seq_len,), device = device).topk(num_keep, dim = -1).indices
|
||||
keep_indices = torch.randn((seq_len,), device=device).topk(num_keep, dim=-1).indices
|
||||
sequences[i] = seq[keep_indices]
|
||||
positions[i] = pos[keep_indices]
|
||||
|
||||
seq = seq[keep_indices]
|
||||
pos = pos[keep_indices]
|
||||
|
||||
image_ids = F.pad(image_ids, (0, seq.shape[-2]), value = image_id)
|
||||
sequences.append(seq)
|
||||
positions.append(pos)
|
||||
# build image_ids efficiently using repeat_interleave
|
||||
patch_counts = [seq.shape[0] for seq in sequences]
|
||||
image_ids = torch.repeat_interleave(
|
||||
arange(len(images)),
|
||||
torch.tensor(patch_counts, device=device)
|
||||
)
|
||||
|
||||
batched_image_ids.append(image_ids)
|
||||
batched_sequences.append(torch.cat(sequences, dim = 0))
|
||||
batched_positions.append(torch.cat(positions, dim = 0))
|
||||
batched_sequences.append(torch.cat(sequences, dim=0))
|
||||
batched_positions.append(torch.cat(positions, dim=0))
|
||||
|
||||
# derive key padding mask
|
||||
|
||||
|
||||
@@ -176,7 +176,7 @@ class NaViT(Module):
|
||||
|
||||
self.channels = channels
|
||||
self.patch_size = patch_size
|
||||
self.to_patches = Rearrange('c (f pf) (h p1) (w p2) -> f h w (c p1 p2 pf)', p1 = patch_size, p2 = patch_size, pf = frame_patch_size)
|
||||
self.to_patches = Rearrange('c (f pf) (h p1) (w p2) -> f h w (c pf p1 p2)', p1 = patch_size, p2 = patch_size, pf = frame_patch_size)
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
nn.LayerNorm(patch_dim),
|
||||
|
||||
@@ -146,7 +146,7 @@ class SimpleViT(Module):
|
||||
patch_dim = channels * patch_height * patch_width * frame_patch_size
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (pf p1 p2 c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
|
||||
@@ -103,7 +103,7 @@ class SimpleViT(nn.Module):
|
||||
patch_dim = channels * patch_height * patch_width * frame_patch_size
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (pf p1 p2 c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Module, ModuleList
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
@@ -11,7 +12,7 @@ def pair(t):
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
class FeedForward(Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
@@ -26,7 +27,7 @@ class FeedForward(nn.Module):
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
@@ -62,13 +63,14 @@ class Attention(nn.Module):
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = nn.ModuleList([])
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
@@ -80,7 +82,7 @@ class Transformer(nn.Module):
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class ViT(nn.Module):
|
||||
class ViT(Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
@@ -101,8 +103,8 @@ class ViT(nn.Module):
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.cls_token = nn.Parameter(torch.randn(1, num_cls_tokens, dim))
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + num_cls_tokens, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(num_cls_tokens, dim))
|
||||
self.pos_embedding = nn.Parameter(torch.randn(num_patches + num_cls_tokens, dim))
|
||||
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
@@ -114,12 +116,15 @@ class ViT(nn.Module):
|
||||
self.mlp_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, img):
|
||||
batch = img.shape[0]
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '1 ... d -> b ... d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
cls_tokens = repeat(self.cls_token, '... d -> b ... d', b = batch)
|
||||
x = torch.cat((cls_tokens, x), dim = 1)
|
||||
|
||||
seq = x.shape[1]
|
||||
|
||||
x = x + self.pos_embedding[:seq]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
@@ -89,7 +89,7 @@ class ViT(nn.Module):
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b (f h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b (f h w) (pf p1 p2 c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
|
||||
353
vit_pytorch/vit_nd_pope.py
Normal file
353
vit_pytorch/vit_nd_pope.py
Normal file
@@ -0,0 +1,353 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import pi, nn, arange, cat, stack, Tensor
|
||||
from torch.nn import Module, ModuleList
|
||||
from torch.amp import autocast
|
||||
|
||||
from einops import rearrange, repeat, reduce, pack, unpack
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def l2norm(t):
|
||||
return F.normalize(t, dim = -1, p = 2)
|
||||
|
||||
def join(arr, delimiter = ' '):
|
||||
return delimiter.join(arr)
|
||||
|
||||
def ensure_tuple(t, length):
|
||||
if isinstance(t, (tuple, list)):
|
||||
assert len(t) == length, f'Expected tuple of length {length}, got {len(t)}'
|
||||
return tuple(t)
|
||||
|
||||
return (t,) * length
|
||||
|
||||
# golden gate rotary - Jerry Xiong, PhD student at UIUC
|
||||
# https://jerryxio.ng/posts/nd-rope/
|
||||
|
||||
# but using polar version instead
|
||||
# Gopalakrishnan et al. https://arxiv.org/abs/2509.10534
|
||||
|
||||
def _phi(m: int) -> float:
|
||||
x = 2.0
|
||||
for _ in range(10):
|
||||
x = (1 + x) ** (1.0 / (m + 1.0))
|
||||
return x
|
||||
|
||||
def make_directions(n: int, d: int) -> Tensor:
|
||||
g = _phi(d)
|
||||
alpha = (1.0 / g) ** arange(1, d + 1, dtype = torch.float64)
|
||||
i = arange(1, n + 1, dtype = torch.float64).unsqueeze(1)
|
||||
z = torch.fmod(i * alpha, 1.0)
|
||||
directions = torch.erfinv(2.0 * z - 1.0)
|
||||
directions = l2norm(directions)
|
||||
return directions.float()
|
||||
|
||||
class GoldenGatePoPENd(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim_pos: int,
|
||||
heads: int,
|
||||
dim_head: int,
|
||||
min_freq: float = 1.0,
|
||||
max_freq: float = 10000.0,
|
||||
p_zero_freqs: float = 0.0, # proportion of frequencies set to 0
|
||||
init_learned_bias_uniform = False
|
||||
):
|
||||
super().__init__()
|
||||
n_freqs = dim_head
|
||||
n_zero_freqs = round(p_zero_freqs * n_freqs)
|
||||
|
||||
omega = cat((
|
||||
torch.zeros(n_zero_freqs),
|
||||
min_freq * (max_freq / min_freq) ** torch.linspace(0, 1, n_freqs - n_zero_freqs),
|
||||
))
|
||||
|
||||
directions = rearrange(
|
||||
make_directions(heads * n_freqs, dim_pos),
|
||||
'(h f) p -> h f p',
|
||||
h = heads
|
||||
)
|
||||
|
||||
omega_expanded = rearrange(omega, 'f -> f 1')
|
||||
self.register_buffer('freqs', directions * omega_expanded) # shape: (h, f, p)
|
||||
|
||||
self.learned_bias = nn.Parameter(torch.zeros(heads, dim_head))
|
||||
|
||||
if init_learned_bias_uniform:
|
||||
self.learned_bias.uniform_(-2. * pi, 0.)
|
||||
|
||||
@autocast('cuda', enabled = False)
|
||||
def forward(self, pos):
|
||||
|
||||
freqs = rearrange(self.freqs, 'h f p -> 1 h 1 f p')
|
||||
positions = rearrange(pos.float(), 'b n p -> b 1 n 1 p')
|
||||
|
||||
# compute theta for each (batch, head, seq, freq)
|
||||
|
||||
theta = reduce(freqs * positions, 'b h n f p -> b h n f', 'sum')
|
||||
|
||||
bias = self.learned_bias.clamp(-2. * pi, 0.)
|
||||
bias = rearrange(bias, 'h d -> h 1 d')
|
||||
|
||||
return theta, bias
|
||||
|
||||
@autocast('cuda', enabled = False)
|
||||
def apply_polar_pos_emb(t, freqs):
|
||||
orig_dtype = t.dtype
|
||||
|
||||
t = t.float()
|
||||
t = F.softplus(t)
|
||||
|
||||
out = cat((t * freqs.cos(), t * freqs.sin()), dim = -1)
|
||||
|
||||
return out.type(orig_dtype)
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qk = nn.Linear(dim, inner_dim * 2, bias = False)
|
||||
self.to_v = nn.Linear(dim, inner_dim, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x, polar_pos_emb = None):
|
||||
x = self.norm(x)
|
||||
qkv = (*self.to_qk(x).chunk(2, dim = -1), self.to_v(x))
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
if exists(polar_pos_emb):
|
||||
freqs, bias = polar_pos_emb
|
||||
q = apply_polar_pos_emb(q, freqs)
|
||||
k = apply_polar_pos_emb(k, freqs + bias)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., polar_emb = None):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.polar_emb = polar_emb
|
||||
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
def forward(self, x, pos = None):
|
||||
|
||||
# pope embedding
|
||||
|
||||
polar_pos_emb = None
|
||||
if exists(pos) and exists(self.polar_emb):
|
||||
polar_pos_emb = self.polar_emb(pos)
|
||||
|
||||
# transformer layers
|
||||
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x, polar_pos_emb) + x
|
||||
x = ff(x) + x
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class ViTND(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ndim: int,
|
||||
input_shape: int | tuple[int, ...],
|
||||
patch_size: int | tuple[int, ...],
|
||||
num_classes: int,
|
||||
dim: int,
|
||||
depth: int,
|
||||
heads: int,
|
||||
mlp_dim: int,
|
||||
channels: int = 3,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.,
|
||||
emb_dropout: float = 0.,
|
||||
pope_min_freq: float = 1.0,
|
||||
pope_max_freq: float = 10000.0,
|
||||
pope_p_zero_freqs: float = 0.0,
|
||||
pope_init_learned_bias_uniform = False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert 1 <= ndim <= 7, 'ndim must be between 1 and 7'
|
||||
|
||||
self.ndim = ndim
|
||||
|
||||
input_shape = ensure_tuple(input_shape, ndim)
|
||||
patch_size = ensure_tuple(patch_size, ndim)
|
||||
|
||||
for i, (inp_dim, patch_dim) in enumerate(zip(input_shape, patch_size)):
|
||||
assert inp_dim % patch_dim == 0, f'Input dimension {i} ({inp_dim}) must be divisible by patch size ({patch_dim})'
|
||||
|
||||
num_patches_per_dim = [inp_dim // patch_dim for inp_dim, patch_dim in zip(input_shape, patch_size)]
|
||||
num_patches = 1
|
||||
for n in num_patches_per_dim:
|
||||
num_patches *= n
|
||||
|
||||
patch_dim = channels
|
||||
for p in patch_size:
|
||||
patch_dim *= p
|
||||
|
||||
dim_names = 'fghijkl'[:ndim]
|
||||
|
||||
input_dims = [f'({d} p{i})' for i, d in enumerate(dim_names)]
|
||||
patch_dims = [f'p{i}' for i in range(ndim)]
|
||||
|
||||
input_pattern = f'b c {join(input_dims)}'
|
||||
output_pattern = f'b {join(dim_names)} ({join(patch_dims)} c)'
|
||||
rearrange_str = f'{input_pattern} -> {output_pattern}'
|
||||
|
||||
rearrange_kwargs = {f'p{i}': p for i, p in enumerate(patch_size)}
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange(rearrange_str, **rearrange_kwargs),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
# golden gate pope
|
||||
|
||||
self.polar_emb = GoldenGatePoPENd(
|
||||
dim_pos = ndim,
|
||||
heads = heads,
|
||||
dim_head = dim_head,
|
||||
min_freq = pope_min_freq,
|
||||
max_freq = pope_max_freq,
|
||||
p_zero_freqs = pope_p_zero_freqs,
|
||||
init_learned_bias_uniform = pope_init_learned_bias_uniform
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, polar_emb = self.polar_emb)
|
||||
|
||||
self.to_latent = nn.Identity()
|
||||
self.mlp_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def muon_parameters(self):
|
||||
params = []
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, Attention):
|
||||
params.extend([
|
||||
m.to_v.weight,
|
||||
m.to_out[0].weight
|
||||
])
|
||||
elif isinstance(m, FeedForward):
|
||||
params.extend([
|
||||
m.net[1].weight,
|
||||
m.net[-2].weight
|
||||
])
|
||||
|
||||
return params
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
return_embed = False
|
||||
):
|
||||
x = self.to_patch_embedding(x) # (b, *spatial_dims, patch_dim)
|
||||
|
||||
batch, *spatial_dims, _, device = *x.shape, x.device
|
||||
|
||||
# Generate position coordinates
|
||||
|
||||
grids = [arange(d, device = device, dtype = torch.float32) for d in spatial_dims]
|
||||
grid = torch.meshgrid(*grids, indexing = 'ij')
|
||||
pos = stack(grid, dim = -1) # (*spatial_dims, ndim)
|
||||
|
||||
# flatten spatial dimensions for attention with nd rotary
|
||||
|
||||
pos = repeat(pos, '... p -> b (...) p', b = batch)
|
||||
x, packed_shape = pack([x], 'b * d')
|
||||
|
||||
x = self.dropout(x)
|
||||
|
||||
embed = self.transformer(x, pos)
|
||||
|
||||
# return the embed with reconstituted patch shape
|
||||
|
||||
if return_embed:
|
||||
embed, = unpack(embed, packed_shape, 'b * d')
|
||||
return embed
|
||||
|
||||
# pooling to logits
|
||||
|
||||
pooled = reduce(embed, 'b n d -> b d', 'mean')
|
||||
|
||||
pooled = self.to_latent(pooled)
|
||||
return self.mlp_head(pooled)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
model = ViTND(
|
||||
ndim = 5,
|
||||
input_shape = (4, 8, 16, 32, 64),
|
||||
patch_size = (2, 2, 4, 4, 8),
|
||||
num_classes = 1000,
|
||||
dim = 512,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048,
|
||||
channels = 3,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
data = torch.randn(3, 3, 4, 8, 16, 32, 64)
|
||||
|
||||
logits = model(data)
|
||||
|
||||
embed = model(data, return_embed = True)
|
||||
assert embed.shape == (3, 2, 4, 4, 8, 8, 512)
|
||||
@@ -141,7 +141,7 @@ class ViT(nn.Module):
|
||||
self.global_average_pool = pool == 'mean'
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b f (h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b f (h w) (pf p1 p2 c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim)
|
||||
|
||||
Reference in New Issue
Block a user