From 0ebd4edab94c211eff78406248896588d155a10c Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 27 Nov 2025 06:07:43 -0800 Subject: [PATCH] address https://github.com/lucidrains/vit-pytorch/issues/351 --- pyproject.toml | 2 +- vit_pytorch/vaat.py | 6 +++--- vit_pytorch/vit.py | 9 ++++++--- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 623497e..afc281d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "vit-pytorch" -version = "1.15.7" +version = "1.16.0" description = "Vision Transformer (ViT) - Pytorch" readme = { file = "README.md", content-type = "text/markdown" } license = { file = "LICENSE" } diff --git a/vit_pytorch/vaat.py b/vit_pytorch/vaat.py index 58fccfc..da17507 100644 --- a/vit_pytorch/vaat.py +++ b/vit_pytorch/vaat.py @@ -735,7 +735,7 @@ if __name__ == '__main__': mlp_dim = 384 * 4 ) - vat = VAAT( + vaat = VAAT( vit, ast, dim = 512, @@ -767,11 +767,11 @@ if __name__ == '__main__': actions = torch.randn(2, 7, 20) # actions for learning - loss = vat(images, audio, actions = actions, tasks = tasks, extra = extra, freeze_vit = True) + loss = vaat(images, audio, actions = actions, tasks = tasks, extra = extra, freeze_vit = True) loss.backward() # after much training - pred_actions, hiddens = vat(images, audio, tasks = tasks, extra = extra, return_hiddens = True) + pred_actions, hiddens = vaat(images, audio, tasks = tasks, extra = extra, return_hiddens = True) assert pred_actions.shape == (2, 7, 20) diff --git a/vit_pytorch/vit.py b/vit_pytorch/vit.py index 5b34a44..b0ff7a3 100644 --- a/vit_pytorch/vit.py +++ b/vit_pytorch/vit.py @@ -90,7 +90,9 @@ class ViT(nn.Module): num_patches = (image_height // patch_height) * (image_width // patch_width) patch_dim = channels * patch_height * patch_width + assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' + num_cls_tokens = 1 if pool == 'cls' else 0 self.to_patch_embedding = nn.Sequential( Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), @@ -99,8 +101,9 @@ class ViT(nn.Module): nn.LayerNorm(dim), ) - self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) - self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) + self.cls_token = nn.Parameter(torch.randn(1, num_cls_tokens, dim)) + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + num_cls_tokens, dim)) + self.dropout = nn.Dropout(emb_dropout) self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) @@ -114,7 +117,7 @@ class ViT(nn.Module): x = self.to_patch_embedding(img) b, n, _ = x.shape - cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) + cls_tokens = repeat(self.cls_token, '1 ... d -> b ... d', b = b) x = torch.cat((cls_tokens, x), dim=1) x += self.pos_embedding[:, :(n + 1)] x = self.dropout(x)