dynamic positional bias for crossformer the more efficient way as described in appendix of paper

This commit is contained in:
Phil Wang
2021-11-22 17:39:36 -08:00
parent 36e32b70fb
commit b69b5af34f
2 changed files with 21 additions and 9 deletions

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.24.0',
version = '0.24.1',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',

View File

@@ -103,12 +103,25 @@ class Attention(nn.Module):
self.attn_type = attn_type
self.window_size = window_size
self.dpb = DynamicPositionBias(dim // 4)
self.norm = LayerNorm(dim)
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(inner_dim, dim, 1)
# positions
self.dpb = DynamicPositionBias(dim // 4)
# calculate and store indices for retrieving bias
pos = torch.arange(window_size)
grid = torch.stack(torch.meshgrid(pos, pos))
grid = rearrange(grid, 'c i j -> (i j) c')
rel_pos = grid[:, None] - grid[None, :]
rel_pos += window_size - 1
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1)
self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False)
def forward(self, x):
*_, height, width, heads, wsz, device = *x.shape, self.heads, self.window_size, x.device
@@ -136,12 +149,11 @@ class Attention(nn.Module):
# add dynamic positional bias
i_pos = torch.arange(wsz, device = device)
j_pos = torch.arange(wsz, device = device)
grid = torch.stack(torch.meshgrid(i_pos, j_pos))
grid = rearrange(grid, 'c i j -> (i j) c')
rel_ij = grid[:, None] - grid[None, :]
rel_pos_bias = self.dpb(rel_ij.float())
pos = torch.arange(-wsz, wsz + 1, device = device)
rel_pos = torch.stack(torch.meshgrid(pos, pos))
rel_pos = rearrange(rel_pos, 'c i j -> (i j) c')
biases = self.dpb(rel_pos.float())
rel_pos_bias = biases[self.rel_pos_indices]
sim = sim + rel_pos_bias