mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
added to readme
This commit is contained in:
43
README.md
43
README.md
@@ -121,6 +121,49 @@ type(v) # <class 'vit_pytorch.vit_pytorch.ViT'>
|
||||
|
||||
### Self Supervised Training
|
||||
|
||||
You can train using the original masked patch prediction task presented in the paper, with the following code.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch import ViT, MPP
|
||||
|
||||
model = ViT(image_size=256,
|
||||
patch_size=32,
|
||||
num_classes=1000,
|
||||
dim=1024,
|
||||
depth=6,
|
||||
heads=8,
|
||||
mlp_dim=2048,
|
||||
dropout=0.1,
|
||||
emb_dropout=0.1)
|
||||
|
||||
mpp_trainer = MPP(
|
||||
transformer=model,
|
||||
patch_size=32,
|
||||
dim=1024,
|
||||
mask_prob=0.15, # probability of using token in masked prediction task
|
||||
random_patch_prob=0.30, # probability of randomly replacing a token being used for mpp
|
||||
replace_prob=0.50, # probability of replacing a token being used for mpp with the mask token
|
||||
)
|
||||
|
||||
opt = torch.optim.Adam(mpp_trainer.parameters(), lr=3e-4)
|
||||
|
||||
|
||||
def sample_unlabelled_images():
|
||||
return torch.randn(20, 3, 256, 256)
|
||||
|
||||
|
||||
for _ in range(100):
|
||||
images = sample_unlabelled_images()
|
||||
loss = mpp_trainer(images)
|
||||
opt.zero_grad()
|
||||
loss.backward()
|
||||
opt.step()
|
||||
|
||||
# save your improved network
|
||||
torch.save(model.state_dict(), './pretrained-net.pt')
|
||||
```
|
||||
|
||||
You can train this with a near SOTA self-supervised learning technique, <a href="https://github.com/lucidrains/byol-pytorch">BYOL</a>, with the following code.
|
||||
|
||||
(1)
|
||||
|
||||
42
test.py
42
test.py
@@ -1,42 +0,0 @@
|
||||
import torch
|
||||
from vit_pytorch import MPP, ViT
|
||||
# from vit_pytorch import ViT, MaskedPredictionLoss
|
||||
|
||||
v = ViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 3,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
# l = MaskedPredictionLoss(patch_size=32, img_size=256)
|
||||
|
||||
img = torch.randn(2, 3, 256, 256)
|
||||
# mask = [1,2,3,4] # optional mask, designating which patch to attend to
|
||||
|
||||
# preds = v(img) # (1, 1000)
|
||||
# loss = l(preds, img, mask)
|
||||
|
||||
# print(preds.shape)
|
||||
|
||||
|
||||
trainer = MPP(
|
||||
transformer = v,
|
||||
patch_size = 32,
|
||||
dim = 1024,
|
||||
mask_prob = 0.15, # masking probability for masked language modeling
|
||||
random_patch_prob=0.30,
|
||||
replace_prob = 0.50, # ~10% probability that token will not be masked, but included in loss, as detailed in the epaper
|
||||
)
|
||||
|
||||
# data = torch.rand((2, 3, 10, 10))
|
||||
|
||||
loss = trainer(img)
|
||||
|
||||
print(loss)
|
||||
|
||||
Reference in New Issue
Block a user