From b69b5af34f7759948425113f6dc3b30dfb91d4d1 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 22 Nov 2021 17:39:36 -0800 Subject: [PATCH] dynamic positional bias for crossformer the more efficient way as described in appendix of paper --- setup.py | 2 +- vit_pytorch/crossformer.py | 28 ++++++++++++++++++++-------- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/setup.py b/setup.py index ad464cb..efed44a 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.24.0', + version = '0.24.1', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', author = 'Phil Wang', diff --git a/vit_pytorch/crossformer.py b/vit_pytorch/crossformer.py index 673c891..2236100 100644 --- a/vit_pytorch/crossformer.py +++ b/vit_pytorch/crossformer.py @@ -103,12 +103,25 @@ class Attention(nn.Module): self.attn_type = attn_type self.window_size = window_size - self.dpb = DynamicPositionBias(dim // 4) - self.norm = LayerNorm(dim) self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False) self.to_out = nn.Conv2d(inner_dim, dim, 1) + # positions + + self.dpb = DynamicPositionBias(dim // 4) + + # calculate and store indices for retrieving bias + + pos = torch.arange(window_size) + grid = torch.stack(torch.meshgrid(pos, pos)) + grid = rearrange(grid, 'c i j -> (i j) c') + rel_pos = grid[:, None] - grid[None, :] + rel_pos += window_size - 1 + rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1) + + self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False) + def forward(self, x): *_, height, width, heads, wsz, device = *x.shape, self.heads, self.window_size, x.device @@ -136,12 +149,11 @@ class Attention(nn.Module): # add dynamic positional bias - i_pos = torch.arange(wsz, device = device) - j_pos = torch.arange(wsz, device = device) - grid = torch.stack(torch.meshgrid(i_pos, j_pos)) - grid = rearrange(grid, 'c i j -> (i j) c') - rel_ij = grid[:, None] - grid[None, :] - rel_pos_bias = self.dpb(rel_ij.float()) + pos = torch.arange(-wsz, wsz + 1, device = device) + rel_pos = torch.stack(torch.meshgrid(pos, pos)) + rel_pos = rearrange(rel_pos, 'c i j -> (i j) c') + biases = self.dpb(rel_pos.float()) + rel_pos_bias = biases[self.rel_pos_indices] sim = sim + rel_pos_bias