Compare commits

...

4 Commits

Author SHA1 Message Date
Phil Wang
5135820c28 fix rvt bug 2021-04-19 22:18:38 -07:00
Phil Wang
4c29328363 fix frequency in rotary vision transformer 2021-04-15 16:06:32 -07:00
Phil Wang
27ac10c1f1 0.16.3 2021-04-14 16:53:05 -07:00
Phil Wang
fa216c45ea tweak 2021-04-14 16:52:53 -07:00
2 changed files with 4 additions and 4 deletions

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.16.2',
version = '0.16.5',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',

View File

@@ -19,7 +19,7 @@ class AxialRotaryEmbedding(nn.Module):
def __init__(self, dim, max_freq = 10):
super().__init__()
self.dim = dim
scales = torch.logspace(1., log(max_freq / 2) / log(2), self.dim // 4, base = 2)
scales = torch.logspace(0., log(max_freq / 2) / log(2), self.dim // 4, base = 2)
self.register_buffer('scales', scales)
def forward(self, x):
@@ -94,7 +94,7 @@ class FeedForward(nn.Module):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., conv_query_kernel = 9):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., conv_query_kernel = 5):
super().__init__()
inner_dim = dim_head * heads
@@ -192,4 +192,4 @@ class RvT(nn.Module):
fmap_dims = {'h': h // p, 'w': w // p}
x = self.transformer(x, fmap_dims = fmap_dims)
return self.mlp_head(x)
return self.mlp_head(x[:, 0])