give a learned bias to and from registers for maxvit + register token variant

This commit is contained in:
lucidrains
2023-10-06 10:40:26 -07:00
parent df8733d86e
commit bbb24e34d4
2 changed files with 14 additions and 13 deletions

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.5.2',
version = '1.5.3',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description_content_type = 'text/markdown',

View File

@@ -119,9 +119,11 @@ class Attention(Module):
dim,
dim_head = 32,
dropout = 0.,
window_size = 7
window_size = 7,
num_registers = 1
):
super().__init__()
assert num_registers > 0
assert (dim % dim_head) == 0, 'dimension should be divisible by dimension per head'
self.heads = dim // dim_head
@@ -142,7 +144,9 @@ class Attention(Module):
# relative positional bias
self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads)
num_rel_pos_bias = (2 * window_size - 1) ** 2
self.rel_pos_bias = nn.Embedding(num_rel_pos_bias + 1, self.heads)
pos = torch.arange(window_size)
grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
@@ -151,10 +155,11 @@ class Attention(Module):
rel_pos += window_size - 1
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1)
rel_pos_indices = F.pad(rel_pos_indices, (num_registers, 0, num_registers, 0), value = num_rel_pos_bias)
self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False)
def forward(self, x):
device, h = x.device, self.heads
device, h, bias_indices = x.device, self.heads, self.rel_pos_indices
x = self.norm(x)
@@ -176,13 +181,8 @@ class Attention(Module):
# add positional bias
bias = self.rel_pos_bias(self.rel_pos_indices)
bias = rearrange(bias, 'i j h -> h i j')
num_registers = sim.shape[-1] - bias.shape[-1]
bias = F.pad(bias, (num_registers, 0, num_registers, 0), value = 0.)
sim = sim + bias
bias = self.rel_pos_bias(bias_indices)
sim = sim + rearrange(bias, 'i j h -> h i j')
# attention
@@ -215,6 +215,7 @@ class MaxViT(Module):
):
super().__init__()
assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'
assert num_register_tokens > 0
# convolutional stem
@@ -256,10 +257,10 @@ class MaxViT(Module):
shrinkage_rate = mbconv_shrinkage_rate
)
block_attn = Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = window_size)
block_attn = Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = window_size, num_registers = num_register_tokens)
block_ff = FeedForward(dim = layer_dim, dropout = dropout)
grid_attn = Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = window_size)
grid_attn = Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = window_size, num_registers = num_register_tokens)
grid_ff = FeedForward(dim = layer_dim, dropout = dropout)
register_tokens = nn.Parameter(torch.randn(num_register_tokens, layer_dim))