some register tokens cannot hurt for VAT

This commit is contained in:
lucidrains
2025-10-24 14:00:38 -07:00
parent a583cb5988
commit f7d59cecb5
2 changed files with 32 additions and 8 deletions

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "vit-pytorch"
version = "1.14.4"
version = "1.14.5"
description = "Vision Transformer (ViT) - Pytorch"
readme = { file = "README.md", content-type = "text/markdown" }
license = { file = "LICENSE" }

View File

@@ -178,7 +178,8 @@ class ViT(Module):
channels = 3,
dim_head = 64,
dropout = 0.,
emb_dropout = 0.
emb_dropout = 0.,
num_register_tokens = 0
):
super().__init__()
self.dim = dim
@@ -200,8 +201,8 @@ class ViT(Module):
nn.LayerNorm(dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.pos_embedding = nn.Parameter(torch.randn(num_patches, dim))
self.cls_token = nn.Parameter(torch.randn(dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
@@ -211,13 +212,19 @@ class ViT(Module):
self.mlp_head = nn.Linear(dim, num_classes)
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
def forward(self, img, return_hiddens = False):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x += self.pos_embedding[:n]
cls_tokens = repeat(self.cls_token, 'd -> b d', b = b)
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = b)
x, packed_shape = pack((register_tokens, cls_tokens, x), 'b * d')
x = self.dropout(x)
x, hiddens = self.transformer(x, return_hiddens = True)
@@ -227,7 +234,9 @@ class ViT(Module):
if return_hiddens:
return x, stack(hiddens)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
cls_tokens, x, register_tokens = unpack(x, packed_shape, 'b * d')
x = x.mean(dim = 1) if self.pool == 'mean' else cls_tokens
x = self.to_latent(x)
return self.mlp_head(x)
@@ -251,6 +260,7 @@ class VAT(Module):
num_views = None,
num_tasks = None,
dim_extra_token = None,
num_register_tokens = 4,
action_chunk_len = 7,
time_seq_len = 1,
dropout = 0.,
@@ -295,6 +305,10 @@ class VAT(Module):
if self.has_tasks:
self.task_emb = nn.Parameter(torch.randn(num_tasks, dim) * 1e-2)
# register tokens from Darcet et al.
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
# to action tokens
self.action_pos_emb = nn.Parameter(torch.randn(action_chunk_len, dim) * 1e-2)
@@ -407,6 +421,12 @@ class VAT(Module):
action_tokens, packed_extra = pack([action_tokens, extra_token], 'b * d')
# register tokens
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
action_tokens, registers_packed_shape = pack((register_tokens, action_tokens), 'b * d')
# cross attention
hiddens = [action_tokens]
@@ -425,6 +445,10 @@ class VAT(Module):
hiddens.append(action_tokens)
# unpack registers
_, action_tokens = unpack(action_tokens, registers_packed_shape, 'b * d')
# maybe unpack extra
if has_extra: