mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-29 23:52:27 +00:00
fix linear head in simple vit, thanks to @atkos
This commit is contained in:
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '1.4.0',
|
||||
version = '1.4.1',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
long_description_content_type = 'text/markdown',
|
||||
|
||||
@@ -105,7 +105,7 @@ class SimpleViT(nn.Module):
|
||||
self.pool = "mean"
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.linear_head = nn.LayerNorm(dim)
|
||||
self.linear_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, img):
|
||||
device = img.device
|
||||
|
||||
Reference in New Issue
Block a user