mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
fix all issues with rotary vision transformer
This commit is contained in:
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.16.8',
|
||||
version = '0.16.9',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
|
||||
@@ -67,12 +67,14 @@ class SpatialConv(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, kernel, bias = False):
|
||||
super().__init__()
|
||||
self.conv = DepthWiseConv2d(dim_in, dim_out, kernel, padding = kernel // 2, bias = False)
|
||||
self.cls_proj = nn.Linear(dim_in, dim_out) if dim_in != dim_out else nn.Identity()
|
||||
|
||||
def forward(self, x, fmap_dims):
|
||||
cls_token, x = x[:, :1], x[:, 1:]
|
||||
x = rearrange(x, 'b (h w) d -> b d h w', **fmap_dims)
|
||||
x = self.conv(x)
|
||||
x = rearrange(x, 'b d h w -> b (h w) d')
|
||||
cls_token = self.cls_proj(cls_token)
|
||||
return torch.cat((cls_token, x), dim = 1)
|
||||
|
||||
class GEGLU(nn.Module):
|
||||
@@ -104,6 +106,7 @@ class Attention(nn.Module):
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
self.use_ds_conv = use_ds_conv
|
||||
|
||||
self.to_q = SpatialConv(dim, inner_dim, conv_query_kernel, bias = False) if use_ds_conv else nn.Linear(dim, inner_dim, bias = False)
|
||||
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
|
||||
@@ -126,8 +129,15 @@ class Attention(nn.Module):
|
||||
# apply 2d rotary embeddings to queries and keys, excluding CLS tokens
|
||||
|
||||
sin, cos = pos_emb
|
||||
dim_rotary = sin.shape[-1]
|
||||
|
||||
(q_cls, q), (k_cls, k) = map(lambda t: (t[:, :1], t[:, 1:]), (q, k))
|
||||
|
||||
# handle the case where rotary dimension < head dimension
|
||||
|
||||
(q, q_pass), (k, k_pass) = map(lambda t: (t[..., :dim_rotary], t[..., dim_rotary:]), (q, k))
|
||||
q, k = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k))
|
||||
q, k = map(lambda t: torch.cat(t, dim = -1), ((q, q_pass), (k, k_pass)))
|
||||
|
||||
# concat back the CLS tokens
|
||||
|
||||
|
||||
Reference in New Issue
Block a user