Compare commits

...

5 Commits

Author SHA1 Message Date
Phil Wang
4c29328363 fix frequency in rotary vision transformer 2021-04-15 16:06:32 -07:00
Phil Wang
27ac10c1f1 0.16.3 2021-04-14 16:53:05 -07:00
Phil Wang
fa216c45ea tweak 2021-04-14 16:52:53 -07:00
Phil Wang
1d8b7826bf update personal pet vit 2021-04-14 15:56:39 -07:00
Phil Wang
53b3af05f6 use convolution on query with padding to give the network absolute spatial awareness in addition to relative encoding from rotary embeddings 2021-04-14 15:56:02 -07:00
2 changed files with 42 additions and 10 deletions

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.16.1',
version = '0.16.4',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',

View File

@@ -19,7 +19,7 @@ class AxialRotaryEmbedding(nn.Module):
def __init__(self, dim, max_freq = 10):
super().__init__()
self.dim = dim
scales = torch.logspace(1., log(max_freq / 2) / log(2), self.dim // 4, base = 2)
scales = torch.logspace(0., log(max_freq / 2) / log(2), self.dim // 4, base = 2)
self.register_buffer('scales', scales)
def forward(self, x):
@@ -43,6 +43,16 @@ class AxialRotaryEmbedding(nn.Module):
sin, cos = map(lambda t: repeat(t, 'n d -> () n (d j)', j = 2), (sin, cos))
return sin, cos
class DepthWiseConv2d(nn.Module):
def __init__(self, dim_in, dim_out, kernel_size, padding, stride = 1, bias = True):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)
)
def forward(self, x):
return self.net(x)
# helper classes
class PreNorm(nn.Module):
@@ -53,6 +63,18 @@ class PreNorm(nn.Module):
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class SpatialConv(nn.Module):
def __init__(self, dim_in, dim_out, kernel, bias = False):
super().__init__()
self.conv = DepthWiseConv2d(dim_in, dim_out, kernel, padding = kernel // 2, bias = False)
def forward(self, x, fmap_dims):
cls_token, x = x[:, :1], x[:, 1:]
x = rearrange(x, 'b (h w) d -> b d h w', **fmap_dims)
x = self.conv(x)
x = rearrange(x, 'b d h w -> b (h w) d')
return torch.cat((cls_token, x), dim = 1)
class GEGLU(nn.Module):
def forward(self, x):
x, gates = x.chunk(2, dim = -1)
@@ -72,7 +94,7 @@ class FeedForward(nn.Module):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., conv_query_kernel = 5):
super().__init__()
inner_dim = dim_head * heads
@@ -80,16 +102,22 @@ class Attention(nn.Module):
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_q = SpatialConv(dim, inner_dim, conv_query_kernel, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, pos_emb):
def forward(self, x, pos_emb, fmap_dims):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim = -1)
q = self.to_q(x, fmap_dims = fmap_dims)
qkv = (q, *self.to_kv(x).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)
# apply 2d rotary embeddings to queries and keys, excluding CLS tokens
@@ -121,11 +149,11 @@ class Transformer(nn.Module):
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
def forward(self, x, fmap_dims):
pos_emb = self.pos_emb(x[:, 1:])
for attn, ff in self.layers:
x = attn(x, pos_emb = pos_emb) + x
x = attn(x, pos_emb = pos_emb, fmap_dims = fmap_dims) + x
x = ff(x) + x
return x
@@ -138,6 +166,7 @@ class RvT(nn.Module):
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
self.patch_size = patch_size
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
nn.Linear(patch_dim, dim),
@@ -152,12 +181,15 @@ class RvT(nn.Module):
)
def forward(self, img):
b, _, h, w, p = *img.shape, self.patch_size
x = self.to_patch_embedding(img)
b, n, _ = x.shape
n = x.shape[1]
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x = self.transformer(x)
fmap_dims = {'h': h // p, 'w': w // p}
x = self.transformer(x, fmap_dims = fmap_dims)
return self.mlp_head(x)