mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
dynamic positional bias for crossformer the more efficient way as described in appendix of paper
This commit is contained in:
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.24.0',
|
||||
version = '0.24.1',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
|
||||
@@ -103,12 +103,25 @@ class Attention(nn.Module):
|
||||
self.attn_type = attn_type
|
||||
self.window_size = window_size
|
||||
|
||||
self.dpb = DynamicPositionBias(dim // 4)
|
||||
|
||||
self.norm = LayerNorm(dim)
|
||||
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
|
||||
self.to_out = nn.Conv2d(inner_dim, dim, 1)
|
||||
|
||||
# positions
|
||||
|
||||
self.dpb = DynamicPositionBias(dim // 4)
|
||||
|
||||
# calculate and store indices for retrieving bias
|
||||
|
||||
pos = torch.arange(window_size)
|
||||
grid = torch.stack(torch.meshgrid(pos, pos))
|
||||
grid = rearrange(grid, 'c i j -> (i j) c')
|
||||
rel_pos = grid[:, None] - grid[None, :]
|
||||
rel_pos += window_size - 1
|
||||
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1)
|
||||
|
||||
self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False)
|
||||
|
||||
def forward(self, x):
|
||||
*_, height, width, heads, wsz, device = *x.shape, self.heads, self.window_size, x.device
|
||||
|
||||
@@ -136,12 +149,11 @@ class Attention(nn.Module):
|
||||
|
||||
# add dynamic positional bias
|
||||
|
||||
i_pos = torch.arange(wsz, device = device)
|
||||
j_pos = torch.arange(wsz, device = device)
|
||||
grid = torch.stack(torch.meshgrid(i_pos, j_pos))
|
||||
grid = rearrange(grid, 'c i j -> (i j) c')
|
||||
rel_ij = grid[:, None] - grid[None, :]
|
||||
rel_pos_bias = self.dpb(rel_ij.float())
|
||||
pos = torch.arange(-wsz, wsz + 1, device = device)
|
||||
rel_pos = torch.stack(torch.meshgrid(pos, pos))
|
||||
rel_pos = rearrange(rel_pos, 'c i j -> (i j) c')
|
||||
biases = self.dpb(rel_pos.float())
|
||||
rel_pos_bias = biases[self.rel_pos_indices]
|
||||
|
||||
sim = sim + rel_pos_bias
|
||||
|
||||
|
||||
Reference in New Issue
Block a user