Compare commits

...

2 Commits

Author SHA1 Message Date
Phil Wang
566365978d add ability to turn off rotary, for ablation 2021-04-20 09:00:27 -07:00
Phil Wang
34f78294d3 fix pooling bugs across a few new archs 2021-04-19 22:36:23 -07:00
4 changed files with 23 additions and 19 deletions

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.16.4',
version = '0.16.7',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',

View File

@@ -149,4 +149,4 @@ class LocalViT(nn.Module):
x = self.transformer(x)
return self.mlp_head(x)
return self.mlp_head(x[:, 0])

View File

@@ -162,8 +162,9 @@ class PiT(nn.Module):
layers.append(Pool(dim))
dim *= 2
self.layers = nn.Sequential(
*layers,
self.layers = nn.Sequential(*layers)
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
@@ -177,4 +178,6 @@ class PiT(nn.Module):
x += self.pos_embedding
x = self.dropout(x)
return self.layers(x)
x = self.layers(x)
return self.mlp_head(x[:, 0])

View File

@@ -94,10 +94,10 @@ class FeedForward(nn.Module):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., conv_query_kernel = 5):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., use_rotary = True, conv_query_kernel = 5):
super().__init__()
inner_dim = dim_head * heads
self.use_rotary = use_rotary
self.heads = heads
self.scale = dim_head ** -0.5
@@ -120,16 +120,17 @@ class Attention(nn.Module):
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)
# apply 2d rotary embeddings to queries and keys, excluding CLS tokens
if self.use_rotary:
# apply 2d rotary embeddings to queries and keys, excluding CLS tokens
sin, cos = pos_emb
(q_cls, q), (k_cls, k) = map(lambda t: (t[:, :1], t[:, 1:]), (q, k))
q, k = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k))
sin, cos = pos_emb
(q_cls, q), (k_cls, k) = map(lambda t: (t[:, :1], t[:, 1:]), (q, k))
q, k = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k))
# concat back the CLS tokens
# concat back the CLS tokens
q = torch.cat((q_cls, q), dim = 1)
k = torch.cat((k_cls, k), dim = 1)
q = torch.cat((q_cls, q), dim = 1)
k = torch.cat((k_cls, k), dim = 1)
dots = einsum('b i d, b j d -> b i j', q, k) * self.scale
@@ -140,13 +141,13 @@ class Attention(nn.Module):
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., use_rotary = True):
super().__init__()
self.layers = nn.ModuleList([])
self.pos_emb = AxialRotaryEmbedding(dim_head)
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_rotary = use_rotary)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x, fmap_dims):
@@ -160,7 +161,7 @@ class Transformer(nn.Module):
# Rotary Vision Transformer
class RvT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., use_rotary = True):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
@@ -173,7 +174,7 @@ class RvT(nn.Module):
)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, use_rotary)
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
@@ -192,4 +193,4 @@ class RvT(nn.Module):
fmap_dims = {'h': h // p, 'w': w // p}
x = self.transformer(x, fmap_dims = fmap_dims)
return self.mlp_head(x)
return self.mlp_head(x[:, 0])