mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
fix window size of none for scalable vit for rectangular images
This commit is contained in:
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.28.1',
|
||||
version = '0.28.2',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
|
||||
@@ -156,8 +156,8 @@ class InteractiveWindowedSelfAttention(nn.Module):
|
||||
def forward(self, x):
|
||||
height, width, heads, wsz = *x.shape[-2:], self.heads, self.window_size
|
||||
|
||||
wsz = default(wsz, height) # take height as window size if not given
|
||||
assert (height % wsz) == 0 and (width % wsz) == 0, f'height ({height}) or width ({width}) of feature map is not divisible by the window size ({wsz})'
|
||||
wsz_h, wsz_w = default(wsz, height), default(wsz, width)
|
||||
assert (height % wsz_h) == 0 and (width % wsz_w) == 0, f'height ({height}) or width ({width}) of feature map is not divisible by the window size ({wsz_h}, {wsz_w})'
|
||||
|
||||
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
|
||||
|
||||
@@ -167,7 +167,7 @@ class InteractiveWindowedSelfAttention(nn.Module):
|
||||
|
||||
# divide into window (and split out heads) for efficient self attention
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b (h d) (x w1) (y w2) -> (b x y) h (w1 w2) d', h = heads, w1 = wsz, w2 = wsz), (q, k, v))
|
||||
q, k, v = map(lambda t: rearrange(t, 'b (h d) (x w1) (y w2) -> (b x y) h (w1 w2) d', h = heads, w1 = wsz_h, w2 = wsz_w), (q, k, v))
|
||||
|
||||
# similarity
|
||||
|
||||
@@ -183,7 +183,7 @@ class InteractiveWindowedSelfAttention(nn.Module):
|
||||
|
||||
# reshape the windows back to full feature map (and merge heads)
|
||||
|
||||
out = rearrange(out, '(b x y) h (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz, y = width // wsz, w1 = wsz, w2 = wsz)
|
||||
out = rearrange(out, '(b x y) h (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz_h, y = width // wsz_w, w1 = wsz_h, w2 = wsz_w)
|
||||
|
||||
# add LIM output
|
||||
|
||||
|
||||
Reference in New Issue
Block a user