From c2b2db2a5451f03730a29157c6ff8c5099c59d57 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 22 Mar 2022 17:37:59 -0700 Subject: [PATCH] fix window size of none for scalable vit for rectangular images --- setup.py | 2 +- vit_pytorch/scalable_vit.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/setup.py b/setup.py index 990dad8..4eb67ad 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.28.1', + version = '0.28.2', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', author = 'Phil Wang', diff --git a/vit_pytorch/scalable_vit.py b/vit_pytorch/scalable_vit.py index 11a21e0..3dbb1be 100644 --- a/vit_pytorch/scalable_vit.py +++ b/vit_pytorch/scalable_vit.py @@ -156,8 +156,8 @@ class InteractiveWindowedSelfAttention(nn.Module): def forward(self, x): height, width, heads, wsz = *x.shape[-2:], self.heads, self.window_size - wsz = default(wsz, height) # take height as window size if not given - assert (height % wsz) == 0 and (width % wsz) == 0, f'height ({height}) or width ({width}) of feature map is not divisible by the window size ({wsz})' + wsz_h, wsz_w = default(wsz, height), default(wsz, width) + assert (height % wsz_h) == 0 and (width % wsz_w) == 0, f'height ({height}) or width ({width}) of feature map is not divisible by the window size ({wsz_h}, {wsz_w})' q, k, v = self.to_q(x), self.to_k(x), self.to_v(x) @@ -167,7 +167,7 @@ class InteractiveWindowedSelfAttention(nn.Module): # divide into window (and split out heads) for efficient self attention - q, k, v = map(lambda t: rearrange(t, 'b (h d) (x w1) (y w2) -> (b x y) h (w1 w2) d', h = heads, w1 = wsz, w2 = wsz), (q, k, v)) + q, k, v = map(lambda t: rearrange(t, 'b (h d) (x w1) (y w2) -> (b x y) h (w1 w2) d', h = heads, w1 = wsz_h, w2 = wsz_w), (q, k, v)) # similarity @@ -183,7 +183,7 @@ class InteractiveWindowedSelfAttention(nn.Module): # reshape the windows back to full feature map (and merge heads) - out = rearrange(out, '(b x y) h (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz, y = width // wsz, w1 = wsz, w2 = wsz) + out = rearrange(out, '(b x y) h (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz_h, y = width // wsz_w, w1 = wsz_h, w2 = wsz_w) # add LIM output