diff --git a/README.md b/README.md index 57babd2..ba974c5 100644 --- a/README.md +++ b/README.md @@ -2319,4 +2319,13 @@ Coming from computer vision and new to transformers? Here are some resources tha } ``` +```bibtex +@misc{kim2026exploring, + title = {Exploring High-Order Self-Similarity for Video Understanding}, + author = {Manjin Kim and Heeseung Kwon and Karteek Alahari and Minsu Cho}, + year = {2026}, + url = {https://openreview.net/forum?id=Co6SCyBIjo} +} +``` + *I visualise a time when we will be to robots what dogs are to humans, and I’m rooting for the machines.* — Claude Shannon diff --git a/pyproject.toml b/pyproject.toml index f19a0cd..01a6938 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "vit-pytorch" -version = "1.20.4" +version = "1.20.5" description = "Vision Transformer (ViT) - Pytorch" readme = { file = "README.md", content-type = "text/markdown" } license = { file = "LICENSE" } diff --git a/vit_pytorch/vivit_with_moss.py b/vit_pytorch/vivit_with_moss.py new file mode 100644 index 0000000..709c749 --- /dev/null +++ b/vit_pytorch/vivit_with_moss.py @@ -0,0 +1,386 @@ +# https://openreview.net/forum?id=Co6SCyBIjo +# applied at https://arxiv.org/abs/2605.03269 - 50-85% jump in pick-place moving conveyer belt + +import torch +from torch import nn +from torch.nn import Module, ModuleList +import torch.nn.functional as F +from torch.nn.attention import SDPBackend, sdpa_kernel + +from einops import rearrange, repeat, reduce, einsum +from einops.layers.torch import Rearrange + +# helpers + +def exists(val): + return val is not None + + +def divisible_by(num, den): + return (num % den) == 0 + +def is_odd(n): + return not divisible_by(n, 2) + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + +def l2norm(t): + return F.normalize(t, dim = -1) + +# normalization helpers + +class ChanLayerNorm(Module): + def __init__(self, dim, eps = 1e-5): + super().__init__() + self.eps = eps + self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1)) + + def forward(self, x): + var = torch.var(x, dim = 1, unbiased = False, keepdim = True) + mean = torch.mean(x, dim = 1, keepdim = True) + return (x - mean) * (var + self.eps).rsqrt() * self.gamma + +# classes + +class FeedForward(Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., use_flash_attn = True): + super().__init__() + self.use_flash_attn = use_flash_attn + self.dropout_p = dropout + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.norm = nn.LayerNorm(dim) + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def flash_attn(self, q, k, v, mask = None): + with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION]): + out = F.scaled_dot_product_attention( + q, k, v, + attn_mask = mask, + dropout_p = self.dropout_p, + is_causal = False, + scale = self.scale + ) + return out + + def forward(self, x, mask = None): + x = self.norm(x) + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + if exists(mask): + mask = rearrange(mask, 'b j -> b 1 1 j') + + if self.use_flash_attn: + out = self.flash_attn(q, k, v, mask = mask) + else: + dots = einsum(q, k, 'b h i d, b h j d -> b h i j') * self.scale + if exists(mask): + dots = dots.masked_fill(~mask, -torch.finfo(dots.dtype).max) + + attn = self.attend(dots) + attn = self.dropout(attn) + out = torch.matmul(attn, v) + + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class Transformer(Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., use_flash_attn = True): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.layers = ModuleList([]) + for _ in range(depth): + self.layers.append(ModuleList([ + Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_flash_attn = use_flash_attn), + FeedForward(dim, mlp_dim, dropout = dropout) + ])) + + def forward(self, x, mask = None): + for attn, ff in self.layers: + x = attn(x, mask = mask) + x + x = ff(x) + x + return self.norm(x) + +# moss specific classes + +class STSSEncoder(Module): + def __init__(self, dim, local_time = 3, local_height = 3, local_width = 3, hidden_dim = 64): + super().__init__() + + self.spatial_to_hidden = nn.Linear(local_height * local_width, hidden_dim) + + self.conv = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, 3, padding = 1), + ChanLayerNorm(hidden_dim), + nn.GELU(), + nn.Conv2d(hidden_dim, hidden_dim, 3, padding = 1), + ChanLayerNorm(hidden_dim), + nn.GELU() + ) + + self.time_to_out = nn.Linear(local_time * hidden_dim, dim) + + def forward(self, sim): + b, t, h, w, lt, lh, lw = sim.shape + + x = rearrange(sim, 'b t h w lt lh lw -> b t h w lt (lh lw)') + x = self.spatial_to_hidden(x) + + x = rearrange(x, 'b t h w lt d -> (b t lt) d h w') + x = self.conv(x) + x = rearrange(x, '(b t lt) d h w -> b t h w (lt d)', b = b, t = t, lt = lt) + + return self.time_to_out(x) + +class MOSS(Module): + def __init__( + self, + dim, + local_time = 3, + local_height = 3, + local_width = 3, + hidden_dim = 64, + orders = 2, + causal = False + ): + super().__init__() + assert is_odd(local_time) and is_odd(local_height) and is_odd(local_width), 'MOSS local dimensions must be odd' + + self.local_time = local_time + self.local_height = local_height + self.local_width = local_width + self.causal = causal + + self.encoders = ModuleList([STSSEncoder(dim, local_time, local_height, local_width, hidden_dim) for _ in range(orders)]) + self.to_order_out = ModuleList([nn.Linear(dim, dim) for _ in range(orders)]) + self.to_out = nn.Linear(dim, dim) + + def stss_transform(self, x): + lt, lh, lw = self.local_time, self.local_height, self.local_width + + x = l2norm(x) + x = rearrange(x, 'b t h w c -> b c t h w') + + pad_h, pad_w = lh // 2, lw // 2 + pad_t_past, pad_t_future = (lt - 1, 0) if self.causal else (lt // 2, lt // 2) + + padded_x = F.pad(x, (pad_w, pad_w, pad_h, pad_h, pad_t_past, pad_t_future)) + windows = padded_x.unfold(2, lt, 1).unfold(3, lh, 1).unfold(4, lw, 1) + + return einsum(x, windows, 'b c t h w, b c t h w l u v -> b t h w l u v') + + def forward(self, x): + out = self.to_out(x) + + for encoder, to_order_out in zip(self.encoders, self.to_order_out): + sim = self.stss_transform(x) + x = encoder(sim) + out = out + to_order_out(x) + + return out + +# main architecture + +class ViViT(Module): + def __init__( + self, + *, + image_size, + image_patch_size, + frames, + frame_patch_size, + num_classes, + dim, + spatial_depth, + temporal_depth, + heads, + mlp_dim, + pool = 'cls', + channels = 3, + dim_head = 64, + dropout = 0., + emb_dropout = 0., + use_flash_attn: bool = True, + moss_local_time = 3, + moss_local_height = 3, + moss_local_width = 3, + moss_hidden_dim = 64, + moss_orders = 2, + moss_causal = True, + ): + super().__init__() + image_height, image_width = pair(image_size) + patch_height, patch_width = pair(image_patch_size) + + assert divisible_by(image_height, patch_height) and divisible_by(image_width, patch_width), 'Image dimensions must be divisible by the patch size.' + assert divisible_by(frames, frame_patch_size), 'Frames must be divisible by frame patch size' + assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' + + num_image_patches = (image_height // patch_height) * (image_width // patch_width) + num_frame_patches = frames // frame_patch_size + patch_dim = channels * patch_height * patch_width * frame_patch_size + + self.frame_patch_size = frame_patch_size + self.patch_h = image_height // patch_height + self.patch_w = image_width // patch_width + self.global_average_pool = pool == 'mean' + + self.to_patch_embedding = nn.Sequential( + Rearrange('b c (f pf) (h p1) (w p2) -> b f (h w) (pf p1 p2 c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size), + nn.LayerNorm(patch_dim), + nn.Linear(patch_dim, dim), + nn.LayerNorm(dim) + ) + + self.pos_embedding = nn.Parameter(torch.randn(1, num_frame_patches, num_image_patches, dim)) + self.dropout = nn.Dropout(emb_dropout) + + self.has_cls = not self.global_average_pool + self.spatial_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if self.has_cls else None + self.temporal_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if self.has_cls else None + + self.spatial_transformer = Transformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout, use_flash_attn) + self.temporal_transformer = Transformer(dim, temporal_depth, heads, dim_head, mlp_dim, dropout, use_flash_attn) + + self.moss = MOSS( + dim, + local_time = moss_local_time, + local_height = moss_local_height, + local_width = moss_local_width, + hidden_dim = moss_hidden_dim, + orders = moss_orders, + causal = moss_causal + ) + + self.mlp_head = nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, num_classes) + ) + + def forward(self, video, mask = None): + x = self.to_patch_embedding(video) + batch, frames, seq, _ = x.shape + + x = x + self.pos_embedding[:, :frames, :seq] + + if self.has_cls: + spatial_cls_tokens = repeat(self.spatial_cls_token, '1 1 d -> b f 1 d', b = batch, f = frames) + x = torch.cat((spatial_cls_tokens, x), dim = 2) + + x = self.dropout(x) + + # temporal mask + + temporal_mask = None + if exists(mask): + temporal_mask = reduce(mask, 'b (f patch) -> b f', 'all', patch = self.frame_patch_size) + + x = rearrange(x, 'b f n d -> (b f) n d') + + # attend across space + x = self.spatial_transformer(x) + x = rearrange(x, '(b f) n d -> b f n d', b = batch) + + # moss integration over spatial patch tokens + + if self.has_cls: + spatial_cls_tokens, patch_tokens = x[:, :, :1], x[:, :, 1:] + else: + patch_tokens = x + + patch_tokens = rearrange(patch_tokens, 'b f (h w) d -> b f h w d', h = self.patch_h, w = self.patch_w) + patch_tokens = self.moss(patch_tokens) + patch_tokens = rearrange(patch_tokens, 'b f h w d -> b f (h w) d') + + # pool spatial features + + moss_pooled = reduce(patch_tokens, 'b f n d -> b f d', 'mean') + + if self.has_cls: + x = rearrange(spatial_cls_tokens, 'b f 1 d -> b f d') + moss_pooled + else: + x = moss_pooled + + # append temporal cls tokens + + if self.has_cls: + temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d -> b 1 d', b = batch) + x = torch.cat((temporal_cls_tokens, x), dim = 1) + + if exists(temporal_mask): + temporal_mask = F.pad(temporal_mask, (1, 0), value = True) + + # attend across time + + x = self.temporal_transformer(x, mask = temporal_mask) + + # temporal pooling + + x = x[:, 0] if self.has_cls else reduce(x, 'b f d -> b d', 'mean') + + return self.mlp_head(x) + +if __name__ == '__main__': + vivit = ViViT( + dim = 512, + spatial_depth = 2, + temporal_depth = 2, + heads = 4, + mlp_dim = 2048, + image_size = 256, + image_patch_size = 32, + frames = 8, + frame_patch_size = 2, + num_classes = 1000, + ) + + video = torch.randn(2, 3, 8, 256, 256) + mask = torch.randint(0, 2, (2, 8)).bool() + + logits = vivit(video, mask = None) + assert logits.shape == (2, 1000) + + logits = vivit(video, mask = mask) + assert logits.shape == (2, 1000) + + moss = MOSS( + dim = 512, + local_time = 3, + local_height = 3, + local_width = 3, + hidden_dim = 64, + orders = 2, + causal = True + ) + + moss_input = torch.randn(2, 8, 16, 16, 512) # (batch, frames, height, width, dim) + moss_output = moss(moss_input) + assert moss_output.shape == (2, 8, 16, 16, 512)