From a6cbda37b9579a1bb135a736829638bfea0debe1 Mon Sep 17 00:00:00 2001 From: Zack Ankner Date: Mon, 8 Mar 2021 09:34:55 -0500 Subject: [PATCH] added to readme --- README.md | 43 +++++++++++++++++++++++++++++++++++++++++++ test.py | 42 ------------------------------------------ 2 files changed, 43 insertions(+), 42 deletions(-) delete mode 100644 test.py diff --git a/README.md b/README.md index a1ac542..599ee9b 100644 --- a/README.md +++ b/README.md @@ -121,6 +121,49 @@ type(v) # ### Self Supervised Training +You can train using the original masked patch prediction task presented in the paper, with the following code. + +```python +import torch +from vit_pytorch import ViT, MPP + +model = ViT(image_size=256, + patch_size=32, + num_classes=1000, + dim=1024, + depth=6, + heads=8, + mlp_dim=2048, + dropout=0.1, + emb_dropout=0.1) + +mpp_trainer = MPP( + transformer=model, + patch_size=32, + dim=1024, + mask_prob=0.15, # probability of using token in masked prediction task + random_patch_prob=0.30, # probability of randomly replacing a token being used for mpp + replace_prob=0.50, # probability of replacing a token being used for mpp with the mask token +) + +opt = torch.optim.Adam(mpp_trainer.parameters(), lr=3e-4) + + +def sample_unlabelled_images(): + return torch.randn(20, 3, 256, 256) + + +for _ in range(100): + images = sample_unlabelled_images() + loss = mpp_trainer(images) + opt.zero_grad() + loss.backward() + opt.step() + +# save your improved network +torch.save(model.state_dict(), './pretrained-net.pt') +``` + You can train this with a near SOTA self-supervised learning technique, BYOL, with the following code. (1) diff --git a/test.py b/test.py deleted file mode 100644 index 172037e..0000000 --- a/test.py +++ /dev/null @@ -1,42 +0,0 @@ -import torch -from vit_pytorch import MPP, ViT -# from vit_pytorch import ViT, MaskedPredictionLoss - -v = ViT( - image_size = 256, - patch_size = 32, - num_classes = 3, - dim = 1024, - depth = 6, - heads = 16, - mlp_dim = 2048, - dropout = 0.1, - emb_dropout = 0.1 -) - -# l = MaskedPredictionLoss(patch_size=32, img_size=256) - -img = torch.randn(2, 3, 256, 256) -# mask = [1,2,3,4] # optional mask, designating which patch to attend to - -# preds = v(img) # (1, 1000) -# loss = l(preds, img, mask) - -# print(preds.shape) - - -trainer = MPP( - transformer = v, - patch_size = 32, - dim = 1024, - mask_prob = 0.15, # masking probability for masked language modeling - random_patch_prob=0.30, - replace_prob = 0.50, # ~10% probability that token will not be masked, but included in loss, as detailed in the epaper -) - -# data = torch.rand((2, 3, 10, 10)) - -loss = trainer(img) - -print(loss) -