mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-29 23:52:27 +00:00
21 lines
411 B
Python
21 lines
411 B
Python
import torch
|
|
from vit_pytorch import ViT
|
|
|
|
def test_vit():
|
|
v = ViT(
|
|
image_size = 256,
|
|
patch_size = 32,
|
|
num_classes = 1000,
|
|
dim = 1024,
|
|
depth = 6,
|
|
heads = 16,
|
|
mlp_dim = 2048,
|
|
dropout = 0.1,
|
|
emb_dropout = 0.1
|
|
)
|
|
|
|
img = torch.randn(1, 3, 256, 256)
|
|
|
|
preds = v(img)
|
|
assert preds.shape == (1, 1000), 'correct logits outputted'
|