2024-08-20 15:12:29 -07:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
2025-12-07 13:32:30 +01:00
|
|
|
from functools import partial, lru_cache
|
2024-08-20 15:12:29 -07:00
|
|
|
from typing import List
|
2023-07-24 13:54:02 -07:00
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
from torch import nn, Tensor
|
|
|
|
|
from torch.nn.utils.rnn import pad_sequence as orig_pad_sequence
|
|
|
|
|
|
|
|
|
|
from einops import rearrange, repeat
|
|
|
|
|
|
|
|
|
|
# helpers
|
|
|
|
|
|
|
|
|
|
def exists(val):
|
|
|
|
|
return val is not None
|
|
|
|
|
|
|
|
|
|
def default(val, d):
|
|
|
|
|
return val if exists(val) else d
|
|
|
|
|
|
2023-07-25 10:38:55 -07:00
|
|
|
def always(val):
|
|
|
|
|
return lambda *args: val
|
|
|
|
|
|
2023-07-24 13:54:02 -07:00
|
|
|
def pair(t):
|
|
|
|
|
return t if isinstance(t, tuple) else (t, t)
|
|
|
|
|
|
|
|
|
|
def divisible_by(numer, denom):
|
|
|
|
|
return (numer % denom) == 0
|
|
|
|
|
|
2025-12-07 13:32:30 +01:00
|
|
|
@lru_cache(maxsize=128)
|
|
|
|
|
def posemb_grid(ph, pw, device):
|
|
|
|
|
h_idx = torch.arange(ph, device=device).repeat_interleave(pw)
|
|
|
|
|
w_idx = torch.arange(pw, device=device).repeat(ph)
|
|
|
|
|
return torch.stack([h_idx, w_idx], dim=-1)
|
|
|
|
|
|
2023-07-25 10:38:55 -07:00
|
|
|
# auto grouping images
|
|
|
|
|
|
|
|
|
|
def group_images_by_max_seq_len(
|
|
|
|
|
images: List[Tensor],
|
|
|
|
|
patch_size: int,
|
|
|
|
|
calc_token_dropout = None,
|
|
|
|
|
max_seq_len = 2048
|
|
|
|
|
|
|
|
|
|
) -> List[List[Tensor]]:
|
|
|
|
|
|
|
|
|
|
calc_token_dropout = default(calc_token_dropout, always(0.))
|
|
|
|
|
|
|
|
|
|
groups = []
|
|
|
|
|
group = []
|
|
|
|
|
seq_len = 0
|
|
|
|
|
|
|
|
|
|
if isinstance(calc_token_dropout, (float, int)):
|
|
|
|
|
calc_token_dropout = always(calc_token_dropout)
|
|
|
|
|
|
|
|
|
|
for image in images:
|
|
|
|
|
assert isinstance(image, Tensor)
|
|
|
|
|
|
|
|
|
|
image_dims = image.shape[-2:]
|
|
|
|
|
ph, pw = map(lambda t: t // patch_size, image_dims)
|
|
|
|
|
|
|
|
|
|
image_seq_len = (ph * pw)
|
|
|
|
|
image_seq_len = int(image_seq_len * (1 - calc_token_dropout(*image_dims)))
|
|
|
|
|
|
|
|
|
|
assert image_seq_len <= max_seq_len, f'image with dimensions {image_dims} exceeds maximum sequence length'
|
|
|
|
|
|
|
|
|
|
if (seq_len + image_seq_len) > max_seq_len:
|
|
|
|
|
groups.append(group)
|
|
|
|
|
group = []
|
|
|
|
|
seq_len = 0
|
|
|
|
|
|
|
|
|
|
group.append(image)
|
|
|
|
|
seq_len += image_seq_len
|
|
|
|
|
|
|
|
|
|
if len(group) > 0:
|
|
|
|
|
groups.append(group)
|
|
|
|
|
|
|
|
|
|
return groups
|
|
|
|
|
|
2023-07-24 13:54:02 -07:00
|
|
|
# normalization
|
|
|
|
|
# they use layernorm without bias, something that pytorch does not offer
|
|
|
|
|
|
|
|
|
|
class LayerNorm(nn.Module):
|
|
|
|
|
def __init__(self, dim):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.gamma = nn.Parameter(torch.ones(dim))
|
|
|
|
|
self.register_buffer('beta', torch.zeros(dim))
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
|
|
|
|
|
|
|
|
|
|
# they use a query-key normalization that is equivalent to rms norm (no mean-centering, learned gamma), from vit 22B paper
|
|
|
|
|
|
|
|
|
|
class RMSNorm(nn.Module):
|
|
|
|
|
def __init__(self, heads, dim):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.scale = dim ** 0.5
|
|
|
|
|
self.gamma = nn.Parameter(torch.ones(heads, 1, dim))
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
normed = F.normalize(x, dim = -1)
|
|
|
|
|
return normed * self.scale * self.gamma
|
|
|
|
|
|
|
|
|
|
# feedforward
|
|
|
|
|
|
|
|
|
|
def FeedForward(dim, hidden_dim, dropout = 0.):
|
|
|
|
|
return nn.Sequential(
|
|
|
|
|
LayerNorm(dim),
|
|
|
|
|
nn.Linear(dim, hidden_dim),
|
|
|
|
|
nn.GELU(),
|
|
|
|
|
nn.Dropout(dropout),
|
|
|
|
|
nn.Linear(hidden_dim, dim),
|
|
|
|
|
nn.Dropout(dropout)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
class Attention(nn.Module):
|
|
|
|
|
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
|
|
|
|
super().__init__()
|
|
|
|
|
inner_dim = dim_head * heads
|
|
|
|
|
self.heads = heads
|
|
|
|
|
self.norm = LayerNorm(dim)
|
|
|
|
|
|
|
|
|
|
self.q_norm = RMSNorm(heads, dim_head)
|
|
|
|
|
self.k_norm = RMSNorm(heads, dim_head)
|
|
|
|
|
|
2025-12-06 13:56:40 +01:00
|
|
|
self.dropout_p = dropout
|
2023-07-24 13:54:02 -07:00
|
|
|
|
|
|
|
|
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
|
|
|
|
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
|
|
|
|
|
|
|
|
|
|
self.to_out = nn.Sequential(
|
|
|
|
|
nn.Linear(inner_dim, dim, bias = False),
|
|
|
|
|
nn.Dropout(dropout)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
x,
|
|
|
|
|
context = None,
|
|
|
|
|
mask = None,
|
|
|
|
|
attn_mask = None
|
|
|
|
|
):
|
|
|
|
|
x = self.norm(x)
|
|
|
|
|
kv_input = default(context, x)
|
|
|
|
|
|
|
|
|
|
qkv = (self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1))
|
|
|
|
|
|
|
|
|
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
|
|
|
|
|
|
|
|
|
q = self.q_norm(q)
|
|
|
|
|
k = self.k_norm(k)
|
|
|
|
|
|
2025-12-06 13:56:40 +01:00
|
|
|
# combine masks if both exist
|
|
|
|
|
if exists(mask) or exists(attn_mask):
|
|
|
|
|
if exists(mask):
|
|
|
|
|
mask = rearrange(mask, 'b j -> b 1 1 j')
|
|
|
|
|
if exists(mask) and exists(attn_mask):
|
|
|
|
|
attn_mask = mask & attn_mask
|
|
|
|
|
elif exists(mask):
|
|
|
|
|
attn_mask = mask
|
|
|
|
|
|
|
|
|
|
out = F.scaled_dot_product_attention(
|
|
|
|
|
q, k, v,
|
|
|
|
|
attn_mask = attn_mask,
|
|
|
|
|
dropout_p = self.dropout_p if self.training else 0.,
|
|
|
|
|
scale = 1. # RMSNorm already includes sqrt(dim) scaling
|
|
|
|
|
)
|
2023-07-24 13:54:02 -07:00
|
|
|
|
|
|
|
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
|
|
|
|
return self.to_out(out)
|
|
|
|
|
|
|
|
|
|
class Transformer(nn.Module):
|
|
|
|
|
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.layers = nn.ModuleList([])
|
|
|
|
|
for _ in range(depth):
|
|
|
|
|
self.layers.append(nn.ModuleList([
|
|
|
|
|
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
|
|
|
|
FeedForward(dim, mlp_dim, dropout = dropout)
|
|
|
|
|
]))
|
|
|
|
|
|
|
|
|
|
self.norm = LayerNorm(dim)
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
x,
|
|
|
|
|
mask = None,
|
|
|
|
|
attn_mask = None
|
|
|
|
|
):
|
|
|
|
|
for attn, ff in self.layers:
|
|
|
|
|
x = attn(x, mask = mask, attn_mask = attn_mask) + x
|
|
|
|
|
x = ff(x) + x
|
|
|
|
|
|
|
|
|
|
return self.norm(x)
|
|
|
|
|
|
|
|
|
|
class NaViT(nn.Module):
|
2023-07-24 14:52:40 -07:00
|
|
|
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., token_dropout_prob = None):
|
2023-07-24 13:54:02 -07:00
|
|
|
super().__init__()
|
|
|
|
|
image_height, image_width = pair(image_size)
|
|
|
|
|
|
2023-07-24 14:14:36 -07:00
|
|
|
# what percent of tokens to dropout
|
2023-07-24 14:52:40 -07:00
|
|
|
# if int or float given, then assume constant dropout prob
|
|
|
|
|
# otherwise accept a callback that in turn calculates dropout prob from height and width
|
2023-07-24 14:14:36 -07:00
|
|
|
|
2023-07-24 14:52:40 -07:00
|
|
|
self.calc_token_dropout = None
|
|
|
|
|
|
|
|
|
|
if callable(token_dropout_prob):
|
|
|
|
|
self.calc_token_dropout = token_dropout_prob
|
|
|
|
|
|
|
|
|
|
elif isinstance(token_dropout_prob, (float, int)):
|
2024-04-18 09:44:29 -07:00
|
|
|
assert 0. <= token_dropout_prob < 1.
|
2023-07-24 14:52:40 -07:00
|
|
|
token_dropout_prob = float(token_dropout_prob)
|
|
|
|
|
self.calc_token_dropout = lambda height, width: token_dropout_prob
|
|
|
|
|
|
|
|
|
|
# calculate patching related stuff
|
2023-07-24 14:14:36 -07:00
|
|
|
|
2023-07-24 13:54:02 -07:00
|
|
|
assert divisible_by(image_height, patch_size) and divisible_by(image_width, patch_size), 'Image dimensions must be divisible by the patch size.'
|
|
|
|
|
|
|
|
|
|
patch_height_dim, patch_width_dim = (image_height // patch_size), (image_width // patch_size)
|
|
|
|
|
patch_dim = channels * (patch_size ** 2)
|
|
|
|
|
|
|
|
|
|
self.channels = channels
|
|
|
|
|
self.patch_size = patch_size
|
|
|
|
|
|
|
|
|
|
self.to_patch_embedding = nn.Sequential(
|
|
|
|
|
LayerNorm(patch_dim),
|
|
|
|
|
nn.Linear(patch_dim, dim),
|
|
|
|
|
LayerNorm(dim),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.pos_embed_height = nn.Parameter(torch.randn(patch_height_dim, dim))
|
|
|
|
|
self.pos_embed_width = nn.Parameter(torch.randn(patch_width_dim, dim))
|
|
|
|
|
|
|
|
|
|
self.dropout = nn.Dropout(emb_dropout)
|
|
|
|
|
|
|
|
|
|
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
|
|
|
|
|
|
|
|
|
# final attention pooling queries
|
|
|
|
|
|
|
|
|
|
self.attn_pool_queries = nn.Parameter(torch.randn(dim))
|
|
|
|
|
self.attn_pool = Attention(dim = dim, dim_head = dim_head, heads = heads)
|
|
|
|
|
|
|
|
|
|
# output to logits
|
|
|
|
|
|
|
|
|
|
self.to_latent = nn.Identity()
|
|
|
|
|
|
|
|
|
|
self.mlp_head = nn.Sequential(
|
|
|
|
|
LayerNorm(dim),
|
|
|
|
|
nn.Linear(dim, num_classes, bias = False)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def device(self):
|
|
|
|
|
return next(self.parameters()).device
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
2024-08-20 15:12:29 -07:00
|
|
|
batched_images: List[Tensor] | List[List[Tensor]], # assume different resolution images already grouped correctly
|
2023-07-25 10:38:55 -07:00
|
|
|
group_images = False,
|
|
|
|
|
group_max_seq_len = 2048
|
2023-07-24 13:54:02 -07:00
|
|
|
):
|
2024-04-18 09:44:29 -07:00
|
|
|
p, c, device, has_token_dropout = self.patch_size, self.channels, self.device, exists(self.calc_token_dropout) and self.training
|
2023-07-24 13:54:02 -07:00
|
|
|
|
|
|
|
|
arange = partial(torch.arange, device = device)
|
|
|
|
|
pad_sequence = partial(orig_pad_sequence, batch_first = True)
|
|
|
|
|
|
2023-07-25 10:38:55 -07:00
|
|
|
# auto pack if specified
|
|
|
|
|
|
|
|
|
|
if group_images:
|
|
|
|
|
batched_images = group_images_by_max_seq_len(
|
|
|
|
|
batched_images,
|
|
|
|
|
patch_size = self.patch_size,
|
2024-04-18 09:44:29 -07:00
|
|
|
calc_token_dropout = self.calc_token_dropout if self.training else None,
|
2023-07-25 10:38:55 -07:00
|
|
|
max_seq_len = group_max_seq_len
|
|
|
|
|
)
|
|
|
|
|
|
2024-08-20 15:12:29 -07:00
|
|
|
# if List[Tensor] is not grouped -> List[List[Tensor]]
|
|
|
|
|
|
|
|
|
|
if torch.is_tensor(batched_images[0]):
|
|
|
|
|
batched_images = [batched_images]
|
|
|
|
|
|
2023-07-24 13:54:02 -07:00
|
|
|
# process images into variable lengthed sequences with attention mask
|
|
|
|
|
|
|
|
|
|
num_images = []
|
|
|
|
|
batched_sequences = []
|
|
|
|
|
batched_positions = []
|
|
|
|
|
batched_image_ids = []
|
|
|
|
|
|
|
|
|
|
for images in batched_images:
|
|
|
|
|
num_images.append(len(images))
|
|
|
|
|
|
2025-12-06 13:56:40 +01:00
|
|
|
# compute patch dimensions for all images
|
|
|
|
|
patch_dims = []
|
|
|
|
|
for image in images:
|
|
|
|
|
assert image.ndim == 3 and image.shape[0] == c
|
2023-07-24 13:54:02 -07:00
|
|
|
image_dims = image.shape[-2:]
|
|
|
|
|
assert all([divisible_by(dim, p) for dim in image_dims]), f'height and width {image_dims} of images must be divisible by patch size {p}'
|
2025-12-06 13:56:40 +01:00
|
|
|
patch_dims.append((image_dims[0] // p, image_dims[1] // p))
|
2023-07-24 13:54:02 -07:00
|
|
|
|
2025-12-06 13:56:40 +01:00
|
|
|
# extract patches for all images
|
|
|
|
|
sequences = [rearrange(img, 'c (h p1) (w p2) -> (h w) (c p1 p2)', p1=p, p2=p) for img in images]
|
2023-07-24 13:54:02 -07:00
|
|
|
|
2025-12-07 13:32:30 +01:00
|
|
|
# compute positions - uses lru_cache to avoid redundant computation across forward passes
|
|
|
|
|
positions = [posemb_grid(ph, pw, device) for ph, pw in patch_dims]
|
2025-12-06 13:56:40 +01:00
|
|
|
|
|
|
|
|
# handle token dropout
|
|
|
|
|
if has_token_dropout:
|
|
|
|
|
for i, (seq, pos) in enumerate(zip(sequences, positions)):
|
|
|
|
|
image_dims = images[i].shape[-2:]
|
2023-07-24 14:52:40 -07:00
|
|
|
token_dropout = self.calc_token_dropout(*image_dims)
|
2025-12-06 13:56:40 +01:00
|
|
|
seq_len = seq.shape[0]
|
2023-07-24 14:52:40 -07:00
|
|
|
num_keep = max(1, int(seq_len * (1 - token_dropout)))
|
2025-12-06 13:56:40 +01:00
|
|
|
keep_indices = torch.randn((seq_len,), device=device).topk(num_keep, dim=-1).indices
|
|
|
|
|
sequences[i] = seq[keep_indices]
|
|
|
|
|
positions[i] = pos[keep_indices]
|
|
|
|
|
|
|
|
|
|
# build image_ids efficiently using repeat_interleave
|
|
|
|
|
patch_counts = [seq.shape[0] for seq in sequences]
|
|
|
|
|
image_ids = torch.repeat_interleave(
|
|
|
|
|
arange(len(images)),
|
|
|
|
|
torch.tensor(patch_counts, device=device)
|
|
|
|
|
)
|
2023-07-24 13:54:02 -07:00
|
|
|
|
|
|
|
|
batched_image_ids.append(image_ids)
|
2025-12-06 13:56:40 +01:00
|
|
|
batched_sequences.append(torch.cat(sequences, dim=0))
|
|
|
|
|
batched_positions.append(torch.cat(positions, dim=0))
|
2023-07-24 13:54:02 -07:00
|
|
|
|
|
|
|
|
# derive key padding mask
|
|
|
|
|
|
|
|
|
|
lengths = torch.tensor([seq.shape[-2] for seq in batched_sequences], device = device, dtype = torch.long)
|
2024-04-18 09:44:29 -07:00
|
|
|
seq_arange = arange(lengths.amax().item())
|
|
|
|
|
key_pad_mask = rearrange(seq_arange, 'n -> 1 n') < rearrange(lengths, 'b -> b 1')
|
2023-07-24 13:54:02 -07:00
|
|
|
|
|
|
|
|
# derive attention mask, and combine with key padding mask from above
|
|
|
|
|
|
|
|
|
|
batched_image_ids = pad_sequence(batched_image_ids)
|
|
|
|
|
attn_mask = rearrange(batched_image_ids, 'b i -> b 1 i 1') == rearrange(batched_image_ids, 'b j -> b 1 1 j')
|
|
|
|
|
attn_mask = attn_mask & rearrange(key_pad_mask, 'b j -> b 1 1 j')
|
|
|
|
|
|
|
|
|
|
# combine patched images as well as the patched width / height positions for 2d positional embedding
|
|
|
|
|
|
|
|
|
|
patches = pad_sequence(batched_sequences)
|
|
|
|
|
patch_positions = pad_sequence(batched_positions)
|
|
|
|
|
|
|
|
|
|
# need to know how many images for final attention pooling
|
|
|
|
|
|
|
|
|
|
num_images = torch.tensor(num_images, device = device, dtype = torch.long)
|
|
|
|
|
|
|
|
|
|
# to patches
|
|
|
|
|
|
|
|
|
|
x = self.to_patch_embedding(patches)
|
|
|
|
|
|
|
|
|
|
# factorized 2d absolute positional embedding
|
|
|
|
|
|
|
|
|
|
h_indices, w_indices = patch_positions.unbind(dim = -1)
|
|
|
|
|
|
|
|
|
|
h_pos = self.pos_embed_height[h_indices]
|
|
|
|
|
w_pos = self.pos_embed_width[w_indices]
|
|
|
|
|
|
|
|
|
|
x = x + h_pos + w_pos
|
|
|
|
|
|
|
|
|
|
# embed dropout
|
|
|
|
|
|
|
|
|
|
x = self.dropout(x)
|
|
|
|
|
|
|
|
|
|
# attention
|
|
|
|
|
|
|
|
|
|
x = self.transformer(x, attn_mask = attn_mask)
|
|
|
|
|
|
|
|
|
|
# do attention pooling at the end
|
|
|
|
|
|
|
|
|
|
max_queries = num_images.amax().item()
|
|
|
|
|
|
|
|
|
|
queries = repeat(self.attn_pool_queries, 'd -> b n d', n = max_queries, b = x.shape[0])
|
|
|
|
|
|
|
|
|
|
# attention pool mask
|
|
|
|
|
|
|
|
|
|
image_id_arange = arange(max_queries)
|
|
|
|
|
|
|
|
|
|
attn_pool_mask = rearrange(image_id_arange, 'i -> i 1') == rearrange(batched_image_ids, 'b j -> b 1 j')
|
|
|
|
|
|
|
|
|
|
attn_pool_mask = attn_pool_mask & rearrange(key_pad_mask, 'b j -> b 1 j')
|
|
|
|
|
|
|
|
|
|
attn_pool_mask = rearrange(attn_pool_mask, 'b i j -> b 1 i j')
|
|
|
|
|
|
|
|
|
|
# attention pool
|
|
|
|
|
|
|
|
|
|
x = self.attn_pool(queries, context = x, attn_mask = attn_pool_mask) + queries
|
|
|
|
|
|
|
|
|
|
x = rearrange(x, 'b n d -> (b n) d')
|
|
|
|
|
|
|
|
|
|
# each batch element may not have same amount of images
|
|
|
|
|
|
|
|
|
|
is_images = image_id_arange < rearrange(num_images, 'b -> b 1')
|
|
|
|
|
is_images = rearrange(is_images, 'b n -> (b n)')
|
|
|
|
|
|
|
|
|
|
x = x[is_images]
|
|
|
|
|
|
|
|
|
|
# project out to logits
|
|
|
|
|
|
|
|
|
|
x = self.to_latent(x)
|
|
|
|
|
|
|
|
|
|
return self.mlp_head(x)
|