This commit is contained in:
Phil Wang
2021-11-12 20:19:38 -08:00
parent e8f6d72033
commit 5b16e8f809

View File

@@ -514,10 +514,9 @@ v = ViT(
mae = MAE( mae = MAE(
encoder = v, encoder = v,
masking_ratio = 0.75, masking_ratio = 0.75, # the paper recommended 75% masked patches
decoder_dim = 1024, decoder_dim = 512, # paper showed good results with just 512
decoder_depth = 6, decoder_depth = 6 # anywhere from 1 to 8
decoder_heads = 8
) )
images = torch.randn(8, 3, 256, 256) images = torch.randn(8, 3, 256, 256)
@@ -527,6 +526,9 @@ loss.backward()
# that's all! # that's all!
# do the above in a for loop many times with a lot of images and your vision transformer will learn # do the above in a for loop many times with a lot of images and your vision transformer will learn
# save your improved vision transformer
torch.save(v.state_dict(), './trained-vit.pt')
``` ```
## Masked Patch Prediction ## Masked Patch Prediction