diff --git a/setup.py b/setup.py index 061934f..4489b91 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ with open('README.md') as f: setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '1.6.5', + version = '1.6.6', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', long_description=long_description, diff --git a/vit_pytorch/cross_vit.py b/vit_pytorch/cross_vit.py index 8a1bc56..11156c2 100644 --- a/vit_pytorch/cross_vit.py +++ b/vit_pytorch/cross_vit.py @@ -170,12 +170,13 @@ class ImageEmbedder(nn.Module): dim, image_size, patch_size, - dropout = 0. + dropout = 0., + channels = 3 ): super().__init__() assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' num_patches = (image_size // patch_size) ** 2 - patch_dim = 3 * patch_size ** 2 + patch_dim = channels * patch_size ** 2 self.to_patch_embedding = nn.Sequential( Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size), @@ -223,11 +224,12 @@ class CrossViT(nn.Module): cross_attn_dim_head = 64, depth = 3, dropout = 0.1, - emb_dropout = 0.1 + emb_dropout = 0.1, + channels = 3 ): super().__init__() - self.sm_image_embedder = ImageEmbedder(dim = sm_dim, image_size = image_size, patch_size = sm_patch_size, dropout = emb_dropout) - self.lg_image_embedder = ImageEmbedder(dim = lg_dim, image_size = image_size, patch_size = lg_patch_size, dropout = emb_dropout) + self.sm_image_embedder = ImageEmbedder(dim = sm_dim, channels= channels, image_size = image_size, patch_size = sm_patch_size, dropout = emb_dropout) + self.lg_image_embedder = ImageEmbedder(dim = lg_dim, channels = channels, image_size = image_size, patch_size = lg_patch_size, dropout = emb_dropout) self.multi_scale_encoder = MultiScaleEncoder( depth = depth,