mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
cleanup levit
This commit is contained in:
@@ -279,11 +279,10 @@ levit = LeViT(
|
||||
num_classes = 1000,
|
||||
stages = 3, # number of stages
|
||||
dim = (256, 384, 512), # dimensions at each stage
|
||||
depth = 4,
|
||||
depth = 4, # transformer of depth 4 at each stage
|
||||
heads = (4, 6, 8), # heads at each stage
|
||||
mlp_mult = 2,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
dropout = 0.1
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 224, 224)
|
||||
|
||||
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.15.1',
|
||||
version = '0.15.2',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
|
||||
@@ -135,7 +135,6 @@ class LeViT(nn.Module):
|
||||
dim_key = 32,
|
||||
dim_value = 64,
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.,
|
||||
num_distill_classes = None
|
||||
):
|
||||
super().__init__()
|
||||
@@ -146,7 +145,7 @@ class LeViT(nn.Module):
|
||||
|
||||
assert all(map(lambda t: len(t) == stages, (dims, depths, layer_heads))), 'dimensions, depths, and heads must be a tuple that is less than the designated number of stages'
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
self.conv_embedding = nn.Sequential(
|
||||
nn.Conv2d(3, 32, 3, stride = 2, padding = 1),
|
||||
nn.Conv2d(32, 64, 3, stride = 2, padding = 1),
|
||||
nn.Conv2d(64, 128, 3, stride = 2, padding = 1),
|
||||
@@ -176,7 +175,7 @@ class LeViT(nn.Module):
|
||||
self.mlp_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, img):
|
||||
x = self.to_patch_embedding(img)
|
||||
x = self.conv_embedding(img)
|
||||
|
||||
x = self.backbone(x)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user