From 566365978dd0e4fec7f73915c9976216c983666f Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 20 Apr 2021 09:00:27 -0700 Subject: [PATCH] add ability to turn off rotary, for ablation --- setup.py | 2 +- vit_pytorch/rvt.py | 27 ++++++++++++++------------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/setup.py b/setup.py index ad7bfd9..b2b7fdb 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.16.6', + version = '0.16.7', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', author = 'Phil Wang', diff --git a/vit_pytorch/rvt.py b/vit_pytorch/rvt.py index 4e5b840..9c70e16 100644 --- a/vit_pytorch/rvt.py +++ b/vit_pytorch/rvt.py @@ -94,10 +94,10 @@ class FeedForward(nn.Module): return self.net(x) class Attention(nn.Module): - def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., conv_query_kernel = 5): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., use_rotary = True, conv_query_kernel = 5): super().__init__() inner_dim = dim_head * heads - + self.use_rotary = use_rotary self.heads = heads self.scale = dim_head ** -0.5 @@ -120,16 +120,17 @@ class Attention(nn.Module): q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv) - # apply 2d rotary embeddings to queries and keys, excluding CLS tokens + if self.use_rotary: + # apply 2d rotary embeddings to queries and keys, excluding CLS tokens - sin, cos = pos_emb - (q_cls, q), (k_cls, k) = map(lambda t: (t[:, :1], t[:, 1:]), (q, k)) - q, k = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k)) + sin, cos = pos_emb + (q_cls, q), (k_cls, k) = map(lambda t: (t[:, :1], t[:, 1:]), (q, k)) + q, k = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k)) - # concat back the CLS tokens + # concat back the CLS tokens - q = torch.cat((q_cls, q), dim = 1) - k = torch.cat((k_cls, k), dim = 1) + q = torch.cat((q_cls, q), dim = 1) + k = torch.cat((k_cls, k), dim = 1) dots = einsum('b i d, b j d -> b i j', q, k) * self.scale @@ -140,13 +141,13 @@ class Attention(nn.Module): return self.to_out(out) class Transformer(nn.Module): - def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., use_rotary = True): super().__init__() self.layers = nn.ModuleList([]) self.pos_emb = AxialRotaryEmbedding(dim_head) for _ in range(depth): self.layers.append(nn.ModuleList([ - PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_rotary = use_rotary)), PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) ])) def forward(self, x, fmap_dims): @@ -160,7 +161,7 @@ class Transformer(nn.Module): # Rotary Vision Transformer class RvT(nn.Module): - def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): + def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., use_rotary = True): super().__init__() assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' num_patches = (image_size // patch_size) ** 2 @@ -173,7 +174,7 @@ class RvT(nn.Module): ) self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) - self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) + self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, use_rotary) self.mlp_head = nn.Sequential( nn.LayerNorm(dim),