Compare commits

...

4 Commits

Author SHA1 Message Date
Phil Wang
b50d3e1334 cleanup levit 2021-04-06 13:46:19 -07:00
Phil Wang
e075460937 stray print 2021-04-06 13:38:52 -07:00
Phil Wang
5e23e48e4d Merge pull request #88 from lucidrains/levit
fix images
2021-04-06 13:37:46 -07:00
Phil Wang
0f31ca79e3 Merge pull request #87 from lucidrains/levit
levit without pos emb
2021-04-06 13:36:26 -07:00
3 changed files with 5 additions and 8 deletions

View File

@@ -279,11 +279,10 @@ levit = LeViT(
num_classes = 1000,
stages = 3, # number of stages
dim = (256, 384, 512), # dimensions at each stage
depth = 4,
depth = 4, # transformer of depth 4 at each stage
heads = (4, 6, 8), # heads at each stage
mlp_mult = 2,
dropout = 0.1,
emb_dropout = 0.1
dropout = 0.1
)
img = torch.randn(1, 3, 224, 224)

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.15.0',
version = '0.15.2',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',

View File

@@ -81,7 +81,6 @@ class Attention(nn.Module):
def apply_pos_bias(self, fmap):
bias = self.pos_bias(self.pos_indices)
bias = rearrange(bias, 'i j h -> () h i j')
print(bias.shape, fmap.shape)
return fmap + bias
def forward(self, x):
@@ -136,7 +135,6 @@ class LeViT(nn.Module):
dim_key = 32,
dim_value = 64,
dropout = 0.,
emb_dropout = 0.,
num_distill_classes = None
):
super().__init__()
@@ -147,7 +145,7 @@ class LeViT(nn.Module):
assert all(map(lambda t: len(t) == stages, (dims, depths, layer_heads))), 'dimensions, depths, and heads must be a tuple that is less than the designated number of stages'
self.to_patch_embedding = nn.Sequential(
self.conv_embedding = nn.Sequential(
nn.Conv2d(3, 32, 3, stride = 2, padding = 1),
nn.Conv2d(32, 64, 3, stride = 2, padding = 1),
nn.Conv2d(64, 128, 3, stride = 2, padding = 1),
@@ -177,7 +175,7 @@ class LeViT(nn.Module):
self.mlp_head = nn.Linear(dim, num_classes)
def forward(self, img):
x = self.to_patch_embedding(img)
x = self.conv_embedding(img)
x = self.backbone(x)