From ad80b6c51ef7d1ff192189f71ccd5bd243c349b7 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 27 Nov 2025 16:56:36 -0800 Subject: [PATCH] fix positional embed for mean pool case and cleanup --- pyproject.toml | 2 +- vit_pytorch/vit.py | 30 ++++++++++++++++++------------ 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index afc281d..eff48eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "vit-pytorch" -version = "1.16.0" +version = "1.16.1" description = "Vision Transformer (ViT) - Pytorch" readme = { file = "README.md", content-type = "text/markdown" } license = { file = "LICENSE" } diff --git a/vit_pytorch/vit.py b/vit_pytorch/vit.py index b0ff7a3..a7eacc9 100644 --- a/vit_pytorch/vit.py +++ b/vit_pytorch/vit.py @@ -1,5 +1,6 @@ import torch from torch import nn +from torch.nn import Module, ModuleList from einops import rearrange, repeat from einops.layers.torch import Rearrange @@ -11,7 +12,7 @@ def pair(t): # classes -class FeedForward(nn.Module): +class FeedForward(Module): def __init__(self, dim, hidden_dim, dropout = 0.): super().__init__() self.net = nn.Sequential( @@ -26,7 +27,7 @@ class FeedForward(nn.Module): def forward(self, x): return self.net(x) -class Attention(nn.Module): +class Attention(Module): def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): super().__init__() inner_dim = dim_head * heads @@ -62,13 +63,14 @@ class Attention(nn.Module): out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out) -class Transformer(nn.Module): +class Transformer(Module): def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): super().__init__() self.norm = nn.LayerNorm(dim) - self.layers = nn.ModuleList([]) + self.layers = ModuleList([]) + for _ in range(depth): - self.layers.append(nn.ModuleList([ + self.layers.append(ModuleList([ Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), FeedForward(dim, mlp_dim, dropout = dropout) ])) @@ -80,7 +82,7 @@ class Transformer(nn.Module): return self.norm(x) -class ViT(nn.Module): +class ViT(Module): def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): super().__init__() image_height, image_width = pair(image_size) @@ -101,8 +103,9 @@ class ViT(nn.Module): nn.LayerNorm(dim), ) - self.cls_token = nn.Parameter(torch.randn(1, num_cls_tokens, dim)) - self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + num_cls_tokens, dim)) + self.num_cls_tokens = num_cls_tokens + self.cls_token = nn.Parameter(torch.randn(num_cls_tokens, dim)) + self.pos_embedding = nn.Parameter(torch.randn(num_patches + num_cls_tokens, dim)) self.dropout = nn.Dropout(emb_dropout) @@ -114,12 +117,15 @@ class ViT(nn.Module): self.mlp_head = nn.Linear(dim, num_classes) def forward(self, img): + batch = img.shape[0] x = self.to_patch_embedding(img) - b, n, _ = x.shape - cls_tokens = repeat(self.cls_token, '1 ... d -> b ... d', b = b) - x = torch.cat((cls_tokens, x), dim=1) - x += self.pos_embedding[:, :(n + 1)] + cls_tokens = repeat(self.cls_token, '... d -> b ... d', b = batch) + x = torch.cat((cls_tokens, x), dim = 1) + + seq = x.shape[1] + + x = x + self.pos_embedding[:seq] x = self.dropout(x) x = self.transformer(x)