From ff44d97cb0ce222bd9ad86563f9c0d83f660590c Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 22 Nov 2021 18:08:49 -0800 Subject: [PATCH] make initial channels customizable for PiT --- setup.py | 2 +- vit_pytorch/pit.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index efed44a..1436b3d 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.24.1', + version = '0.24.2', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', author = 'Phil Wang', diff --git a/vit_pytorch/pit.py b/vit_pytorch/pit.py index a5f92f8..35d123d 100644 --- a/vit_pytorch/pit.py +++ b/vit_pytorch/pit.py @@ -129,14 +129,15 @@ class PiT(nn.Module): mlp_dim, dim_head = 64, dropout = 0., - emb_dropout = 0. + emb_dropout = 0., + channels = 3 ): super().__init__() assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' assert isinstance(depth, tuple), 'depth must be a tuple of integers, specifying the number of blocks before each downsizing' heads = cast_tuple(heads, len(depth)) - patch_dim = 3 * patch_size ** 2 + patch_dim = channels * patch_size ** 2 self.to_patch_embedding = nn.Sequential( nn.Unfold(kernel_size = patch_size, stride = patch_size // 2),