cleanup levit

This commit is contained in:
Phil Wang
2021-04-06 13:46:19 -07:00
parent e075460937
commit b50d3e1334
3 changed files with 5 additions and 7 deletions

View File

@@ -279,11 +279,10 @@ levit = LeViT(
num_classes = 1000,
stages = 3, # number of stages
dim = (256, 384, 512), # dimensions at each stage
depth = 4,
depth = 4, # transformer of depth 4 at each stage
heads = (4, 6, 8), # heads at each stage
mlp_mult = 2,
dropout = 0.1,
emb_dropout = 0.1
dropout = 0.1
)
img = torch.randn(1, 3, 224, 224)

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.15.1',
version = '0.15.2',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',

View File

@@ -135,7 +135,6 @@ class LeViT(nn.Module):
dim_key = 32,
dim_value = 64,
dropout = 0.,
emb_dropout = 0.,
num_distill_classes = None
):
super().__init__()
@@ -146,7 +145,7 @@ class LeViT(nn.Module):
assert all(map(lambda t: len(t) == stages, (dims, depths, layer_heads))), 'dimensions, depths, and heads must be a tuple that is less than the designated number of stages'
self.to_patch_embedding = nn.Sequential(
self.conv_embedding = nn.Sequential(
nn.Conv2d(3, 32, 3, stride = 2, padding = 1),
nn.Conv2d(32, 64, 3, stride = 2, padding = 1),
nn.Conv2d(64, 128, 3, stride = 2, padding = 1),
@@ -176,7 +175,7 @@ class LeViT(nn.Module):
self.mlp_head = nn.Linear(dim, num_classes)
def forward(self, img):
x = self.to_patch_embedding(img)
x = self.conv_embedding(img)
x = self.backbone(x)