From 950c901b80e08be700425b396718852a31f28097 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 10 Aug 2023 14:36:21 -0700 Subject: [PATCH] fix linear head in simple vit, thanks to @atkos --- setup.py | 2 +- vit_pytorch/simple_vit.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 4d9e283..02209ec 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '1.4.0', + version = '1.4.1', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', long_description_content_type = 'text/markdown', diff --git a/vit_pytorch/simple_vit.py b/vit_pytorch/simple_vit.py index f535693..54b5f9c 100644 --- a/vit_pytorch/simple_vit.py +++ b/vit_pytorch/simple_vit.py @@ -105,7 +105,7 @@ class SimpleViT(nn.Module): self.pool = "mean" self.to_latent = nn.Identity() - self.linear_head = nn.LayerNorm(dim) + self.linear_head = nn.Linear(dim, num_classes) def forward(self, img): device = img.device