diff --git a/README.md b/README.md index daa7660..232de0b 100644 --- a/README.md +++ b/README.md @@ -613,6 +613,29 @@ img = torch.randn(1, 3, 256, 128) # <-- not a square preds = v(img) # (1, 1000) ``` +- How do I pass in non-square patches? + +```python +import torch +from vit_pytorch import ViT + +v = ViT( + num_classes = 1000, + image_size = (256, 128), # image size is a tuple of (height, width) + patch_size = (32, 16), # patch size is a tuple of (height, width) + dim = 1024, + depth = 6, + heads = 16, + mlp_dim = 2048, + dropout = 0.1, + emb_dropout = 0.1 +) + +img = torch.randn(1, 3, 256, 128) + +preds = v(img) +``` + ## Resources Coming from computer vision and new to transformers? Here are some resources that greatly accelerated my learning. diff --git a/setup.py b/setup.py index ce7ba0e..9922c4b 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.17.2', + version = '0.17.3', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', author = 'Phil Wang', diff --git a/vit_pytorch/vit.py b/vit_pytorch/vit.py index 90d1977..519ef79 100644 --- a/vit_pytorch/vit.py +++ b/vit_pytorch/vit.py @@ -5,6 +5,13 @@ import torch.nn.functional as F from einops import rearrange, repeat from einops.layers.torch import Rearrange +# helpers + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + +# classes + class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() @@ -74,13 +81,17 @@ class Transformer(nn.Module): class ViT(nn.Module): def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): super().__init__() - assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' - num_patches = (image_size // patch_size) ** 2 - patch_dim = channels * patch_size ** 2 + image_height, image_width = pair(image_size) + patch_height, patch_width = pair(patch_size) + + assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' + + num_patches = (image_height // patch_height) * (image_width // patch_width) + patch_dim = channels * patch_height * patch_width assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' self.to_patch_embedding = nn.Sequential( - Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size), + Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), nn.Linear(patch_dim, dim), )