From bad4b94e7b4baa544ca36149431f7912eccd4b49 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 25 Apr 2021 12:09:32 -0700 Subject: [PATCH] fix all issues with rotary vision transformer --- setup.py | 2 +- vit_pytorch/rvt.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 83f4dc4..1441ffa 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.16.8', + version = '0.16.9', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', author = 'Phil Wang', diff --git a/vit_pytorch/rvt.py b/vit_pytorch/rvt.py index 42dcd09..eeb842b 100644 --- a/vit_pytorch/rvt.py +++ b/vit_pytorch/rvt.py @@ -67,12 +67,14 @@ class SpatialConv(nn.Module): def __init__(self, dim_in, dim_out, kernel, bias = False): super().__init__() self.conv = DepthWiseConv2d(dim_in, dim_out, kernel, padding = kernel // 2, bias = False) + self.cls_proj = nn.Linear(dim_in, dim_out) if dim_in != dim_out else nn.Identity() def forward(self, x, fmap_dims): cls_token, x = x[:, :1], x[:, 1:] x = rearrange(x, 'b (h w) d -> b d h w', **fmap_dims) x = self.conv(x) x = rearrange(x, 'b d h w -> b (h w) d') + cls_token = self.cls_proj(cls_token) return torch.cat((cls_token, x), dim = 1) class GEGLU(nn.Module): @@ -104,6 +106,7 @@ class Attention(nn.Module): self.attend = nn.Softmax(dim = -1) self.use_ds_conv = use_ds_conv + self.to_q = SpatialConv(dim, inner_dim, conv_query_kernel, bias = False) if use_ds_conv else nn.Linear(dim, inner_dim, bias = False) self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) @@ -126,8 +129,15 @@ class Attention(nn.Module): # apply 2d rotary embeddings to queries and keys, excluding CLS tokens sin, cos = pos_emb + dim_rotary = sin.shape[-1] + (q_cls, q), (k_cls, k) = map(lambda t: (t[:, :1], t[:, 1:]), (q, k)) + + # handle the case where rotary dimension < head dimension + + (q, q_pass), (k, k_pass) = map(lambda t: (t[..., :dim_rotary], t[..., dim_rotary:]), (q, k)) q, k = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k)) + q, k = map(lambda t: torch.cat(t, dim = -1), ((q, q_pass), (k, k_pass))) # concat back the CLS tokens