From 59787a6b7e89bbed8a9ced3edba0af3246b89716 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 23 Dec 2020 18:15:40 -0800 Subject: [PATCH] allow for mean pool with efficient version too --- setup.py | 2 +- vit_pytorch/efficient.py | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index 1d557c7..74ee902 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.5.0', + version = '0.5.1', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', author = 'Phil Wang', diff --git a/vit_pytorch/efficient.py b/vit_pytorch/efficient.py index e3ff4c3..d1a9ed2 100644 --- a/vit_pytorch/efficient.py +++ b/vit_pytorch/efficient.py @@ -3,9 +3,10 @@ from einops import rearrange, repeat from torch import nn class ViT(nn.Module): - def __init__(self, *, image_size, patch_size, num_classes, dim, transformer, channels = 3): + def __init__(self, *, image_size, patch_size, num_classes, dim, transformer, pool = 'cls', channels = 3): super().__init__() assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size' + assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' num_patches = (image_size // patch_size) ** 2 patch_dim = channels * patch_size ** 2 @@ -16,7 +17,8 @@ class ViT(nn.Module): self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) self.transformer = transformer - self.to_cls_token = nn.Identity() + self.pool = pool + self.to_latent = nn.Identity() self.mlp_head = nn.Sequential( nn.LayerNorm(dim), @@ -35,5 +37,7 @@ class ViT(nn.Module): x += self.pos_embedding[:, :(n + 1)] x = self.transformer(x) - x = self.to_cls_token(x[:, 0]) + x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] + + x = self.to_latent(x) return self.mlp_head(x)