fix pooling bugs across a few new archs

This commit is contained in:
Phil Wang
2021-04-19 22:36:23 -07:00
parent 4c29328363
commit 34f78294d3
4 changed files with 9 additions and 6 deletions

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.16.4',
version = '0.16.6',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',

View File

@@ -149,4 +149,4 @@ class LocalViT(nn.Module):
x = self.transformer(x)
return self.mlp_head(x)
return self.mlp_head(x[:, 0])

View File

@@ -162,8 +162,9 @@ class PiT(nn.Module):
layers.append(Pool(dim))
dim *= 2
self.layers = nn.Sequential(
*layers,
self.layers = nn.Sequential(*layers)
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
@@ -177,4 +178,6 @@ class PiT(nn.Module):
x += self.pos_embedding
x = self.dropout(x)
return self.layers(x)
x = self.layers(x)
return self.mlp_head(x[:, 0])

View File

@@ -192,4 +192,4 @@ class RvT(nn.Module):
fmap_dims = {'h': h // p, 'w': w // p}
x = self.transformer(x, fmap_dims = fmap_dims)
return self.mlp_head(x)
return self.mlp_head(x[:, 0])