mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
readme
This commit is contained in:
22
README.md
22
README.md
@@ -150,15 +150,17 @@ import torch
|
||||
from vit_pytorch import ViT
|
||||
from vit_pytorch.mpp import MPP
|
||||
|
||||
model = ViT(image_size=256,
|
||||
patch_size=32,
|
||||
num_classes=1000,
|
||||
dim=1024,
|
||||
depth=6,
|
||||
heads=8,
|
||||
mlp_dim=2048,
|
||||
dropout=0.1,
|
||||
emb_dropout=0.1)
|
||||
model = ViT(
|
||||
image_size=256,
|
||||
patch_size=32,
|
||||
num_classes=1000,
|
||||
dim=1024,
|
||||
depth=6,
|
||||
heads=8,
|
||||
mlp_dim=2048,
|
||||
dropout=0.1,
|
||||
emb_dropout=0.1
|
||||
)
|
||||
|
||||
mpp_trainer = MPP(
|
||||
transformer=model,
|
||||
@@ -171,11 +173,9 @@ mpp_trainer = MPP(
|
||||
|
||||
opt = torch.optim.Adam(mpp_trainer.parameters(), lr=3e-4)
|
||||
|
||||
|
||||
def sample_unlabelled_images():
|
||||
return torch.randn(20, 3, 256, 256)
|
||||
|
||||
|
||||
for _ in range(100):
|
||||
images = sample_unlabelled_images()
|
||||
loss = mpp_trainer(images)
|
||||
|
||||
Reference in New Issue
Block a user