fix positional embed for mean pool case and cleanup

This commit is contained in:
lucidrains
2025-11-27 17:01:47 -08:00
parent 0ebd4edab9
commit fdaf7f92b9
2 changed files with 18 additions and 13 deletions

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "vit-pytorch"
version = "1.16.0"
version = "1.16.1"
description = "Vision Transformer (ViT) - Pytorch"
readme = { file = "README.md", content-type = "text/markdown" }
license = { file = "LICENSE" }

View File

@@ -1,5 +1,6 @@
import torch
from torch import nn
from torch.nn import Module, ModuleList
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
@@ -11,7 +12,7 @@ def pair(t):
# classes
class FeedForward(nn.Module):
class FeedForward(Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
@@ -26,7 +27,7 @@ class FeedForward(nn.Module):
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
@@ -62,13 +63,14 @@ class Attention(nn.Module):
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
@@ -80,7 +82,7 @@ class Transformer(nn.Module):
return self.norm(x)
class ViT(nn.Module):
class ViT(Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
@@ -101,8 +103,8 @@ class ViT(nn.Module):
nn.LayerNorm(dim),
)
self.cls_token = nn.Parameter(torch.randn(1, num_cls_tokens, dim))
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + num_cls_tokens, dim))
self.cls_token = nn.Parameter(torch.randn(num_cls_tokens, dim))
self.pos_embedding = nn.Parameter(torch.randn(num_patches + num_cls_tokens, dim))
self.dropout = nn.Dropout(emb_dropout)
@@ -114,12 +116,15 @@ class ViT(nn.Module):
self.mlp_head = nn.Linear(dim, num_classes)
def forward(self, img):
batch = img.shape[0]
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '1 ... d -> b ... d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
cls_tokens = repeat(self.cls_token, '... d -> b ... d', b = batch)
x = torch.cat((cls_tokens, x), dim = 1)
seq = x.shape[1]
x = x + self.pos_embedding[:seq]
x = self.dropout(x)
x = self.transformer(x)