Compare commits

...

1 Commits

Author SHA1 Message Date
lucidrains
6f1caef987 allow for no final output head on the vit 2026-01-06 13:00:48 -08:00
2 changed files with 5 additions and 2 deletions

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "vit-pytorch"
version = "1.17.1"
version = "1.17.3"
description = "Vision Transformer (ViT) - Pytorch"
readme = { file = "README.md", content-type = "text/markdown" }
license = { file = "LICENSE" }

View File

@@ -113,7 +113,7 @@ class ViT(Module):
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes)
self.mlp_head = nn.Linear(dim, num_classes) if num_classes > 0 else None
def forward(self, img):
batch = img.shape[0]
@@ -129,6 +129,9 @@ class ViT(Module):
x = self.transformer(x)
if self.mlp_head is None:
return x
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)