register tokens for the AST in VAAT

This commit is contained in:
lucidrains
2025-11-22 08:12:01 -08:00
parent b35a97de05
commit 6aa0374313
3 changed files with 16 additions and 4 deletions

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "vit-pytorch"
version = "1.15.5"
version = "1.15.6"
description = "Vision Transformer (ViT) - Pytorch"
readme = { file = "README.md", content-type = "text/markdown" }
license = { file = "LICENSE" }

View File

@@ -215,7 +215,8 @@ class AST(Module):
spec_hop_length = None,
spec_pad = 0,
spec_center = True,
spec_pad_mode = 'reflect'
spec_pad_mode = 'reflect',
num_register_tokens = 4
):
super().__init__()
self.dim = dim
@@ -256,8 +257,11 @@ class AST(Module):
)
self.final_norm = nn.LayerNorm(dim)
self.mlp_head = nn.Linear(dim, num_classes) if exists(num_classes) else nn.Identity()
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
def forward(
self,
raw_audio_or_spec, # (b t) | (b f t)
@@ -296,6 +300,12 @@ class AST(Module):
tokens = rearrange(tokens, 'b ... c -> b (...) c')
# register tokens
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
tokens, packed_shape = pack((register_tokens, tokens), 'b * d')
# attention
attended, hiddens = self.transformer(tokens, return_hiddens = True)
@@ -307,6 +317,8 @@ class AST(Module):
if return_hiddens:
return normed, stack(hiddens)
register_tokens, normed = unpack(normed, packed_shape, 'b * d')
pooled = reduce(normed, 'b n d -> b d', 'mean')
maybe_logits = self.mlp_head(pooled)
@@ -384,7 +396,7 @@ class ViT(Module):
if return_hiddens:
return x, stack(hiddens)
cls_tokens, x, register_tokens = unpack(x, packed_shape, 'b * d')
register_tokens, cls_tokens, x = unpack(x, packed_shape, 'b * d')
x = x.mean(dim = 1) if self.pool == 'mean' else cls_tokens

View File

@@ -237,7 +237,7 @@ class ViT(Module):
if return_hiddens:
return x, stack(hiddens)
cls_tokens, x, register_tokens = unpack(x, packed_shape, 'b * d')
register_tokens, cls_tokens, x = unpack(x, packed_shape, 'b * d')
x = x.mean(dim = 1) if self.pool == 'mean' else cls_tokens