From 6aa037431347c8ee819b5598098f91ad9c0956eb Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sat, 22 Nov 2025 08:12:01 -0800 Subject: [PATCH] register tokens for the AST in VAAT --- pyproject.toml | 2 +- vit_pytorch/vaat.py | 16 ++++++++++++++-- vit_pytorch/vat.py | 2 +- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index db9adf7..211ddd7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "vit-pytorch" -version = "1.15.5" +version = "1.15.6" description = "Vision Transformer (ViT) - Pytorch" readme = { file = "README.md", content-type = "text/markdown" } license = { file = "LICENSE" } diff --git a/vit_pytorch/vaat.py b/vit_pytorch/vaat.py index 46cdf8f..c62c409 100644 --- a/vit_pytorch/vaat.py +++ b/vit_pytorch/vaat.py @@ -215,7 +215,8 @@ class AST(Module): spec_hop_length = None, spec_pad = 0, spec_center = True, - spec_pad_mode = 'reflect' + spec_pad_mode = 'reflect', + num_register_tokens = 4 ): super().__init__() self.dim = dim @@ -256,8 +257,11 @@ class AST(Module): ) self.final_norm = nn.LayerNorm(dim) + self.mlp_head = nn.Linear(dim, num_classes) if exists(num_classes) else nn.Identity() + self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim) * 1e-2) + def forward( self, raw_audio_or_spec, # (b t) | (b f t) @@ -296,6 +300,12 @@ class AST(Module): tokens = rearrange(tokens, 'b ... c -> b (...) c') + # register tokens + + register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch) + + tokens, packed_shape = pack((register_tokens, tokens), 'b * d') + # attention attended, hiddens = self.transformer(tokens, return_hiddens = True) @@ -307,6 +317,8 @@ class AST(Module): if return_hiddens: return normed, stack(hiddens) + register_tokens, normed = unpack(normed, packed_shape, 'b * d') + pooled = reduce(normed, 'b n d -> b d', 'mean') maybe_logits = self.mlp_head(pooled) @@ -384,7 +396,7 @@ class ViT(Module): if return_hiddens: return x, stack(hiddens) - cls_tokens, x, register_tokens = unpack(x, packed_shape, 'b * d') + register_tokens, cls_tokens, x = unpack(x, packed_shape, 'b * d') x = x.mean(dim = 1) if self.pool == 'mean' else cls_tokens diff --git a/vit_pytorch/vat.py b/vit_pytorch/vat.py index 2ee93cb..1397879 100644 --- a/vit_pytorch/vat.py +++ b/vit_pytorch/vat.py @@ -237,7 +237,7 @@ class ViT(Module): if return_hiddens: return x, stack(hiddens) - cls_tokens, x, register_tokens = unpack(x, packed_shape, 'b * d') + register_tokens, cls_tokens, x = unpack(x, packed_shape, 'b * d') x = x.mean(dim = 1) if self.pool == 'mean' else cls_tokens