From f7d59cecb5a368ff910ce93bf0c70daff8378ca7 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Fri, 24 Oct 2025 14:00:38 -0700 Subject: [PATCH] some register tokens cannot hurt for VAT --- pyproject.toml | 2 +- vit_pytorch/vat.py | 38 +++++++++++++++++++++++++++++++------- 2 files changed, 32 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7ae8399..f0fe609 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "vit-pytorch" -version = "1.14.4" +version = "1.14.5" description = "Vision Transformer (ViT) - Pytorch" readme = { file = "README.md", content-type = "text/markdown" } license = { file = "LICENSE" } diff --git a/vit_pytorch/vat.py b/vit_pytorch/vat.py index aa8f529..14f6865 100644 --- a/vit_pytorch/vat.py +++ b/vit_pytorch/vat.py @@ -178,7 +178,8 @@ class ViT(Module): channels = 3, dim_head = 64, dropout = 0., - emb_dropout = 0. + emb_dropout = 0., + num_register_tokens = 0 ): super().__init__() self.dim = dim @@ -200,8 +201,8 @@ class ViT(Module): nn.LayerNorm(dim), ) - self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) - self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) + self.pos_embedding = nn.Parameter(torch.randn(num_patches, dim)) + self.cls_token = nn.Parameter(torch.randn(dim)) self.dropout = nn.Dropout(emb_dropout) self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) @@ -211,13 +212,19 @@ class ViT(Module): self.mlp_head = nn.Linear(dim, num_classes) + self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim) * 1e-2) + def forward(self, img, return_hiddens = False): x = self.to_patch_embedding(img) b, n, _ = x.shape - cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) - x = cat((cls_tokens, x), dim=1) - x += self.pos_embedding[:, :(n + 1)] + x += self.pos_embedding[:n] + + cls_tokens = repeat(self.cls_token, 'd -> b d', b = b) + register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = b) + + x, packed_shape = pack((register_tokens, cls_tokens, x), 'b * d') + x = self.dropout(x) x, hiddens = self.transformer(x, return_hiddens = True) @@ -227,7 +234,9 @@ class ViT(Module): if return_hiddens: return x, stack(hiddens) - x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] + cls_tokens, x, register_tokens = unpack(x, packed_shape, 'b * d') + + x = x.mean(dim = 1) if self.pool == 'mean' else cls_tokens x = self.to_latent(x) return self.mlp_head(x) @@ -251,6 +260,7 @@ class VAT(Module): num_views = None, num_tasks = None, dim_extra_token = None, + num_register_tokens = 4, action_chunk_len = 7, time_seq_len = 1, dropout = 0., @@ -295,6 +305,10 @@ class VAT(Module): if self.has_tasks: self.task_emb = nn.Parameter(torch.randn(num_tasks, dim) * 1e-2) + # register tokens from Darcet et al. + + self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim) * 1e-2) + # to action tokens self.action_pos_emb = nn.Parameter(torch.randn(action_chunk_len, dim) * 1e-2) @@ -407,6 +421,12 @@ class VAT(Module): action_tokens, packed_extra = pack([action_tokens, extra_token], 'b * d') + # register tokens + + register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch) + + action_tokens, registers_packed_shape = pack((register_tokens, action_tokens), 'b * d') + # cross attention hiddens = [action_tokens] @@ -425,6 +445,10 @@ class VAT(Module): hiddens.append(action_tokens) + # unpack registers + + _, action_tokens = unpack(action_tokens, registers_packed_shape, 'b * d') + # maybe unpack extra if has_extra: