add register tokens to the nested tensor 3d na vit example for researcher

This commit is contained in:
lucidrains
2024-08-28 12:21:31 -07:00
parent c4651a35a3
commit fcb9501cdd
2 changed files with 23 additions and 7 deletions

View File

@@ -6,7 +6,7 @@ with open('README.md') as f:
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.7.11',
version = '1.7.12',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description=long_description,

View File

@@ -163,6 +163,7 @@ class NaViT(Module):
dim_head = 64,
dropout = 0.,
emb_dropout = 0.,
num_registers = 4,
token_dropout_prob: float | None = None
):
super().__init__()
@@ -193,9 +194,18 @@ class NaViT(Module):
nn.LayerNorm(dim),
)
self.pos_embed_frame = nn.Parameter(torch.randn(patch_frame_dim, dim))
self.pos_embed_height = nn.Parameter(torch.randn(patch_height_dim, dim))
self.pos_embed_width = nn.Parameter(torch.randn(patch_width_dim, dim))
self.pos_embed_frame = nn.Parameter(torch.zeros(patch_frame_dim, dim))
self.pos_embed_height = nn.Parameter(torch.zeros(patch_height_dim, dim))
self.pos_embed_width = nn.Parameter(torch.zeros(patch_width_dim, dim))
# register tokens
self.register_tokens = nn.Parameter(torch.zeros(num_registers, dim))
nn.init.normal_(self.pos_embed_frame, std = 0.02)
nn.init.normal_(self.pos_embed_height, std = 0.02)
nn.init.normal_(self.pos_embed_width, std = 0.02)
nn.init.normal_(self.register_tokens, std = 0.02)
self.dropout = nn.Dropout(emb_dropout)
@@ -275,8 +285,6 @@ class NaViT(Module):
pos_embed = frame_embed + height_embed + width_embed
# use nested tensor for transformers and save on padding computation
tokens = torch.cat(tokens)
# linear projection to patch embeddings
@@ -287,7 +295,15 @@ class NaViT(Module):
tokens = tokens + pos_embed
tokens = nested_tensor(tokens.split(seq_lens.tolist()), layout = torch.jagged, device = device)
# add register tokens
tokens = tokens.split(seq_lens.tolist())
tokens = [torch.cat((self.register_tokens, one_tokens)) for one_tokens in tokens]
# use nested tensor for transformers and save on padding computation
tokens = nested_tensor(tokens, layout = torch.jagged, device = device)
# embedding dropout