Files
vit-pytorch/tests/test.py
2021-12-22 09:13:31 -08:00

21 lines
407 B
Python

import torch
from vit_pytorch import ViT
def test():
v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
img = torch.randn(1, 3, 256, 256)
preds = v(img)
assert preds.shape == (1, 1000), 'correct logits outputted'