mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2026-05-13 11:41:49 +00:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6f1caef987 |
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "vit-pytorch"
|
||||
version = "1.17.1"
|
||||
version = "1.17.3"
|
||||
description = "Vision Transformer (ViT) - Pytorch"
|
||||
readme = { file = "README.md", content-type = "text/markdown" }
|
||||
license = { file = "LICENSE" }
|
||||
|
||||
@@ -113,7 +113,7 @@ class ViT(Module):
|
||||
self.pool = pool
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Linear(dim, num_classes)
|
||||
self.mlp_head = nn.Linear(dim, num_classes) if num_classes > 0 else None
|
||||
|
||||
def forward(self, img):
|
||||
batch = img.shape[0]
|
||||
@@ -129,6 +129,9 @@ class ViT(Module):
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
if self.mlp_head is None:
|
||||
return x
|
||||
|
||||
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
x = self.to_latent(x)
|
||||
|
||||
Reference in New Issue
Block a user