From b50d3e13347f2d34732d3158fc3e5a0897e4dc4b Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 6 Apr 2021 13:46:19 -0700 Subject: [PATCH] cleanup levit --- README.md | 5 ++--- setup.py | 2 +- vit_pytorch/levit.py | 5 ++--- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index aa1b909..f61e766 100644 --- a/README.md +++ b/README.md @@ -279,11 +279,10 @@ levit = LeViT( num_classes = 1000, stages = 3, # number of stages dim = (256, 384, 512), # dimensions at each stage - depth = 4, + depth = 4, # transformer of depth 4 at each stage heads = (4, 6, 8), # heads at each stage mlp_mult = 2, - dropout = 0.1, - emb_dropout = 0.1 + dropout = 0.1 ) img = torch.randn(1, 3, 224, 224) diff --git a/setup.py b/setup.py index 26bb607..e34f1cc 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.15.1', + version = '0.15.2', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', author = 'Phil Wang', diff --git a/vit_pytorch/levit.py b/vit_pytorch/levit.py index e2c7e66..79c2957 100644 --- a/vit_pytorch/levit.py +++ b/vit_pytorch/levit.py @@ -135,7 +135,6 @@ class LeViT(nn.Module): dim_key = 32, dim_value = 64, dropout = 0., - emb_dropout = 0., num_distill_classes = None ): super().__init__() @@ -146,7 +145,7 @@ class LeViT(nn.Module): assert all(map(lambda t: len(t) == stages, (dims, depths, layer_heads))), 'dimensions, depths, and heads must be a tuple that is less than the designated number of stages' - self.to_patch_embedding = nn.Sequential( + self.conv_embedding = nn.Sequential( nn.Conv2d(3, 32, 3, stride = 2, padding = 1), nn.Conv2d(32, 64, 3, stride = 2, padding = 1), nn.Conv2d(64, 128, 3, stride = 2, padding = 1), @@ -176,7 +175,7 @@ class LeViT(nn.Module): self.mlp_head = nn.Linear(dim, num_classes) def forward(self, img): - x = self.to_patch_embedding(img) + x = self.conv_embedding(img) x = self.backbone(x)