added to readme

This commit is contained in:
Zack Ankner
2021-03-08 09:34:55 -05:00
parent 73de1e8a73
commit a6cbda37b9
2 changed files with 43 additions and 42 deletions

View File

@@ -121,6 +121,49 @@ type(v) # <class 'vit_pytorch.vit_pytorch.ViT'>
### Self Supervised Training
You can train using the original masked patch prediction task presented in the paper, with the following code.
```python
import torch
from vit_pytorch import ViT, MPP
model = ViT(image_size=256,
patch_size=32,
num_classes=1000,
dim=1024,
depth=6,
heads=8,
mlp_dim=2048,
dropout=0.1,
emb_dropout=0.1)
mpp_trainer = MPP(
transformer=model,
patch_size=32,
dim=1024,
mask_prob=0.15, # probability of using token in masked prediction task
random_patch_prob=0.30, # probability of randomly replacing a token being used for mpp
replace_prob=0.50, # probability of replacing a token being used for mpp with the mask token
)
opt = torch.optim.Adam(mpp_trainer.parameters(), lr=3e-4)
def sample_unlabelled_images():
return torch.randn(20, 3, 256, 256)
for _ in range(100):
images = sample_unlabelled_images()
loss = mpp_trainer(images)
opt.zero_grad()
loss.backward()
opt.step()
# save your improved network
torch.save(model.state_dict(), './pretrained-net.pt')
```
You can train this with a near SOTA self-supervised learning technique, <a href="https://github.com/lucidrains/byol-pytorch">BYOL</a>, with the following code.
(1)

42
test.py
View File

@@ -1,42 +0,0 @@
import torch
from vit_pytorch import MPP, ViT
# from vit_pytorch import ViT, MaskedPredictionLoss
v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 3,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
# l = MaskedPredictionLoss(patch_size=32, img_size=256)
img = torch.randn(2, 3, 256, 256)
# mask = [1,2,3,4] # optional mask, designating which patch to attend to
# preds = v(img) # (1, 1000)
# loss = l(preds, img, mask)
# print(preds.shape)
trainer = MPP(
transformer = v,
patch_size = 32,
dim = 1024,
mask_prob = 0.15, # masking probability for masked language modeling
random_patch_prob=0.30,
replace_prob = 0.50, # ~10% probability that token will not be masked, but included in loss, as detailed in the epaper
)
# data = torch.rand((2, 3, 10, 10))
loss = trainer(img)
print(loss)