fix linear head in simple vit, thanks to @atkos

This commit is contained in:
lucidrains
2023-08-10 14:36:21 -07:00
parent 3e5d1be6f0
commit 950c901b80
2 changed files with 2 additions and 2 deletions

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.4.0',
version = '1.4.1',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description_content_type = 'text/markdown',

View File

@@ -105,7 +105,7 @@ class SimpleViT(nn.Module):
self.pool = "mean"
self.to_latent = nn.Identity()
self.linear_head = nn.LayerNorm(dim)
self.linear_head = nn.Linear(dim, num_classes)
def forward(self, img):
device = img.device