Compare commits

...

10 Commits

Author SHA1 Message Date
lucidrains
6032a54b48 patch 2026-02-11 11:49:51 -08:00
Harikrishna KP
06a1f42924 Fix ViViT Transformer not passing use_flash_attn to Attention and duplicate mask reshape (#360)
Two related bugs in vivit.py:

1. Transformer.__init__ accepted use_flash_attn but never forwarded it to the
   Attention modules it creates. Since Attention defaults to use_flash_attn=True,
   setting use_flash_attn=False on ViViT had no effect on the factorized_encoder
   variant's spatial and temporal transformers.

2. Attention.forward reshaped the mask from 2D to 4D before the flash/non-flash
   branch (line 82), then attempted to reshape it again inside the non-flash
   branch (line 92). When the non-flash code path is actually reached with a
   mask, einops raises an error because the mask is already 4D.

   These bugs masked each other: bug #1 prevented bug #2 from triggering because
   the non-flash path was never taken even when requested.

Fix: pass use_flash_attn through to Attention in Transformer.__init__, and
remove the redundant second mask rearrange in the non-flash branch.
2026-02-11 11:49:31 -08:00
Phil Wang
6ae6a3ab64 cleanup 2026-02-04 13:29:40 -08:00
lucidrains
827300beed add vit with keel post ln, proposed by bytedance for scaling depth 2026-02-04 09:09:17 -08:00
lucidrains
a7c4e7f79f best practices 2026-01-28 05:04:20 -08:00
lucidrains
54ec3f2af5 address https://github.com/lucidrains/vit-pytorch/issues/357 2026-01-17 05:11:36 -08:00
lucidrains
9aa52cce49 do an actual vat with siglip arch, and have gemini flash craft the weight loading script from hf 2026-01-15 11:27:05 -08:00
lucidrains
4c89017444 fix up vivit 2026-01-08 06:36:40 -08:00
Eyal Mazuz
580258d99e Allow to pass mask parameter for temporal transformer in ViVit (#356)
* Mask for temporal transformer in ViVit

This allows to pad videos to certain length which allow the transformer
to ignore padded frames using batch sizes > 1

* Added flash attention to vivit

* Added flash attention to vivit

* Added flash attention to vivit
2026-01-08 06:08:54 -08:00
lucidrains
6f1caef987 allow for no final output head on the vit 2026-01-06 13:00:48 -08:00
8 changed files with 884 additions and 45 deletions

View File

@@ -1358,7 +1358,7 @@ learner = Dino(
hidden_layer = 'to_latent', # hidden layer name or index, from which to extract the embedding
projection_hidden_size = 256, # projector network hidden dimension
projection_layers = 4, # number of layers in projection network
num_classes_K = 65336, # output logits dimensions (referenced as K in paper)
num_classes_K = 65536, # output logits dimensions (referenced as K in paper)
student_temp = 0.9, # student temperature
teacher_temp = 0.04, # teacher temperature, needs to be annealed from 0.04 to 0.07 over 30 epochs
local_upper_crop_scale = 0.4, # upper bound for local crop - 0.4 was recommended in the paper
@@ -2225,4 +2225,28 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```
```bibtex
@misc{qiu2025gatedattentionlargelanguage,
title = {Gated Attention for Large Language Models: Non-linearity, Sparsity, and Attention-Sink-Free},
author = {Zihan Qiu and Zekun Wang and Bo Zheng and Zeyu Huang and Kaiyue Wen and Songlin Yang and Rui Men and Le Yu and Fei Huang and Suozhi Huang and Dayiheng Liu and Jingren Zhou and Junyang Lin},
year = {2025},
eprint = {2505.06708},
archivePrefix = {arXiv},
primaryClass = {cs.CL},
url = {https://arxiv.org/abs/2505.06708},
}
```
```bibtex
@misc{chen2026postlayernormbackstableexpressive,
title = {Post-LayerNorm Is Back: Stable, ExpressivE, and Deep},
author = {Chen Chen and Lai Wei},
year = {2026},
eprint = {2601.19895},
archivePrefix = {arXiv},
primaryClass = {cs.LG},
url = {https://arxiv.org/abs/2601.19895},
}
```
*I visualise a time when we will be to robots what dogs are to humans, and Im rooting for the machines.* — Claude Shannon

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "vit-pytorch"
version = "1.17.1"
version = "1.17.8"
description = "Vision Transformer (ViT) - Pytorch"
readme = { file = "README.md", content-type = "text/markdown" }
license = { file = "LICENSE" }

View File

@@ -120,6 +120,12 @@ class Attention(Module):
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim_context, inner_dim * 2, bias = False)
self.to_out_gates = nn.Sequential(
nn.Linear(dim, heads),
Rearrange('b ... h -> b h ... 1'),
nn.Sigmoid()
)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
@@ -150,6 +156,9 @@ class Attention(Module):
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = out * self.to_out_gates(x)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)

View File

@@ -92,6 +92,12 @@ class Attention(Module):
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim_context, inner_dim * 2, bias = False)
self.to_out_gates = nn.Sequential(
nn.Linear(dim, heads),
Rearrange('b ... h -> b h ... 1'),
nn.Sigmoid()
)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
@@ -122,6 +128,8 @@ class Attention(Module):
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = out * self.to_out_gates(x) # https://arxiv.org/abs/2505.06708
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)

487
vit_pytorch/vat_siglip.py Normal file
View File

@@ -0,0 +1,487 @@
from __future__ import annotations
from contextlib import nullcontext
from pathlib import Path
import torch
import torch.nn.functional as F
from torch import nn, cat, stack, tensor, einsum
from torch.nn import Module, ModuleList
from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange
# helpers
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# attention
class Attention(Module):
def __init__(
self,
dim,
dim_context = None,
heads = 8,
dim_head = 64,
dropout = 0.,
norm_eps = 1e-6,
gate_attn = False
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
inner_dim = dim_head * heads
self.norm = nn.LayerNorm(dim, eps = norm_eps)
self.is_cross_attn = exists(dim_context)
dim_context = default(dim_context, dim)
self.norm_context = nn.LayerNorm(dim_context, eps = norm_eps) if self.is_cross_attn else None
self.to_q = nn.Linear(dim, inner_dim)
self.to_kv = nn.Linear(dim_context, inner_dim * 2)
self.to_out_gates = nn.Sequential(
nn.Linear(dim, heads),
Rearrange('b ... h -> b h ... 1'),
nn.Sigmoid()
) if gate_attn else None
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, context = None):
x = self.norm(x)
if self.is_cross_attn:
assert exists(context)
context = self.norm_context(context)
else:
context = x
q = self.to_q(x)
k, v = self.to_kv(context).chunk(2, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))
sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
if exists(self.to_out_gates):
out = out * self.to_out_gates(x) # https://arxiv.org/abs/2505.06708
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
def FeedForward(
dim,
dim_inner,
norm_eps = 1e-6
):
return nn.Sequential(
nn.LayerNorm(dim, eps = norm_eps),
nn.Linear(dim, dim_inner),
nn.GELU(approximate = 'tanh'),
nn.Linear(dim_inner, dim)
)
class SigLIP(Module):
def __init__(
self,
image_size = 224,
patch_size = 14,
dim = 1152,
depth = 27,
heads = 16,
mlp_dim = 4304,
norm_eps = 1e-6
):
super().__init__()
self.dim = dim
self.depth = depth
num_patches = (image_size // patch_size) ** 2
dim_head = dim // heads
self.to_patch_embed = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
nn.Linear(patch_size * patch_size * 3, dim)
)
self.pos_embed = nn.Parameter(torch.randn(num_patches, dim))
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, norm_eps = norm_eps),
FeedForward(dim = dim, dim_inner = mlp_dim, norm_eps = norm_eps)
]))
self.norm = nn.LayerNorm(dim, eps = norm_eps)
def forward(self, x, return_hiddens = False):
x = self.to_patch_embed(x)
num_patches = x.shape[1]
x = x + self.pos_embed[:num_patches]
hiddens = []
for attn, ff in self.layers:
hiddens.append(x)
x = attn(x) + x
x = ff(x) + x
out = self.norm(x)
if not return_hiddens:
return out
return out, stack(hiddens)
class FiLM(Module):
def __init__(self, dim):
super().__init__()
proj = nn.Linear(dim, dim * 2)
self.to_gamma_beta = nn.Sequential(
proj,
Rearrange('b (two d) -> two b 1 d', two = 2)
)
nn.init.zeros_(proj.weight)
nn.init.zeros_(proj.bias)
def forward(self, tokens, cond):
gamma, beta = self.to_gamma_beta(cond)
return tokens * gamma + beta
class SigLIPVAT(Module):
def __init__(
self,
*,
dim = 512,
depth = 27,
heads = 8,
dim_head = 64,
dim_action = 32,
mlp_dim = 2048,
num_views = 1,
num_tasks = None,
dim_extra_token = None,
num_register_tokens = 4,
action_chunk_len = 50,
time_seq_len = 1,
dropout = 0.,
add_self_attn = True,
self_attn_heads = 4,
self_attn_dim_head = 32,
vit_layer_indices: tuple[int, ...] | None = None,
siglip_image_size = 224,
siglip_patch_size = 14,
siglip_dim = 1152,
siglip_depth = 27,
siglip_heads = 16,
siglip_mlp_dim = 4304,
siglip_norm_eps = 1e-6,
):
super().__init__()
self.vit = SigLIP(
image_size = siglip_image_size,
patch_size = siglip_patch_size,
dim = siglip_dim,
depth = siglip_depth,
heads = siglip_heads,
mlp_dim = siglip_mlp_dim,
norm_eps = siglip_norm_eps
)
vit_dim = siglip_dim
self.vit_dim = vit_dim
vit_layer_indices = default(vit_layer_indices, tuple(range(depth)))
assert len(vit_layer_indices) == depth, f'number of vit layer indices {len(vit_layer_indices)} does not much the VAT depth {depth}'
self.register_buffer('layer_indices', tensor(vit_layer_indices), persistent = False)
# handle maybe multiple frames
is_video = time_seq_len > 1
self.is_video = is_video
self.time_seq_len = time_seq_len
self.time_pos_emb = nn.Parameter(torch.randn(time_seq_len, vit_dim) * 1e-2) if is_video else None
# maybe view embeddings
self.view_emb = nn.Parameter(torch.randn(num_views, vit_dim) * 1e-2) if exists(num_views) and num_views > 1 else None
# handle maybe task conditioning
self.has_tasks = exists(num_tasks)
if self.has_tasks:
self.task_emb = nn.Parameter(torch.randn(num_tasks, dim) * 1e-2)
# register tokens
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
# to action tokens
self.action_pos_emb = nn.Parameter(torch.randn(action_chunk_len, dim) * 1e-2)
self.layers = ModuleList([])
for _ in range(depth):
maybe_film = FiLM(dim = dim) if self.has_tasks else None
maybe_self_attn = Attention(dim = dim, heads = self_attn_heads, dim_head = self_attn_dim_head, dropout = dropout) if add_self_attn else None
self.layers.append(ModuleList([
maybe_film,
maybe_self_attn,
Attention(dim = dim, dim_context = vit_dim, heads = heads, dim_head = dim_head, dropout = dropout, gate_attn = True),
FeedForward(dim = dim, dim_inner = mlp_dim)
]))
self.final_norm = nn.LayerNorm(dim)
self.to_pred_action = nn.Linear(dim, dim_action, bias = False)
# handle the extra token
self.accept_extra_token = exists(dim_extra_token)
if exists(dim_extra_token):
self.to_extra_token = nn.Linear(dim_extra_token, dim)
def load_siglip(
self,
repo_id = 'google/siglip-so400m-patch14-224',
folder = 'checkpoints/siglip'
):
folder = Path(folder)
if not folder.exists():
from huggingface_hub import snapshot_download
snapshot_download(
repo_id = repo_id,
local_dir = folder,
allow_patterns = ['config.json', 'model.safetensors']
)
from safetensors import safe_open
weights_path = folder / 'model.safetensors'
# Auto-detect prefix based on keys
with safe_open(weights_path, framework = 'pt') as f:
keys = f.keys()
vi_p = ''
if any(k.startswith('paligemma_with_expert.paligemma.model.vision_tower.vision_model') for k in keys):
vi_p = 'paligemma_with_expert.paligemma.model.vision_tower.vision_model.'
elif any(k.startswith('vision_model') for k in keys):
vi_p = 'vision_model.'
pz_state = self.vit.state_dict()
def copy_weight_bias(pz_prefix, vi_prefix):
pz_state[f'{pz_prefix}.weight'].copy_(f.get_tensor(f'{vi_prefix}.weight'))
pz_state[f'{pz_prefix}.bias'].copy_(f.get_tensor(f'{vi_prefix}.bias'))
# patch embedding
patch_weight = rearrange(f.get_tensor(f'{vi_p}embeddings.patch_embedding.weight'), 'd c h w -> d (h w c)')
pz_state['to_patch_embed.1.weight'].copy_(patch_weight)
pz_state['to_patch_embed.1.bias'].copy_(f.get_tensor(f'{vi_p}embeddings.patch_embedding.bias'))
# position embedding
pz_state['pos_embed'].copy_(f.get_tensor(f'{vi_p}embeddings.position_embedding.weight'))
# transformer layers
for i in range(self.vit.depth):
v_pi = f'{vi_p}encoder.layers.{i}'
v_pz = f'layers.{i}'
# attention
copy_weight_bias(f'{v_pz}.0.norm', f'{v_pi}.layer_norm1')
copy_weight_bias(f'{v_pz}.0.to_q', f'{v_pi}.self_attn.q_proj')
vk, vv = [f.get_tensor(f'{v_pi}.self_attn.{x}_proj.weight') for x in ('k', 'v')]
bk, bv = [f.get_tensor(f'{v_pi}.self_attn.{x}_proj.bias') for x in ('k', 'v')]
pz_state[f'{v_pz}.0.to_kv.weight'].copy_(cat((vk, vv), dim = 0))
pz_state[f'{v_pz}.0.to_kv.bias'].copy_(cat((bk, bv), dim = 0))
copy_weight_bias(f'{v_pz}.0.to_out.0', f'{v_pi}.self_attn.out_proj')
# feedforward
copy_weight_bias(f'{v_pz}.1.0', f'{v_pi}.layer_norm2')
copy_weight_bias(f'{v_pz}.1.1', f'{v_pi}.mlp.fc1')
copy_weight_bias(f'{v_pz}.1.3', f'{v_pi}.mlp.fc2')
# post-layernorm
copy_weight_bias('norm', f'{vi_p}post_layernorm')
self.vit.load_state_dict(pz_state)
print(f'Successfully loaded SigLIP weights from {repo_id}')
def forward(
self,
video_or_image, # (b v? c t? h w)
*,
extra = None,
tasks = None,
actions = None,
return_hiddens = False,
freeze_vit = False
):
batch = video_or_image.shape[0]
return_loss = exists(actions)
# handle some various input dimensions
if video_or_image.ndim == 4:
video_or_image = rearrange(video_or_image, 'b 1 c h w')
if video_or_image.ndim == 5:
video_or_image = rearrange(video_or_image, 'b v c h w -> b v c 1 h w')
assert video_or_image.shape[3] == self.time_seq_len
# to images
images = rearrange(video_or_image, 'b v c t h w -> b v t c h w')
images, packed_shape = pack([images], '* c h w')
# get representation trajectory from vit
vit_forward_context = torch.no_grad if freeze_vit else nullcontext
with vit_forward_context():
embed, hiddens = self.vit(images, return_hiddens = True)
hiddens = cat((hiddens, embed[None, ...]))
# extract the hiddens needed for the action cross attention
hiddens = hiddens[self.layer_indices]
# pack temporarily for embedding
hiddens, = unpack(hiddens, packed_shape, 'l * n d') # l for layers
# maybe add time embeddings
if self.is_video:
time_pos_emb = rearrange(self.time_pos_emb, 't d -> t 1 d')
hiddens = hiddens + time_pos_emb
# maybe view embeddings
if exists(self.view_emb):
view_emb = rearrange(self.view_emb, 'v d -> v 1 1 d')
hiddens = hiddens + view_emb
# maybe tasks
if exists(tasks):
task_emb = self.task_emb[tasks]
# cross from actions to representation trajectory
context = rearrange(hiddens, 'l b v t n d -> l b (v t n) d')
# get main action tokens and maybe append extra
action_tokens = repeat(self.action_pos_emb, 'k d -> b k d', b = batch)
has_extra = exists(extra)
if has_extra:
extra_token = self.to_extra_token(extra)
action_tokens, packed_extra = pack([action_tokens, extra_token], 'b * d')
# register tokens
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
action_tokens, registers_packed_shape = pack((register_tokens, action_tokens), 'b * d')
# cross attention
vat_hiddens = [action_tokens]
for (maybe_film, maybe_self_attn, cross_attn, ff), layer_context in zip(self.layers, context):
if exists(tasks):
action_tokens = maybe_film(action_tokens, task_emb)
action_tokens = cross_attn(action_tokens, layer_context) + action_tokens
if exists(maybe_self_attn):
action_tokens = maybe_self_attn(action_tokens) + action_tokens
action_tokens = ff(action_tokens) + action_tokens
vat_hiddens.append(action_tokens)
# unpack registers
_, action_tokens = unpack(action_tokens, registers_packed_shape, 'b * d')
# maybe unpack extra
if has_extra:
action_tokens, _ = unpack(action_tokens, packed_extra, 'b * d')
# norm and prediction
action_tokens = self.final_norm(action_tokens)
pred_action = self.to_pred_action(action_tokens)
if not return_loss:
if not return_hiddens:
return pred_action
return pred_action, stack(vat_hiddens)
assert pred_action.shape[1] == actions.shape[1]
return F.l1_loss(pred_action, actions)
# quick test
if __name__ == '__main__':
vat = SigLIPVAT(
num_tasks = 4,
dim_extra_token = 32,
time_seq_len = 2,
num_views = 2,
depth = 4,
vit_layer_indices = ( # extending on the paper, allow for any order of hiddens, and also allow for depth index (which equates to the final embedding output from the vit)
0, 1, 26, 27
)
)
vat.load_siglip() # load siglip weights from hf
# inputs
images = torch.randn(1, 2, 3, 2, 224, 224) # (b, v, c, t, h, w)
tasks = torch.randint(0, 4, (1,))
extra = torch.randn(1, 32)
actions = torch.randn(1, 50, 32) # actions for learning
loss = vat(images, actions = actions, tasks = tasks, extra = extra, freeze_vit = True)
loss.backward()
# after much training
pred_actions = vat(images, tasks = tasks, extra = extra)
assert pred_actions.shape == (1, 50, 32)

View File

@@ -113,7 +113,7 @@ class ViT(Module):
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes)
self.mlp_head = nn.Linear(dim, num_classes) if num_classes > 0 else None
def forward(self, img):
batch = img.shape[0]
@@ -129,6 +129,9 @@ class ViT(Module):
x = self.transformer(x)
if self.mlp_head is None:
return x
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)

View File

@@ -0,0 +1,217 @@
from __future__ import annotations
import torch
from torch import nn
from torch.nn import Module, ModuleList
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# functions
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# classes
class FeedForward(Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim, bias = False),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.norm = nn.LayerNorm(dim, bias = False)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(Module):
def __init__(
self,
dim,
depth,
heads,
dim_head,
mlp_dim,
dropout = 0.,
keel_residual_scale = None
):
super().__init__()
assert depth > 1
self.layers = ModuleList([])
for _ in range(depth):
self.layers.extend([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
])
num_layers = depth * 2
self.keel_residual_scale = default(keel_residual_scale, num_layers)
self.post_norms = ModuleList([nn.LayerNorm(dim, bias = False) for _ in range(num_layers - 1)])
def forward(self, x):
residual_scale = self.keel_residual_scale
for layer_ind, layer in enumerate(self.layers):
first_layer = layer_ind == 0
residual = x
out = layer(x)
if first_layer:
x = out + residual
continue
post_norm = self.post_norms[layer_ind - 1]
x = post_norm(out + residual * residual_scale)
return x
class ViT(Module):
def __init__(
self,
*,
image_size,
patch_size,
num_classes,
dim,
depth,
heads,
mlp_dim,
pool = 'cls',
channels = 3,
dim_head = 64,
dropout = 0.,
emb_dropout = 0.,
keel_residual_scale = None
):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
num_cls_tokens = 1 if pool == 'cls' else 0
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.cls_token = nn.Parameter(torch.randn(num_cls_tokens, dim))
self.pos_embedding = nn.Parameter(torch.randn(num_patches + num_cls_tokens, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(
dim,
depth,
heads,
dim_head,
mlp_dim,
dropout,
keel_residual_scale = keel_residual_scale
)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes) if num_classes > 0 else None
def forward(self, img):
batch = img.shape[0]
x = self.to_patch_embedding(img)
cls_tokens = repeat(self.cls_token, '... d -> b ... d', b = batch)
x = torch.cat((cls_tokens, x), dim = 1)
seq = x.shape[1]
x = x + self.pos_embedding[:seq]
x = self.dropout(x)
x = self.transformer(x)
if not exists(self.mlp_head):
return x
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)
if __name__ == '__main__':
v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
img = torch.randn(1, 3, 256, 256)
preds = v(img)
assert preds.shape == (1, 1000)

View File

@@ -1,5 +1,10 @@
from collections import namedtuple
import torch
from torch import nn
from torch import nn, cat
import torch.nn.functional as F
from torch.nn import Module, ModuleList
from torch.nn.attention import SDPBackend, sdpa_kernel
from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange
@@ -9,12 +14,15 @@ from einops.layers.torch import Rearrange
def exists(val):
return val is not None
def divisible_by(num, den):
return (num % den) == 0
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# classes
class FeedForward(nn.Module):
class FeedForward(Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
@@ -28,9 +36,11 @@ class FeedForward(nn.Module):
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., use_flash_attn = True):
super().__init__()
self.use_flash_attn = use_flash_attn
self.dropout_p = dropout
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
@@ -48,61 +58,100 @@ class Attention(nn.Module):
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
def flash_attn(self, q, k, v, mask = None):
with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION]):
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask = mask,
dropout_p = self.dropout_p,
is_causal = False,
scale = self.scale
)
return out
def forward(self, x, mask = None):
batch, seq, _ = x.shape
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
attn = self.attend(dots)
attn = self.dropout(attn)
if self.use_flash_attn:
out = self.flash_attn(q, k, v, mask = mask)
else:
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
if exists(mask):
dots = dots.masked_fill(~mask, -torch.finfo(dots.dtype).max)
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., use_flash_attn = True):
super().__init__()
self.use_flash_attn = use_flash_attn
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_flash_attn = use_flash_attn),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
def forward(self, x, mask = None):
for attn, ff in self.layers:
x = attn(x) + x
x = attn(x, mask = mask) + x
x = ff(x) + x
return self.norm(x)
class FactorizedTransformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
class FactorizedTransformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., use_flash_attn = True):
super().__init__()
self.use_flash_attn = use_flash_attn
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_flash_attn = use_flash_attn),
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_flash_attn = use_flash_attn),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
b, f, n, _ = x.shape
def forward(self, x, mask = None):
batch, frames, seq, _ = x.shape
if exists(mask):
mask = repeat(mask, 'b ... -> (b space) ...', space = x.shape[2])
for spatial_attn, temporal_attn, ff in self.layers:
x = rearrange(x, 'b f n d -> (b f) n d')
x = spatial_attn(x) + x
x = rearrange(x, '(b f) n d -> (b n) f d', b=b, f=f)
x = temporal_attn(x) + x
x = rearrange(x, '(b f) n d -> (b n) f d', b = batch, f = frames)
x = temporal_attn(x, mask = mask) + x
x = ff(x) + x
x = rearrange(x, '(b n) f d -> b f n d', b=b, n=n)
x = rearrange(x, '(b n) f d -> b f n d', b = batch, n = seq)
return self.norm(x)
class ViT(nn.Module):
class ViViT(Module):
def __init__(
self,
*,
@@ -122,13 +171,14 @@ class ViT(nn.Module):
dropout = 0.,
emb_dropout = 0.,
variant = 'factorized_encoder',
use_flash_attn: bool = True,
):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(image_patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size'
assert divisible_by(image_height, patch_height) and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
assert divisible_by(frames, frame_patch_size), 'Frames must be divisible by frame patch size'
assert variant in ('factorized_encoder', 'factorized_self_attention'), f'variant = {variant} is not implemented'
num_image_patches = (image_height // patch_height) * (image_width // patch_width)
@@ -138,6 +188,8 @@ class ViT(nn.Module):
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.frame_patch_size = frame_patch_size
self.global_average_pool = pool == 'mean'
self.to_patch_embedding = nn.Sequential(
@@ -154,11 +206,11 @@ class ViT(nn.Module):
if variant == 'factorized_encoder':
self.temporal_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None
self.spatial_transformer = Transformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)
self.temporal_transformer = Transformer(dim, temporal_depth, heads, dim_head, mlp_dim, dropout)
self.spatial_transformer = Transformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout, use_flash_attn)
self.temporal_transformer = Transformer(dim, temporal_depth, heads, dim_head, mlp_dim, dropout, use_flash_attn)
elif variant == 'factorized_self_attention':
assert spatial_depth == temporal_depth, 'Spatial and temporal depth must be the same for factorized self-attention'
self.factorized_transformer = FactorizedTransformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)
self.factorized_transformer = FactorizedTransformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout, use_flash_attn)
self.pool = pool
self.to_latent = nn.Identity()
@@ -166,25 +218,36 @@ class ViT(nn.Module):
self.mlp_head = nn.Linear(dim, num_classes)
self.variant = variant
def forward(self, video):
x = self.to_patch_embedding(video)
b, f, n, _ = x.shape
def forward(self, video, mask = None):
device = video.device
x = x + self.pos_embedding[:, :f, :n]
x = self.to_patch_embedding(video)
batch, frames, seq, _ = x.shape
x = x + self.pos_embedding[:, :frames, :seq]
if exists(self.spatial_cls_token):
spatial_cls_tokens = repeat(self.spatial_cls_token, '1 1 d -> b f 1 d', b = b, f = f)
x = torch.cat((spatial_cls_tokens, x), dim = 2)
spatial_cls_tokens = repeat(self.spatial_cls_token, '1 1 d -> b f 1 d', b = batch, f = frames)
x = cat((spatial_cls_tokens, x), dim = 2)
x = self.dropout(x)
# maybe temporal mask
temporal_mask = None
if exists(mask):
temporal_mask = reduce(mask, 'b (f patch) -> b f', 'all', patch = self.frame_patch_size)
# the two variants
if self.variant == 'factorized_encoder':
x = rearrange(x, 'b f n d -> (b f) n d')
# attend across space
x = self.spatial_transformer(x)
x = rearrange(x, '(b f) n d -> b f n d', b = b)
x = rearrange(x, '(b f) n d -> b f n d', b = batch)
# excise out the spatial cls tokens or average pool for temporal attention
@@ -193,22 +256,50 @@ class ViT(nn.Module):
# append temporal CLS tokens
if exists(self.temporal_cls_token):
temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = b)
temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = batch)
x = torch.cat((temporal_cls_tokens, x), dim = 1)
x = cat((temporal_cls_tokens, x), dim = 1)
if exists(temporal_mask):
temporal_mask = F.pad(temporal_mask, (1, 0), value = True)
# attend across time
x = self.temporal_transformer(x)
x = self.temporal_transformer(x, mask = temporal_mask)
# excise out temporal cls token or average pool
x = x[:, 0] if not self.global_average_pool else reduce(x, 'b f d -> b d', 'mean')
elif self.variant == 'factorized_self_attention':
x = self.factorized_transformer(x)
x = self.factorized_transformer(x, mask = temporal_mask)
x = x[:, 0, 0] if not self.global_average_pool else reduce(x, 'b f n d -> b d', 'mean')
x = self.to_latent(x)
return self.mlp_head(x)
# main
if __name__ == '__main__':
vivit = ViViT(
dim = 512,
spatial_depth = 2,
temporal_depth = 2,
heads = 4,
mlp_dim = 2048,
image_size = 256,
image_patch_size = 16,
frames = 8,
frame_patch_size = 2,
num_classes = 1000,
variant = 'factorized_encoder',
)
video = torch.randn(3, 3, 8, 256, 256)
mask = torch.randint(0, 2, (3, 8)).bool()
logits = vivit(video, mask = None)
assert logits.shape == (3, 1000)