mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2026-05-14 12:18:06 +00:00
Compare commits
22 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6032a54b48 | ||
|
|
06a1f42924 | ||
|
|
6ae6a3ab64 | ||
|
|
827300beed | ||
|
|
a7c4e7f79f | ||
|
|
54ec3f2af5 | ||
|
|
9aa52cce49 | ||
|
|
4c89017444 | ||
|
|
580258d99e | ||
|
|
6f1caef987 | ||
|
|
fb5014f0ee | ||
|
|
0b7518ef45 | ||
|
|
077d8c188f | ||
|
|
5888f05300 | ||
|
|
d518e89573 | ||
|
|
dd6462d19b | ||
|
|
a1ee1daa1a | ||
|
|
3cff5e547a | ||
|
|
fdaf7f92b9 | ||
|
|
0ebd4edab9 | ||
|
|
aa49c2783a | ||
|
|
6aa0374313 |
40
README.md
40
README.md
@@ -49,7 +49,7 @@
|
||||
|
||||
## Vision Transformer - Pytorch
|
||||
|
||||
Implementation of <a href="https://openreview.net/pdf?id=YicbFdNTTy">Vision Transformer</a>, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. Significance is further explained in <a href="https://www.youtube.com/watch?v=TrdevFK_am4">Yannic Kilcher's</a> video. There's really not much to code here, but may as well lay it out for everyone so we expedite the attention revolution.
|
||||
Implementation of <a href="https://openreview.net/pdf?id=YicbFdNTTy">Vision Transformer</a>, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. Significance is further explained in <a href="https://www.youtube.com/watch?v=TrdevFK_am4">Yannic Kilcher's</a> video. There's really not much to code here, but may as well lay it out for everyone so we expedite the [attention](https://www.youtube.com/watch?v=eMlx5fFNoYc) revolution.
|
||||
|
||||
For a Pytorch implementation with pretrained models, please see Ross Wightman's repository <a href="https://github.com/rwightman/pytorch-image-models">here</a>.
|
||||
|
||||
@@ -1358,7 +1358,7 @@ learner = Dino(
|
||||
hidden_layer = 'to_latent', # hidden layer name or index, from which to extract the embedding
|
||||
projection_hidden_size = 256, # projector network hidden dimension
|
||||
projection_layers = 4, # number of layers in projection network
|
||||
num_classes_K = 65336, # output logits dimensions (referenced as K in paper)
|
||||
num_classes_K = 65536, # output logits dimensions (referenced as K in paper)
|
||||
student_temp = 0.9, # student temperature
|
||||
teacher_temp = 0.04, # teacher temperature, needs to be annealed from 0.04 to 0.07 over 30 epochs
|
||||
local_upper_crop_scale = 0.4, # upper bound for local crop - 0.4 was recommended in the paper
|
||||
@@ -2213,4 +2213,40 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{gopalakrishnan2025decouplingwhatwherepolar,
|
||||
title = {Decoupling the "What" and "Where" With Polar Coordinate Positional Embeddings},
|
||||
author = {Anand Gopalakrishnan and Robert Csordás and Jürgen Schmidhuber and Michael C. Mozer},
|
||||
year = {2025},
|
||||
eprint = {2509.10534},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.LG},
|
||||
url = {https://arxiv.org/abs/2509.10534},
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{qiu2025gatedattentionlargelanguage,
|
||||
title = {Gated Attention for Large Language Models: Non-linearity, Sparsity, and Attention-Sink-Free},
|
||||
author = {Zihan Qiu and Zekun Wang and Bo Zheng and Zeyu Huang and Kaiyue Wen and Songlin Yang and Rui Men and Le Yu and Fei Huang and Suozhi Huang and Dayiheng Liu and Jingren Zhou and Junyang Lin},
|
||||
year = {2025},
|
||||
eprint = {2505.06708},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CL},
|
||||
url = {https://arxiv.org/abs/2505.06708},
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{chen2026postlayernormbackstableexpressive,
|
||||
title = {Post-LayerNorm Is Back: Stable, ExpressivE, and Deep},
|
||||
author = {Chen Chen and Lai Wei},
|
||||
year = {2026},
|
||||
eprint = {2601.19895},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.LG},
|
||||
url = {https://arxiv.org/abs/2601.19895},
|
||||
}
|
||||
```
|
||||
|
||||
*I visualise a time when we will be to robots what dogs are to humans, and I’m rooting for the machines.* — Claude Shannon
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "vit-pytorch"
|
||||
version = "1.15.5"
|
||||
version = "1.17.8"
|
||||
description = "Vision Transformer (ViT) - Pytorch"
|
||||
readme = { file = "README.md", content-type = "text/markdown" }
|
||||
license = { file = "LICENSE" }
|
||||
|
||||
@@ -25,12 +25,12 @@ class DistillMixin:
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = b)
|
||||
cls_tokens = repeat(self.cls_token, 'n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim = 1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x += self.pos_embedding[:(n + 1)]
|
||||
|
||||
if distilling:
|
||||
distill_tokens = repeat(distill_token, '1 n d -> b n d', b = b)
|
||||
distill_tokens = repeat(distill_token, 'n d -> b n d', b = b)
|
||||
x = torch.cat((x, distill_tokens), dim = 1)
|
||||
|
||||
x = self._attend(x)
|
||||
@@ -125,7 +125,7 @@ class DistillWrapper(Module):
|
||||
self.alpha = alpha
|
||||
self.hard = hard
|
||||
|
||||
self.distillation_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.distillation_token = nn.Parameter(torch.randn(1, dim))
|
||||
|
||||
self.distill_mlp = nn.Sequential(
|
||||
nn.LayerNorm(dim) if mlp_layernorm else nn.Identity(),
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
from functools import partial, lru_cache
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
@@ -9,7 +9,6 @@ from torch import nn, Tensor
|
||||
from torch.nn.utils.rnn import pad_sequence as orig_pad_sequence
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
@@ -28,6 +27,12 @@ def pair(t):
|
||||
def divisible_by(numer, denom):
|
||||
return (numer % denom) == 0
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def posemb_grid(ph, pw, device):
|
||||
h_idx = torch.arange(ph, device=device).repeat_interleave(pw)
|
||||
w_idx = torch.arange(pw, device=device).repeat(ph)
|
||||
return torch.stack([h_idx, w_idx], dim=-1)
|
||||
|
||||
# auto grouping images
|
||||
|
||||
def group_images_by_max_seq_len(
|
||||
@@ -117,8 +122,7 @@ class Attention(nn.Module):
|
||||
self.q_norm = RMSNorm(heads, dim_head)
|
||||
self.k_norm = RMSNorm(heads, dim_head)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.dropout_p = dropout
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
|
||||
@@ -145,19 +149,22 @@ class Attention(nn.Module):
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2))
|
||||
# combine masks if both exist
|
||||
if exists(mask) or exists(attn_mask):
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b j -> b 1 1 j')
|
||||
if exists(mask) and exists(attn_mask):
|
||||
attn_mask = mask & attn_mask
|
||||
elif exists(mask):
|
||||
attn_mask = mask
|
||||
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b j -> b 1 1 j')
|
||||
dots = dots.masked_fill(~mask, -torch.finfo(dots.dtype).max)
|
||||
out = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask = attn_mask,
|
||||
dropout_p = self.dropout_p if self.training else 0.,
|
||||
scale = 1. # RMSNorm already includes sqrt(dim) scaling
|
||||
)
|
||||
|
||||
if exists(attn_mask):
|
||||
dots = dots.masked_fill(~attn_mask, -torch.finfo(dots.dtype).max)
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
@@ -281,42 +288,41 @@ class NaViT(nn.Module):
|
||||
for images in batched_images:
|
||||
num_images.append(len(images))
|
||||
|
||||
sequences = []
|
||||
positions = []
|
||||
image_ids = torch.empty((0,), device = device, dtype = torch.long)
|
||||
|
||||
for image_id, image in enumerate(images):
|
||||
assert image.ndim ==3 and image.shape[0] == c
|
||||
# compute patch dimensions for all images
|
||||
patch_dims = []
|
||||
for image in images:
|
||||
assert image.ndim == 3 and image.shape[0] == c
|
||||
image_dims = image.shape[-2:]
|
||||
assert all([divisible_by(dim, p) for dim in image_dims]), f'height and width {image_dims} of images must be divisible by patch size {p}'
|
||||
patch_dims.append((image_dims[0] // p, image_dims[1] // p))
|
||||
|
||||
ph, pw = map(lambda dim: dim // p, image_dims)
|
||||
# extract patches for all images
|
||||
sequences = [rearrange(img, 'c (h p1) (w p2) -> (h w) (c p1 p2)', p1=p, p2=p) for img in images]
|
||||
|
||||
pos = torch.stack(torch.meshgrid((
|
||||
arange(ph),
|
||||
arange(pw)
|
||||
), indexing = 'ij'), dim = -1)
|
||||
# compute positions - uses lru_cache to avoid redundant computation across forward passes
|
||||
positions = [posemb_grid(ph, pw, device) for ph, pw in patch_dims]
|
||||
|
||||
pos = rearrange(pos, 'h w c -> (h w) c')
|
||||
seq = rearrange(image, 'c (h p1) (w p2) -> (h w) (c p1 p2)', p1 = p, p2 = p)
|
||||
|
||||
seq_len = seq.shape[-2]
|
||||
|
||||
if has_token_dropout:
|
||||
# handle token dropout
|
||||
if has_token_dropout:
|
||||
for i, (seq, pos) in enumerate(zip(sequences, positions)):
|
||||
image_dims = images[i].shape[-2:]
|
||||
token_dropout = self.calc_token_dropout(*image_dims)
|
||||
seq_len = seq.shape[0]
|
||||
num_keep = max(1, int(seq_len * (1 - token_dropout)))
|
||||
keep_indices = torch.randn((seq_len,), device = device).topk(num_keep, dim = -1).indices
|
||||
keep_indices = torch.randn((seq_len,), device=device).topk(num_keep, dim=-1).indices
|
||||
sequences[i] = seq[keep_indices]
|
||||
positions[i] = pos[keep_indices]
|
||||
|
||||
seq = seq[keep_indices]
|
||||
pos = pos[keep_indices]
|
||||
|
||||
image_ids = F.pad(image_ids, (0, seq.shape[-2]), value = image_id)
|
||||
sequences.append(seq)
|
||||
positions.append(pos)
|
||||
# build image_ids efficiently using repeat_interleave
|
||||
patch_counts = [seq.shape[0] for seq in sequences]
|
||||
image_ids = torch.repeat_interleave(
|
||||
arange(len(images)),
|
||||
torch.tensor(patch_counts, device=device)
|
||||
)
|
||||
|
||||
batched_image_ids.append(image_ids)
|
||||
batched_sequences.append(torch.cat(sequences, dim = 0))
|
||||
batched_positions.append(torch.cat(positions, dim = 0))
|
||||
batched_sequences.append(torch.cat(sequences, dim=0))
|
||||
batched_positions.append(torch.cat(positions, dim=0))
|
||||
|
||||
# derive key padding mask
|
||||
|
||||
|
||||
@@ -176,7 +176,7 @@ class NaViT(Module):
|
||||
|
||||
self.channels = channels
|
||||
self.patch_size = patch_size
|
||||
self.to_patches = Rearrange('c (f pf) (h p1) (w p2) -> f h w (c p1 p2 pf)', p1 = patch_size, p2 = patch_size, pf = frame_patch_size)
|
||||
self.to_patches = Rearrange('c (f pf) (h p1) (w p2) -> f h w (c pf p1 p2)', p1 = patch_size, p2 = patch_size, pf = frame_patch_size)
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
nn.LayerNorm(patch_dim),
|
||||
|
||||
@@ -146,7 +146,7 @@ class SimpleViT(Module):
|
||||
patch_dim = channels * patch_height * patch_width * frame_patch_size
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (pf p1 p2 c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
|
||||
@@ -103,7 +103,7 @@ class SimpleViT(nn.Module):
|
||||
patch_dim = channels * patch_height * patch_width * frame_patch_size
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (pf p1 p2 c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
|
||||
@@ -120,6 +120,12 @@ class Attention(Module):
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(dim_context, inner_dim * 2, bias = False)
|
||||
|
||||
self.to_out_gates = nn.Sequential(
|
||||
nn.Linear(dim, heads),
|
||||
Rearrange('b ... h -> b h ... 1'),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
@@ -150,6 +156,9 @@ class Attention(Module):
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
|
||||
out = out * self.to_out_gates(x)
|
||||
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
@@ -215,7 +224,8 @@ class AST(Module):
|
||||
spec_hop_length = None,
|
||||
spec_pad = 0,
|
||||
spec_center = True,
|
||||
spec_pad_mode = 'reflect'
|
||||
spec_pad_mode = 'reflect',
|
||||
num_register_tokens = 4
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
@@ -256,8 +266,11 @@ class AST(Module):
|
||||
)
|
||||
|
||||
self.final_norm = nn.LayerNorm(dim)
|
||||
|
||||
self.mlp_head = nn.Linear(dim, num_classes) if exists(num_classes) else nn.Identity()
|
||||
|
||||
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
raw_audio_or_spec, # (b t) | (b f t)
|
||||
@@ -296,6 +309,12 @@ class AST(Module):
|
||||
|
||||
tokens = rearrange(tokens, 'b ... c -> b (...) c')
|
||||
|
||||
# register tokens
|
||||
|
||||
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
|
||||
|
||||
tokens, packed_shape = pack((register_tokens, tokens), 'b * d')
|
||||
|
||||
# attention
|
||||
|
||||
attended, hiddens = self.transformer(tokens, return_hiddens = True)
|
||||
@@ -307,6 +326,8 @@ class AST(Module):
|
||||
if return_hiddens:
|
||||
return normed, stack(hiddens)
|
||||
|
||||
register_tokens, normed = unpack(normed, packed_shape, 'b * d')
|
||||
|
||||
pooled = reduce(normed, 'b n d -> b d', 'mean')
|
||||
|
||||
maybe_logits = self.mlp_head(pooled)
|
||||
@@ -384,7 +405,7 @@ class ViT(Module):
|
||||
if return_hiddens:
|
||||
return x, stack(hiddens)
|
||||
|
||||
cls_tokens, x, register_tokens = unpack(x, packed_shape, 'b * d')
|
||||
register_tokens, cls_tokens, x = unpack(x, packed_shape, 'b * d')
|
||||
|
||||
x = x.mean(dim = 1) if self.pool == 'mean' else cls_tokens
|
||||
|
||||
@@ -409,7 +430,8 @@ class VAAT(Module):
|
||||
dim_head,
|
||||
dim_action,
|
||||
mlp_dim,
|
||||
num_views = None,
|
||||
num_image_views = None,
|
||||
num_audio_views = None,
|
||||
num_tasks = None,
|
||||
dim_extra_token = None,
|
||||
num_register_tokens = 4,
|
||||
@@ -450,6 +472,8 @@ class VAAT(Module):
|
||||
|
||||
ast_dim = ast.dim
|
||||
|
||||
self.ast_accept_spec = ast.accept_spec
|
||||
|
||||
assert ast.depth == depth or exists(ast_layer_indices), f'if the VAAT depth is not equal to the AST depth, you must pass in the indices from the AST to be layered to the VAAT in order from bottom to top'
|
||||
|
||||
ast_layer_indices = default(ast_layer_indices, tuple(range(depth)))
|
||||
@@ -468,7 +492,9 @@ class VAAT(Module):
|
||||
|
||||
# maybe view embeddings
|
||||
|
||||
self.view_emb = nn.Parameter(torch.randn(num_views, vit_dim) * 1e-2) if exists(num_views) and num_views > 1 else None
|
||||
self.image_view_emb = nn.Parameter(torch.randn(num_image_views, vit_dim) * 1e-2) if exists(num_image_views) and num_image_views > 1 else None
|
||||
|
||||
self.audio_view_emb = nn.Parameter(torch.randn(num_audio_views, ast_dim) * 1e-2) if exists(num_audio_views) and num_audio_views > 1 else None
|
||||
|
||||
# handle maybe task conditioning
|
||||
|
||||
@@ -511,12 +537,12 @@ class VAAT(Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
video_or_image, # (b v? c t? h w) - batch, views [wrist + third person or more], channels, maybe time, height, width
|
||||
audio_or_spec, # (b t) | (b f t) - batch, audio len | batch, spec freq, time
|
||||
video_or_image, # (b v? c t? h w) - batch, views [wrist + third person or more], channels, maybe time, height, width
|
||||
audio_or_spec, # (b v? t) | (b v?f t) - batch, audio len | batch, spec freq, time
|
||||
*,
|
||||
extra = None, # (b d) - batch, dim extra
|
||||
extra = None, # (b d) - batch, dim extra
|
||||
tasks = None, # (b)
|
||||
actions = None, # (b k d) - batch, action chunk length, action dimension
|
||||
actions = None, # (b k d) - batch, action chunk length, action dimension
|
||||
return_hiddens = False,
|
||||
freeze_vit = False,
|
||||
freeze_ast = False
|
||||
@@ -539,11 +565,26 @@ class VAAT(Module):
|
||||
|
||||
assert video_or_image.shape[3] == self.time_seq_len
|
||||
|
||||
# audio shapes - adding view if impliciy to be 1
|
||||
|
||||
if audio_or_spec.ndim == 2 and not self.ast_accept_spec:
|
||||
audio_or_spec = rearrange(audio_or_spec, 'b t -> b 1 t')
|
||||
|
||||
elif audio_or_spec.ndim == 3 and self.ast_accept_spec:
|
||||
audio_or_spec = rearrange(audio_or_spec, 'b f t -> b 1 f t')
|
||||
|
||||
# to images
|
||||
|
||||
images = rearrange(video_or_image, 'b v c t h w -> b v t c h w')
|
||||
|
||||
images, packed_shape = pack([images], '* c h w')
|
||||
images, image_packed_shape = pack([images], '* c h w')
|
||||
|
||||
# to audio
|
||||
|
||||
if self.ast_accept_spec:
|
||||
audio_or_spec, audio_packed_shape = pack([audio_or_spec], '* f t')
|
||||
else:
|
||||
audio_or_spec, audio_packed_shape = pack([audio_or_spec], '* t')
|
||||
|
||||
# get representation trajectory from vit
|
||||
|
||||
@@ -558,9 +599,9 @@ class VAAT(Module):
|
||||
|
||||
hiddens = hiddens[self.vit_layer_indices]
|
||||
|
||||
# pack temporarily for embedding
|
||||
# unpack temporarily for embedding
|
||||
|
||||
hiddens, = unpack(hiddens, packed_shape, 'l * n d') # l for layers
|
||||
hiddens, = unpack(hiddens, image_packed_shape, 'l * n d') # l for layers
|
||||
|
||||
# maybe add time embeddings
|
||||
|
||||
@@ -570,11 +611,11 @@ class VAAT(Module):
|
||||
|
||||
# maybe view embeddings
|
||||
|
||||
if exists(self.view_emb):
|
||||
assert self.view_emb.shape[0] == hiddens.shape[2]
|
||||
if exists(self.image_view_emb):
|
||||
assert self.image_view_emb.shape[0] == hiddens.shape[2]
|
||||
|
||||
view_emb = rearrange(self.view_emb, 'v d -> v 1 1 d')
|
||||
hiddens = hiddens + view_emb
|
||||
image_view_emb = rearrange(self.image_view_emb, 'v d -> v 1 1 d')
|
||||
hiddens = hiddens + image_view_emb
|
||||
|
||||
# get representation trajectory from ast
|
||||
|
||||
@@ -589,6 +630,18 @@ class VAAT(Module):
|
||||
|
||||
audio_hiddens = audio_hiddens[self.ast_layer_indices]
|
||||
|
||||
# unpack audio temporarily for embedding
|
||||
|
||||
audio_hiddens, = unpack(audio_hiddens, audio_packed_shape, 'l * n d') # l for layers
|
||||
|
||||
# maybe audio view embeddings
|
||||
|
||||
if exists(self.audio_view_emb):
|
||||
assert self.audio_view_emb.shape[0] == audio_hiddens.shape[2]
|
||||
|
||||
audio_view_emb = rearrange(self.audio_view_emb, 'v d -> v 1 1 d')
|
||||
audio_hiddens = audio_hiddens + audio_view_emb
|
||||
|
||||
# maybe tasks
|
||||
|
||||
if exists(tasks):
|
||||
@@ -600,7 +653,7 @@ class VAAT(Module):
|
||||
|
||||
image_context = rearrange(hiddens, 'l b v t n d -> l b (v t n) d')
|
||||
|
||||
audio_context = audio_hiddens # eventually handle views (stereo and beyond)
|
||||
audio_context = rearrange(audio_hiddens, 'l b v n d -> l b (v n) d')
|
||||
|
||||
# get main action tokens and maybe append extra
|
||||
|
||||
@@ -691,7 +744,7 @@ if __name__ == '__main__':
|
||||
mlp_dim = 384 * 4
|
||||
)
|
||||
|
||||
vat = VAAT(
|
||||
vaat = VAAT(
|
||||
vit,
|
||||
ast,
|
||||
dim = 512,
|
||||
@@ -702,7 +755,8 @@ if __name__ == '__main__':
|
||||
dim_action = 20,
|
||||
action_chunk_len = 7,
|
||||
time_seq_len = 4,
|
||||
num_views = 2,
|
||||
num_image_views = 2,
|
||||
num_audio_views = 2,
|
||||
num_tasks = 4,
|
||||
add_self_attn = True,
|
||||
dim_extra_token = 33, # extra token with some variable dimension
|
||||
@@ -715,18 +769,18 @@ if __name__ == '__main__':
|
||||
)
|
||||
|
||||
images = torch.randn(2, 2, 3, 4, 256, 256) # (2 views with 4 frames)
|
||||
audio = torch.randn(2, 14_100 * 5)
|
||||
audio = torch.randn(2, 2, 14_100 * 5)
|
||||
|
||||
tasks = torch.randint(0, 4, (2,))
|
||||
extra = torch.randn(2, 33) # extra internal state
|
||||
|
||||
actions = torch.randn(2, 7, 20) # actions for learning
|
||||
|
||||
loss = vat(images, audio, actions = actions, tasks = tasks, extra = extra, freeze_vit = True)
|
||||
loss = vaat(images, audio, actions = actions, tasks = tasks, extra = extra, freeze_vit = True)
|
||||
loss.backward()
|
||||
|
||||
# after much training
|
||||
|
||||
pred_actions, hiddens = vat(images, audio, tasks = tasks, extra = extra, return_hiddens = True)
|
||||
pred_actions, hiddens = vaat(images, audio, tasks = tasks, extra = extra, return_hiddens = True)
|
||||
|
||||
assert pred_actions.shape == (2, 7, 20)
|
||||
|
||||
@@ -92,6 +92,12 @@ class Attention(Module):
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(dim_context, inner_dim * 2, bias = False)
|
||||
|
||||
self.to_out_gates = nn.Sequential(
|
||||
nn.Linear(dim, heads),
|
||||
Rearrange('b ... h -> b h ... 1'),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
@@ -122,6 +128,8 @@ class Attention(Module):
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = out * self.to_out_gates(x) # https://arxiv.org/abs/2505.06708
|
||||
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
@@ -237,7 +245,7 @@ class ViT(Module):
|
||||
if return_hiddens:
|
||||
return x, stack(hiddens)
|
||||
|
||||
cls_tokens, x, register_tokens = unpack(x, packed_shape, 'b * d')
|
||||
register_tokens, cls_tokens, x = unpack(x, packed_shape, 'b * d')
|
||||
|
||||
x = x.mean(dim = 1) if self.pool == 'mean' else cls_tokens
|
||||
|
||||
|
||||
487
vit_pytorch/vat_siglip.py
Normal file
487
vit_pytorch/vat_siglip.py
Normal file
@@ -0,0 +1,487 @@
|
||||
from __future__ import annotations
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, cat, stack, tensor, einsum
|
||||
from torch.nn import Module, ModuleList
|
||||
|
||||
from einops import rearrange, repeat, pack, unpack
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
def default(v, d):
|
||||
return v if exists(v) else d
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
# attention
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_context = None,
|
||||
heads = 8,
|
||||
dim_head = 64,
|
||||
dropout = 0.,
|
||||
norm_eps = 1e-6,
|
||||
gate_attn = False
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.norm = nn.LayerNorm(dim, eps = norm_eps)
|
||||
|
||||
self.is_cross_attn = exists(dim_context)
|
||||
dim_context = default(dim_context, dim)
|
||||
self.norm_context = nn.LayerNorm(dim_context, eps = norm_eps) if self.is_cross_attn else None
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim)
|
||||
self.to_kv = nn.Linear(dim_context, inner_dim * 2)
|
||||
|
||||
self.to_out_gates = nn.Sequential(
|
||||
nn.Linear(dim, heads),
|
||||
Rearrange('b ... h -> b h ... 1'),
|
||||
nn.Sigmoid()
|
||||
) if gate_attn else None
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, context = None):
|
||||
x = self.norm(x)
|
||||
|
||||
if self.is_cross_attn:
|
||||
assert exists(context)
|
||||
context = self.norm_context(context)
|
||||
else:
|
||||
context = x
|
||||
|
||||
q = self.to_q(x)
|
||||
k, v = self.to_kv(context).chunk(2, dim = -1)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))
|
||||
|
||||
sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
attn = sim.softmax(dim = -1)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
|
||||
if exists(self.to_out_gates):
|
||||
out = out * self.to_out_gates(x) # https://arxiv.org/abs/2505.06708
|
||||
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
def FeedForward(
|
||||
dim,
|
||||
dim_inner,
|
||||
norm_eps = 1e-6
|
||||
):
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim, eps = norm_eps),
|
||||
nn.Linear(dim, dim_inner),
|
||||
nn.GELU(approximate = 'tanh'),
|
||||
nn.Linear(dim_inner, dim)
|
||||
)
|
||||
|
||||
class SigLIP(Module):
|
||||
def __init__(
|
||||
self,
|
||||
image_size = 224,
|
||||
patch_size = 14,
|
||||
dim = 1152,
|
||||
depth = 27,
|
||||
heads = 16,
|
||||
mlp_dim = 4304,
|
||||
norm_eps = 1e-6
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.depth = depth
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
dim_head = dim // heads
|
||||
|
||||
self.to_patch_embed = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
||||
nn.Linear(patch_size * patch_size * 3, dim)
|
||||
)
|
||||
|
||||
self.pos_embed = nn.Parameter(torch.randn(num_patches, dim))
|
||||
|
||||
self.layers = ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, norm_eps = norm_eps),
|
||||
FeedForward(dim = dim, dim_inner = mlp_dim, norm_eps = norm_eps)
|
||||
]))
|
||||
|
||||
self.norm = nn.LayerNorm(dim, eps = norm_eps)
|
||||
|
||||
def forward(self, x, return_hiddens = False):
|
||||
x = self.to_patch_embed(x)
|
||||
num_patches = x.shape[1]
|
||||
|
||||
x = x + self.pos_embed[:num_patches]
|
||||
|
||||
hiddens = []
|
||||
|
||||
for attn, ff in self.layers:
|
||||
hiddens.append(x)
|
||||
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
|
||||
out = self.norm(x)
|
||||
|
||||
if not return_hiddens:
|
||||
return out
|
||||
|
||||
return out, stack(hiddens)
|
||||
|
||||
class FiLM(Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
proj = nn.Linear(dim, dim * 2)
|
||||
|
||||
self.to_gamma_beta = nn.Sequential(
|
||||
proj,
|
||||
Rearrange('b (two d) -> two b 1 d', two = 2)
|
||||
)
|
||||
|
||||
nn.init.zeros_(proj.weight)
|
||||
nn.init.zeros_(proj.bias)
|
||||
|
||||
def forward(self, tokens, cond):
|
||||
gamma, beta = self.to_gamma_beta(cond)
|
||||
return tokens * gamma + beta
|
||||
|
||||
class SigLIPVAT(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim = 512,
|
||||
depth = 27,
|
||||
heads = 8,
|
||||
dim_head = 64,
|
||||
dim_action = 32,
|
||||
mlp_dim = 2048,
|
||||
num_views = 1,
|
||||
num_tasks = None,
|
||||
dim_extra_token = None,
|
||||
num_register_tokens = 4,
|
||||
action_chunk_len = 50,
|
||||
time_seq_len = 1,
|
||||
dropout = 0.,
|
||||
add_self_attn = True,
|
||||
self_attn_heads = 4,
|
||||
self_attn_dim_head = 32,
|
||||
vit_layer_indices: tuple[int, ...] | None = None,
|
||||
siglip_image_size = 224,
|
||||
siglip_patch_size = 14,
|
||||
siglip_dim = 1152,
|
||||
siglip_depth = 27,
|
||||
siglip_heads = 16,
|
||||
siglip_mlp_dim = 4304,
|
||||
siglip_norm_eps = 1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.vit = SigLIP(
|
||||
image_size = siglip_image_size,
|
||||
patch_size = siglip_patch_size,
|
||||
dim = siglip_dim,
|
||||
depth = siglip_depth,
|
||||
heads = siglip_heads,
|
||||
mlp_dim = siglip_mlp_dim,
|
||||
norm_eps = siglip_norm_eps
|
||||
)
|
||||
|
||||
vit_dim = siglip_dim
|
||||
self.vit_dim = vit_dim
|
||||
|
||||
vit_layer_indices = default(vit_layer_indices, tuple(range(depth)))
|
||||
assert len(vit_layer_indices) == depth, f'number of vit layer indices {len(vit_layer_indices)} does not much the VAT depth {depth}'
|
||||
self.register_buffer('layer_indices', tensor(vit_layer_indices), persistent = False)
|
||||
|
||||
# handle maybe multiple frames
|
||||
|
||||
is_video = time_seq_len > 1
|
||||
self.is_video = is_video
|
||||
self.time_seq_len = time_seq_len
|
||||
self.time_pos_emb = nn.Parameter(torch.randn(time_seq_len, vit_dim) * 1e-2) if is_video else None
|
||||
|
||||
# maybe view embeddings
|
||||
|
||||
self.view_emb = nn.Parameter(torch.randn(num_views, vit_dim) * 1e-2) if exists(num_views) and num_views > 1 else None
|
||||
|
||||
# handle maybe task conditioning
|
||||
|
||||
self.has_tasks = exists(num_tasks)
|
||||
if self.has_tasks:
|
||||
self.task_emb = nn.Parameter(torch.randn(num_tasks, dim) * 1e-2)
|
||||
|
||||
# register tokens
|
||||
|
||||
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
|
||||
|
||||
# to action tokens
|
||||
|
||||
self.action_pos_emb = nn.Parameter(torch.randn(action_chunk_len, dim) * 1e-2)
|
||||
|
||||
self.layers = ModuleList([])
|
||||
for _ in range(depth):
|
||||
maybe_film = FiLM(dim = dim) if self.has_tasks else None
|
||||
maybe_self_attn = Attention(dim = dim, heads = self_attn_heads, dim_head = self_attn_dim_head, dropout = dropout) if add_self_attn else None
|
||||
|
||||
self.layers.append(ModuleList([
|
||||
maybe_film,
|
||||
maybe_self_attn,
|
||||
Attention(dim = dim, dim_context = vit_dim, heads = heads, dim_head = dim_head, dropout = dropout, gate_attn = True),
|
||||
FeedForward(dim = dim, dim_inner = mlp_dim)
|
||||
]))
|
||||
|
||||
self.final_norm = nn.LayerNorm(dim)
|
||||
self.to_pred_action = nn.Linear(dim, dim_action, bias = False)
|
||||
|
||||
# handle the extra token
|
||||
|
||||
self.accept_extra_token = exists(dim_extra_token)
|
||||
if exists(dim_extra_token):
|
||||
self.to_extra_token = nn.Linear(dim_extra_token, dim)
|
||||
|
||||
def load_siglip(
|
||||
self,
|
||||
repo_id = 'google/siglip-so400m-patch14-224',
|
||||
folder = 'checkpoints/siglip'
|
||||
):
|
||||
folder = Path(folder)
|
||||
if not folder.exists():
|
||||
from huggingface_hub import snapshot_download
|
||||
snapshot_download(
|
||||
repo_id = repo_id,
|
||||
local_dir = folder,
|
||||
allow_patterns = ['config.json', 'model.safetensors']
|
||||
)
|
||||
|
||||
from safetensors import safe_open
|
||||
weights_path = folder / 'model.safetensors'
|
||||
|
||||
# Auto-detect prefix based on keys
|
||||
with safe_open(weights_path, framework = 'pt') as f:
|
||||
keys = f.keys()
|
||||
|
||||
vi_p = ''
|
||||
if any(k.startswith('paligemma_with_expert.paligemma.model.vision_tower.vision_model') for k in keys):
|
||||
vi_p = 'paligemma_with_expert.paligemma.model.vision_tower.vision_model.'
|
||||
elif any(k.startswith('vision_model') for k in keys):
|
||||
vi_p = 'vision_model.'
|
||||
|
||||
pz_state = self.vit.state_dict()
|
||||
|
||||
def copy_weight_bias(pz_prefix, vi_prefix):
|
||||
pz_state[f'{pz_prefix}.weight'].copy_(f.get_tensor(f'{vi_prefix}.weight'))
|
||||
pz_state[f'{pz_prefix}.bias'].copy_(f.get_tensor(f'{vi_prefix}.bias'))
|
||||
|
||||
# patch embedding
|
||||
patch_weight = rearrange(f.get_tensor(f'{vi_p}embeddings.patch_embedding.weight'), 'd c h w -> d (h w c)')
|
||||
pz_state['to_patch_embed.1.weight'].copy_(patch_weight)
|
||||
pz_state['to_patch_embed.1.bias'].copy_(f.get_tensor(f'{vi_p}embeddings.patch_embedding.bias'))
|
||||
|
||||
# position embedding
|
||||
pz_state['pos_embed'].copy_(f.get_tensor(f'{vi_p}embeddings.position_embedding.weight'))
|
||||
|
||||
# transformer layers
|
||||
for i in range(self.vit.depth):
|
||||
v_pi = f'{vi_p}encoder.layers.{i}'
|
||||
v_pz = f'layers.{i}'
|
||||
|
||||
# attention
|
||||
copy_weight_bias(f'{v_pz}.0.norm', f'{v_pi}.layer_norm1')
|
||||
copy_weight_bias(f'{v_pz}.0.to_q', f'{v_pi}.self_attn.q_proj')
|
||||
|
||||
vk, vv = [f.get_tensor(f'{v_pi}.self_attn.{x}_proj.weight') for x in ('k', 'v')]
|
||||
bk, bv = [f.get_tensor(f'{v_pi}.self_attn.{x}_proj.bias') for x in ('k', 'v')]
|
||||
|
||||
pz_state[f'{v_pz}.0.to_kv.weight'].copy_(cat((vk, vv), dim = 0))
|
||||
pz_state[f'{v_pz}.0.to_kv.bias'].copy_(cat((bk, bv), dim = 0))
|
||||
|
||||
copy_weight_bias(f'{v_pz}.0.to_out.0', f'{v_pi}.self_attn.out_proj')
|
||||
|
||||
# feedforward
|
||||
copy_weight_bias(f'{v_pz}.1.0', f'{v_pi}.layer_norm2')
|
||||
copy_weight_bias(f'{v_pz}.1.1', f'{v_pi}.mlp.fc1')
|
||||
copy_weight_bias(f'{v_pz}.1.3', f'{v_pi}.mlp.fc2')
|
||||
|
||||
# post-layernorm
|
||||
copy_weight_bias('norm', f'{vi_p}post_layernorm')
|
||||
|
||||
self.vit.load_state_dict(pz_state)
|
||||
|
||||
print(f'Successfully loaded SigLIP weights from {repo_id}')
|
||||
|
||||
def forward(
|
||||
self,
|
||||
video_or_image, # (b v? c t? h w)
|
||||
*,
|
||||
extra = None,
|
||||
tasks = None,
|
||||
actions = None,
|
||||
return_hiddens = False,
|
||||
freeze_vit = False
|
||||
):
|
||||
batch = video_or_image.shape[0]
|
||||
return_loss = exists(actions)
|
||||
|
||||
# handle some various input dimensions
|
||||
|
||||
if video_or_image.ndim == 4:
|
||||
video_or_image = rearrange(video_or_image, 'b 1 c h w')
|
||||
|
||||
if video_or_image.ndim == 5:
|
||||
video_or_image = rearrange(video_or_image, 'b v c h w -> b v c 1 h w')
|
||||
|
||||
assert video_or_image.shape[3] == self.time_seq_len
|
||||
|
||||
# to images
|
||||
|
||||
images = rearrange(video_or_image, 'b v c t h w -> b v t c h w')
|
||||
images, packed_shape = pack([images], '* c h w')
|
||||
|
||||
# get representation trajectory from vit
|
||||
|
||||
vit_forward_context = torch.no_grad if freeze_vit else nullcontext
|
||||
|
||||
with vit_forward_context():
|
||||
embed, hiddens = self.vit(images, return_hiddens = True)
|
||||
|
||||
hiddens = cat((hiddens, embed[None, ...]))
|
||||
|
||||
# extract the hiddens needed for the action cross attention
|
||||
|
||||
hiddens = hiddens[self.layer_indices]
|
||||
|
||||
# pack temporarily for embedding
|
||||
|
||||
hiddens, = unpack(hiddens, packed_shape, 'l * n d') # l for layers
|
||||
|
||||
# maybe add time embeddings
|
||||
|
||||
if self.is_video:
|
||||
time_pos_emb = rearrange(self.time_pos_emb, 't d -> t 1 d')
|
||||
hiddens = hiddens + time_pos_emb
|
||||
|
||||
# maybe view embeddings
|
||||
|
||||
if exists(self.view_emb):
|
||||
view_emb = rearrange(self.view_emb, 'v d -> v 1 1 d')
|
||||
hiddens = hiddens + view_emb
|
||||
|
||||
# maybe tasks
|
||||
|
||||
if exists(tasks):
|
||||
task_emb = self.task_emb[tasks]
|
||||
|
||||
# cross from actions to representation trajectory
|
||||
|
||||
context = rearrange(hiddens, 'l b v t n d -> l b (v t n) d')
|
||||
|
||||
# get main action tokens and maybe append extra
|
||||
|
||||
action_tokens = repeat(self.action_pos_emb, 'k d -> b k d', b = batch)
|
||||
|
||||
has_extra = exists(extra)
|
||||
if has_extra:
|
||||
extra_token = self.to_extra_token(extra)
|
||||
action_tokens, packed_extra = pack([action_tokens, extra_token], 'b * d')
|
||||
|
||||
# register tokens
|
||||
|
||||
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
|
||||
action_tokens, registers_packed_shape = pack((register_tokens, action_tokens), 'b * d')
|
||||
|
||||
# cross attention
|
||||
|
||||
vat_hiddens = [action_tokens]
|
||||
|
||||
for (maybe_film, maybe_self_attn, cross_attn, ff), layer_context in zip(self.layers, context):
|
||||
|
||||
if exists(tasks):
|
||||
action_tokens = maybe_film(action_tokens, task_emb)
|
||||
|
||||
action_tokens = cross_attn(action_tokens, layer_context) + action_tokens
|
||||
|
||||
if exists(maybe_self_attn):
|
||||
action_tokens = maybe_self_attn(action_tokens) + action_tokens
|
||||
|
||||
action_tokens = ff(action_tokens) + action_tokens
|
||||
|
||||
vat_hiddens.append(action_tokens)
|
||||
|
||||
# unpack registers
|
||||
|
||||
_, action_tokens = unpack(action_tokens, registers_packed_shape, 'b * d')
|
||||
|
||||
# maybe unpack extra
|
||||
|
||||
if has_extra:
|
||||
action_tokens, _ = unpack(action_tokens, packed_extra, 'b * d')
|
||||
|
||||
# norm and prediction
|
||||
|
||||
action_tokens = self.final_norm(action_tokens)
|
||||
pred_action = self.to_pred_action(action_tokens)
|
||||
|
||||
if not return_loss:
|
||||
if not return_hiddens:
|
||||
return pred_action
|
||||
|
||||
return pred_action, stack(vat_hiddens)
|
||||
|
||||
assert pred_action.shape[1] == actions.shape[1]
|
||||
|
||||
return F.l1_loss(pred_action, actions)
|
||||
|
||||
# quick test
|
||||
|
||||
if __name__ == '__main__':
|
||||
vat = SigLIPVAT(
|
||||
num_tasks = 4,
|
||||
dim_extra_token = 32,
|
||||
time_seq_len = 2,
|
||||
num_views = 2,
|
||||
depth = 4,
|
||||
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
|
||||
0, 1, 26, 27
|
||||
)
|
||||
)
|
||||
|
||||
vat.load_siglip() # load siglip weights from hf
|
||||
|
||||
# inputs
|
||||
|
||||
images = torch.randn(1, 2, 3, 2, 224, 224) # (b, v, c, t, h, w)
|
||||
tasks = torch.randint(0, 4, (1,))
|
||||
extra = torch.randn(1, 32)
|
||||
|
||||
actions = torch.randn(1, 50, 32) # actions for learning
|
||||
|
||||
loss = vat(images, actions = actions, tasks = tasks, extra = extra, freeze_vit = True)
|
||||
loss.backward()
|
||||
|
||||
# after much training
|
||||
|
||||
pred_actions = vat(images, tasks = tasks, extra = extra)
|
||||
|
||||
assert pred_actions.shape == (1, 50, 32)
|
||||
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Module, ModuleList
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
@@ -11,7 +12,7 @@ def pair(t):
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
class FeedForward(Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
@@ -26,7 +27,7 @@ class FeedForward(nn.Module):
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
@@ -62,13 +63,14 @@ class Attention(nn.Module):
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = nn.ModuleList([])
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
@@ -80,7 +82,7 @@ class Transformer(nn.Module):
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class ViT(nn.Module):
|
||||
class ViT(Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
@@ -90,7 +92,9 @@ class ViT(nn.Module):
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
num_cls_tokens = 1 if pool == 'cls' else 0
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
|
||||
@@ -99,8 +103,9 @@ class ViT(nn.Module):
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(num_cls_tokens, dim))
|
||||
self.pos_embedding = nn.Parameter(torch.randn(num_patches + num_cls_tokens, dim))
|
||||
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
@@ -108,19 +113,25 @@ class ViT(nn.Module):
|
||||
self.pool = pool
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Linear(dim, num_classes)
|
||||
self.mlp_head = nn.Linear(dim, num_classes) if num_classes > 0 else None
|
||||
|
||||
def forward(self, img):
|
||||
batch = img.shape[0]
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
cls_tokens = repeat(self.cls_token, '... d -> b ... d', b = batch)
|
||||
x = torch.cat((cls_tokens, x), dim = 1)
|
||||
|
||||
seq = x.shape[1]
|
||||
|
||||
x = x + self.pos_embedding[:seq]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
if self.mlp_head is None:
|
||||
return x
|
||||
|
||||
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
x = self.to_latent(x)
|
||||
|
||||
@@ -89,7 +89,7 @@ class ViT(nn.Module):
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b (f h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b (f h w) (pf p1 p2 c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
|
||||
353
vit_pytorch/vit_nd_pope.py
Normal file
353
vit_pytorch/vit_nd_pope.py
Normal file
@@ -0,0 +1,353 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import pi, nn, arange, cat, stack, Tensor
|
||||
from torch.nn import Module, ModuleList
|
||||
from torch.amp import autocast
|
||||
|
||||
from einops import rearrange, repeat, reduce, pack, unpack
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def l2norm(t):
|
||||
return F.normalize(t, dim = -1, p = 2)
|
||||
|
||||
def join(arr, delimiter = ' '):
|
||||
return delimiter.join(arr)
|
||||
|
||||
def ensure_tuple(t, length):
|
||||
if isinstance(t, (tuple, list)):
|
||||
assert len(t) == length, f'Expected tuple of length {length}, got {len(t)}'
|
||||
return tuple(t)
|
||||
|
||||
return (t,) * length
|
||||
|
||||
# golden gate rotary - Jerry Xiong, PhD student at UIUC
|
||||
# https://jerryxio.ng/posts/nd-rope/
|
||||
|
||||
# but using polar version instead
|
||||
# Gopalakrishnan et al. https://arxiv.org/abs/2509.10534
|
||||
|
||||
def _phi(m: int) -> float:
|
||||
x = 2.0
|
||||
for _ in range(10):
|
||||
x = (1 + x) ** (1.0 / (m + 1.0))
|
||||
return x
|
||||
|
||||
def make_directions(n: int, d: int) -> Tensor:
|
||||
g = _phi(d)
|
||||
alpha = (1.0 / g) ** arange(1, d + 1, dtype = torch.float64)
|
||||
i = arange(1, n + 1, dtype = torch.float64).unsqueeze(1)
|
||||
z = torch.fmod(i * alpha, 1.0)
|
||||
directions = torch.erfinv(2.0 * z - 1.0)
|
||||
directions = l2norm(directions)
|
||||
return directions.float()
|
||||
|
||||
class GoldenGatePoPENd(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim_pos: int,
|
||||
heads: int,
|
||||
dim_head: int,
|
||||
min_freq: float = 1.0,
|
||||
max_freq: float = 10000.0,
|
||||
p_zero_freqs: float = 0.0, # proportion of frequencies set to 0
|
||||
init_learned_bias_uniform = False
|
||||
):
|
||||
super().__init__()
|
||||
n_freqs = dim_head
|
||||
n_zero_freqs = round(p_zero_freqs * n_freqs)
|
||||
|
||||
omega = cat((
|
||||
torch.zeros(n_zero_freqs),
|
||||
min_freq * (max_freq / min_freq) ** torch.linspace(0, 1, n_freqs - n_zero_freqs),
|
||||
))
|
||||
|
||||
directions = rearrange(
|
||||
make_directions(heads * n_freqs, dim_pos),
|
||||
'(h f) p -> h f p',
|
||||
h = heads
|
||||
)
|
||||
|
||||
omega_expanded = rearrange(omega, 'f -> f 1')
|
||||
self.register_buffer('freqs', directions * omega_expanded) # shape: (h, f, p)
|
||||
|
||||
self.learned_bias = nn.Parameter(torch.zeros(heads, dim_head))
|
||||
|
||||
if init_learned_bias_uniform:
|
||||
self.learned_bias.uniform_(-2. * pi, 0.)
|
||||
|
||||
@autocast('cuda', enabled = False)
|
||||
def forward(self, pos):
|
||||
|
||||
freqs = rearrange(self.freqs, 'h f p -> 1 h 1 f p')
|
||||
positions = rearrange(pos.float(), 'b n p -> b 1 n 1 p')
|
||||
|
||||
# compute theta for each (batch, head, seq, freq)
|
||||
|
||||
theta = reduce(freqs * positions, 'b h n f p -> b h n f', 'sum')
|
||||
|
||||
bias = self.learned_bias.clamp(-2. * pi, 0.)
|
||||
bias = rearrange(bias, 'h d -> h 1 d')
|
||||
|
||||
return theta, bias
|
||||
|
||||
@autocast('cuda', enabled = False)
|
||||
def apply_polar_pos_emb(t, freqs):
|
||||
orig_dtype = t.dtype
|
||||
|
||||
t = t.float()
|
||||
t = F.softplus(t)
|
||||
|
||||
out = cat((t * freqs.cos(), t * freqs.sin()), dim = -1)
|
||||
|
||||
return out.type(orig_dtype)
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qk = nn.Linear(dim, inner_dim * 2, bias = False)
|
||||
self.to_v = nn.Linear(dim, inner_dim, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x, polar_pos_emb = None):
|
||||
x = self.norm(x)
|
||||
qkv = (*self.to_qk(x).chunk(2, dim = -1), self.to_v(x))
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
if exists(polar_pos_emb):
|
||||
freqs, bias = polar_pos_emb
|
||||
q = apply_polar_pos_emb(q, freqs)
|
||||
k = apply_polar_pos_emb(k, freqs + bias)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., polar_emb = None):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.polar_emb = polar_emb
|
||||
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
def forward(self, x, pos = None):
|
||||
|
||||
# pope embedding
|
||||
|
||||
polar_pos_emb = None
|
||||
if exists(pos) and exists(self.polar_emb):
|
||||
polar_pos_emb = self.polar_emb(pos)
|
||||
|
||||
# transformer layers
|
||||
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x, polar_pos_emb) + x
|
||||
x = ff(x) + x
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class ViTND(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ndim: int,
|
||||
input_shape: int | tuple[int, ...],
|
||||
patch_size: int | tuple[int, ...],
|
||||
num_classes: int,
|
||||
dim: int,
|
||||
depth: int,
|
||||
heads: int,
|
||||
mlp_dim: int,
|
||||
channels: int = 3,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.,
|
||||
emb_dropout: float = 0.,
|
||||
pope_min_freq: float = 1.0,
|
||||
pope_max_freq: float = 10000.0,
|
||||
pope_p_zero_freqs: float = 0.0,
|
||||
pope_init_learned_bias_uniform = False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert 1 <= ndim <= 7, 'ndim must be between 1 and 7'
|
||||
|
||||
self.ndim = ndim
|
||||
|
||||
input_shape = ensure_tuple(input_shape, ndim)
|
||||
patch_size = ensure_tuple(patch_size, ndim)
|
||||
|
||||
for i, (inp_dim, patch_dim) in enumerate(zip(input_shape, patch_size)):
|
||||
assert inp_dim % patch_dim == 0, f'Input dimension {i} ({inp_dim}) must be divisible by patch size ({patch_dim})'
|
||||
|
||||
num_patches_per_dim = [inp_dim // patch_dim for inp_dim, patch_dim in zip(input_shape, patch_size)]
|
||||
num_patches = 1
|
||||
for n in num_patches_per_dim:
|
||||
num_patches *= n
|
||||
|
||||
patch_dim = channels
|
||||
for p in patch_size:
|
||||
patch_dim *= p
|
||||
|
||||
dim_names = 'fghijkl'[:ndim]
|
||||
|
||||
input_dims = [f'({d} p{i})' for i, d in enumerate(dim_names)]
|
||||
patch_dims = [f'p{i}' for i in range(ndim)]
|
||||
|
||||
input_pattern = f'b c {join(input_dims)}'
|
||||
output_pattern = f'b {join(dim_names)} ({join(patch_dims)} c)'
|
||||
rearrange_str = f'{input_pattern} -> {output_pattern}'
|
||||
|
||||
rearrange_kwargs = {f'p{i}': p for i, p in enumerate(patch_size)}
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange(rearrange_str, **rearrange_kwargs),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
# golden gate pope
|
||||
|
||||
self.polar_emb = GoldenGatePoPENd(
|
||||
dim_pos = ndim,
|
||||
heads = heads,
|
||||
dim_head = dim_head,
|
||||
min_freq = pope_min_freq,
|
||||
max_freq = pope_max_freq,
|
||||
p_zero_freqs = pope_p_zero_freqs,
|
||||
init_learned_bias_uniform = pope_init_learned_bias_uniform
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, polar_emb = self.polar_emb)
|
||||
|
||||
self.to_latent = nn.Identity()
|
||||
self.mlp_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def muon_parameters(self):
|
||||
params = []
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, Attention):
|
||||
params.extend([
|
||||
m.to_v.weight,
|
||||
m.to_out[0].weight
|
||||
])
|
||||
elif isinstance(m, FeedForward):
|
||||
params.extend([
|
||||
m.net[1].weight,
|
||||
m.net[-2].weight
|
||||
])
|
||||
|
||||
return params
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
return_embed = False
|
||||
):
|
||||
x = self.to_patch_embedding(x) # (b, *spatial_dims, patch_dim)
|
||||
|
||||
batch, *spatial_dims, _, device = *x.shape, x.device
|
||||
|
||||
# Generate position coordinates
|
||||
|
||||
grids = [arange(d, device = device, dtype = torch.float32) for d in spatial_dims]
|
||||
grid = torch.meshgrid(*grids, indexing = 'ij')
|
||||
pos = stack(grid, dim = -1) # (*spatial_dims, ndim)
|
||||
|
||||
# flatten spatial dimensions for attention with nd rotary
|
||||
|
||||
pos = repeat(pos, '... p -> b (...) p', b = batch)
|
||||
x, packed_shape = pack([x], 'b * d')
|
||||
|
||||
x = self.dropout(x)
|
||||
|
||||
embed = self.transformer(x, pos)
|
||||
|
||||
# return the embed with reconstituted patch shape
|
||||
|
||||
if return_embed:
|
||||
embed, = unpack(embed, packed_shape, 'b * d')
|
||||
return embed
|
||||
|
||||
# pooling to logits
|
||||
|
||||
pooled = reduce(embed, 'b n d -> b d', 'mean')
|
||||
|
||||
pooled = self.to_latent(pooled)
|
||||
return self.mlp_head(pooled)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
model = ViTND(
|
||||
ndim = 5,
|
||||
input_shape = (4, 8, 16, 32, 64),
|
||||
patch_size = (2, 2, 4, 4, 8),
|
||||
num_classes = 1000,
|
||||
dim = 512,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048,
|
||||
channels = 3,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
data = torch.randn(3, 3, 4, 8, 16, 32, 64)
|
||||
|
||||
logits = model(data)
|
||||
|
||||
embed = model(data, return_embed = True)
|
||||
assert embed.shape == (3, 2, 4, 4, 8, 8, 512)
|
||||
217
vit_pytorch/vit_with_keel_post_ln.py
Normal file
217
vit_pytorch/vit_with_keel_post_ln.py
Normal file
@@ -0,0 +1,217 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Module, ModuleList
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# functions
|
||||
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
def default(v, d):
|
||||
return v if exists(v) else d
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim, bias = False),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.norm = nn.LayerNorm(dim, bias = False)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
depth,
|
||||
heads,
|
||||
dim_head,
|
||||
mlp_dim,
|
||||
dropout = 0.,
|
||||
keel_residual_scale = None
|
||||
):
|
||||
super().__init__()
|
||||
assert depth > 1
|
||||
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.extend([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
])
|
||||
|
||||
num_layers = depth * 2
|
||||
self.keel_residual_scale = default(keel_residual_scale, num_layers)
|
||||
|
||||
self.post_norms = ModuleList([nn.LayerNorm(dim, bias = False) for _ in range(num_layers - 1)])
|
||||
|
||||
def forward(self, x):
|
||||
residual_scale = self.keel_residual_scale
|
||||
|
||||
for layer_ind, layer in enumerate(self.layers):
|
||||
first_layer = layer_ind == 0
|
||||
|
||||
residual = x
|
||||
|
||||
out = layer(x)
|
||||
|
||||
if first_layer:
|
||||
x = out + residual
|
||||
continue
|
||||
|
||||
post_norm = self.post_norms[layer_ind - 1]
|
||||
|
||||
x = post_norm(out + residual * residual_scale)
|
||||
|
||||
return x
|
||||
|
||||
class ViT(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_size,
|
||||
patch_size,
|
||||
num_classes,
|
||||
dim,
|
||||
depth,
|
||||
heads,
|
||||
mlp_dim,
|
||||
pool = 'cls',
|
||||
channels = 3,
|
||||
dim_head = 64,
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.,
|
||||
keel_residual_scale = None
|
||||
):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
num_cls_tokens = 1 if pool == 'cls' else 0
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.cls_token = nn.Parameter(torch.randn(num_cls_tokens, dim))
|
||||
self.pos_embedding = nn.Parameter(torch.randn(num_patches + num_cls_tokens, dim))
|
||||
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(
|
||||
dim,
|
||||
depth,
|
||||
heads,
|
||||
dim_head,
|
||||
mlp_dim,
|
||||
dropout,
|
||||
keel_residual_scale = keel_residual_scale
|
||||
)
|
||||
|
||||
self.pool = pool
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Linear(dim, num_classes) if num_classes > 0 else None
|
||||
|
||||
def forward(self, img):
|
||||
batch = img.shape[0]
|
||||
x = self.to_patch_embedding(img)
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '... d -> b ... d', b = batch)
|
||||
x = torch.cat((cls_tokens, x), dim = 1)
|
||||
|
||||
seq = x.shape[1]
|
||||
|
||||
x = x + self.pos_embedding[:seq]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
if not exists(self.mlp_head):
|
||||
return x
|
||||
|
||||
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.mlp_head(x)
|
||||
|
||||
if __name__ == '__main__':
|
||||
v = ViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 256, 256)
|
||||
|
||||
preds = v(img)
|
||||
|
||||
assert preds.shape == (1, 1000)
|
||||
@@ -1,5 +1,10 @@
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch import nn, cat
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Module, ModuleList
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
|
||||
from einops import rearrange, repeat, reduce
|
||||
from einops.layers.torch import Rearrange
|
||||
@@ -9,12 +14,15 @@ from einops.layers.torch import Rearrange
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def divisible_by(num, den):
|
||||
return (num % den) == 0
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
class FeedForward(Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
@@ -28,9 +36,11 @@ class FeedForward(nn.Module):
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., use_flash_attn = True):
|
||||
super().__init__()
|
||||
self.use_flash_attn = use_flash_attn
|
||||
self.dropout_p = dropout
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
@@ -48,61 +58,100 @@ class Attention(nn.Module):
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
def flash_attn(self, q, k, v, mask = None):
|
||||
|
||||
with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION]):
|
||||
|
||||
out = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask = mask,
|
||||
dropout_p = self.dropout_p,
|
||||
is_causal = False,
|
||||
scale = self.scale
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
def forward(self, x, mask = None):
|
||||
batch, seq, _ = x.shape
|
||||
|
||||
x = self.norm(x)
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b j -> b 1 1 j')
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
if self.use_flash_attn:
|
||||
out = self.flash_attn(q, k, v, mask = mask)
|
||||
|
||||
else:
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
if exists(mask):
|
||||
dots = dots.masked_fill(~mask, -torch.finfo(dots.dtype).max)
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., use_flash_attn = True):
|
||||
super().__init__()
|
||||
self.use_flash_attn = use_flash_attn
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = nn.ModuleList([])
|
||||
self.layers = ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_flash_attn = use_flash_attn),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
def forward(self, x):
|
||||
|
||||
def forward(self, x, mask = None):
|
||||
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = attn(x, mask = mask) + x
|
||||
x = ff(x) + x
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class FactorizedTransformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
class FactorizedTransformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., use_flash_attn = True):
|
||||
super().__init__()
|
||||
self.use_flash_attn = use_flash_attn
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_flash_attn = use_flash_attn),
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_flash_attn = use_flash_attn),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
b, f, n, _ = x.shape
|
||||
def forward(self, x, mask = None):
|
||||
batch, frames, seq, _ = x.shape
|
||||
|
||||
if exists(mask):
|
||||
mask = repeat(mask, 'b ... -> (b space) ...', space = x.shape[2])
|
||||
|
||||
for spatial_attn, temporal_attn, ff in self.layers:
|
||||
x = rearrange(x, 'b f n d -> (b f) n d')
|
||||
x = spatial_attn(x) + x
|
||||
x = rearrange(x, '(b f) n d -> (b n) f d', b=b, f=f)
|
||||
x = temporal_attn(x) + x
|
||||
x = rearrange(x, '(b f) n d -> (b n) f d', b = batch, f = frames)
|
||||
x = temporal_attn(x, mask = mask) + x
|
||||
x = ff(x) + x
|
||||
x = rearrange(x, '(b n) f d -> b f n d', b=b, n=n)
|
||||
x = rearrange(x, '(b n) f d -> b f n d', b = batch, n = seq)
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class ViT(nn.Module):
|
||||
class ViViT(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -122,13 +171,14 @@ class ViT(nn.Module):
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.,
|
||||
variant = 'factorized_encoder',
|
||||
use_flash_attn: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(image_patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size'
|
||||
assert divisible_by(image_height, patch_height) and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
assert divisible_by(frames, frame_patch_size), 'Frames must be divisible by frame patch size'
|
||||
assert variant in ('factorized_encoder', 'factorized_self_attention'), f'variant = {variant} is not implemented'
|
||||
|
||||
num_image_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
@@ -138,10 +188,12 @@ class ViT(nn.Module):
|
||||
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
|
||||
self.frame_patch_size = frame_patch_size
|
||||
|
||||
self.global_average_pool = pool == 'mean'
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b f (h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b f (h w) (pf p1 p2 c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim)
|
||||
@@ -154,11 +206,11 @@ class ViT(nn.Module):
|
||||
|
||||
if variant == 'factorized_encoder':
|
||||
self.temporal_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None
|
||||
self.spatial_transformer = Transformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)
|
||||
self.temporal_transformer = Transformer(dim, temporal_depth, heads, dim_head, mlp_dim, dropout)
|
||||
self.spatial_transformer = Transformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout, use_flash_attn)
|
||||
self.temporal_transformer = Transformer(dim, temporal_depth, heads, dim_head, mlp_dim, dropout, use_flash_attn)
|
||||
elif variant == 'factorized_self_attention':
|
||||
assert spatial_depth == temporal_depth, 'Spatial and temporal depth must be the same for factorized self-attention'
|
||||
self.factorized_transformer = FactorizedTransformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)
|
||||
self.factorized_transformer = FactorizedTransformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout, use_flash_attn)
|
||||
|
||||
self.pool = pool
|
||||
self.to_latent = nn.Identity()
|
||||
@@ -166,25 +218,36 @@ class ViT(nn.Module):
|
||||
self.mlp_head = nn.Linear(dim, num_classes)
|
||||
self.variant = variant
|
||||
|
||||
def forward(self, video):
|
||||
x = self.to_patch_embedding(video)
|
||||
b, f, n, _ = x.shape
|
||||
def forward(self, video, mask = None):
|
||||
device = video.device
|
||||
|
||||
x = x + self.pos_embedding[:, :f, :n]
|
||||
x = self.to_patch_embedding(video)
|
||||
batch, frames, seq, _ = x.shape
|
||||
|
||||
x = x + self.pos_embedding[:, :frames, :seq]
|
||||
|
||||
if exists(self.spatial_cls_token):
|
||||
spatial_cls_tokens = repeat(self.spatial_cls_token, '1 1 d -> b f 1 d', b = b, f = f)
|
||||
x = torch.cat((spatial_cls_tokens, x), dim = 2)
|
||||
spatial_cls_tokens = repeat(self.spatial_cls_token, '1 1 d -> b f 1 d', b = batch, f = frames)
|
||||
x = cat((spatial_cls_tokens, x), dim = 2)
|
||||
|
||||
x = self.dropout(x)
|
||||
|
||||
# maybe temporal mask
|
||||
|
||||
temporal_mask = None
|
||||
|
||||
if exists(mask):
|
||||
temporal_mask = reduce(mask, 'b (f patch) -> b f', 'all', patch = self.frame_patch_size)
|
||||
|
||||
# the two variants
|
||||
|
||||
if self.variant == 'factorized_encoder':
|
||||
x = rearrange(x, 'b f n d -> (b f) n d')
|
||||
|
||||
# attend across space
|
||||
|
||||
x = self.spatial_transformer(x)
|
||||
x = rearrange(x, '(b f) n d -> b f n d', b = b)
|
||||
x = rearrange(x, '(b f) n d -> b f n d', b = batch)
|
||||
|
||||
# excise out the spatial cls tokens or average pool for temporal attention
|
||||
|
||||
@@ -193,22 +256,50 @@ class ViT(nn.Module):
|
||||
# append temporal CLS tokens
|
||||
|
||||
if exists(self.temporal_cls_token):
|
||||
temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = b)
|
||||
temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = batch)
|
||||
|
||||
x = torch.cat((temporal_cls_tokens, x), dim = 1)
|
||||
|
||||
x = cat((temporal_cls_tokens, x), dim = 1)
|
||||
|
||||
if exists(temporal_mask):
|
||||
temporal_mask = F.pad(temporal_mask, (1, 0), value = True)
|
||||
|
||||
# attend across time
|
||||
|
||||
x = self.temporal_transformer(x)
|
||||
x = self.temporal_transformer(x, mask = temporal_mask)
|
||||
|
||||
# excise out temporal cls token or average pool
|
||||
|
||||
x = x[:, 0] if not self.global_average_pool else reduce(x, 'b f d -> b d', 'mean')
|
||||
|
||||
elif self.variant == 'factorized_self_attention':
|
||||
x = self.factorized_transformer(x)
|
||||
|
||||
x = self.factorized_transformer(x, mask = temporal_mask)
|
||||
|
||||
x = x[:, 0, 0] if not self.global_average_pool else reduce(x, 'b f n d -> b d', 'mean')
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.mlp_head(x)
|
||||
|
||||
# main
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
vivit = ViViT(
|
||||
dim = 512,
|
||||
spatial_depth = 2,
|
||||
temporal_depth = 2,
|
||||
heads = 4,
|
||||
mlp_dim = 2048,
|
||||
image_size = 256,
|
||||
image_patch_size = 16,
|
||||
frames = 8,
|
||||
frame_patch_size = 2,
|
||||
num_classes = 1000,
|
||||
variant = 'factorized_encoder',
|
||||
)
|
||||
|
||||
video = torch.randn(3, 3, 8, 256, 256)
|
||||
mask = torch.randint(0, 2, (3, 8)).bool()
|
||||
|
||||
logits = vivit(video, mask = None)
|
||||
assert logits.shape == (3, 1000)
|
||||
|
||||
Reference in New Issue
Block a user