mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 16:12:29 +00:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5135820c28 | ||
|
|
4c29328363 | ||
|
|
27ac10c1f1 | ||
|
|
fa216c45ea |
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.16.2',
|
||||
version = '0.16.5',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
|
||||
@@ -19,7 +19,7 @@ class AxialRotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_freq = 10):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
scales = torch.logspace(1., log(max_freq / 2) / log(2), self.dim // 4, base = 2)
|
||||
scales = torch.logspace(0., log(max_freq / 2) / log(2), self.dim // 4, base = 2)
|
||||
self.register_buffer('scales', scales)
|
||||
|
||||
def forward(self, x):
|
||||
@@ -94,7 +94,7 @@ class FeedForward(nn.Module):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., conv_query_kernel = 9):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., conv_query_kernel = 5):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
@@ -192,4 +192,4 @@ class RvT(nn.Module):
|
||||
fmap_dims = {'h': h // p, 'w': w // p}
|
||||
x = self.transformer(x, fmap_dims = fmap_dims)
|
||||
|
||||
return self.mlp_head(x)
|
||||
return self.mlp_head(x[:, 0])
|
||||
|
||||
Reference in New Issue
Block a user