mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
readme
This commit is contained in:
10
README.md
10
README.md
@@ -514,10 +514,9 @@ v = ViT(
|
|||||||
|
|
||||||
mae = MAE(
|
mae = MAE(
|
||||||
encoder = v,
|
encoder = v,
|
||||||
masking_ratio = 0.75,
|
masking_ratio = 0.75, # the paper recommended 75% masked patches
|
||||||
decoder_dim = 1024,
|
decoder_dim = 512, # paper showed good results with just 512
|
||||||
decoder_depth = 6,
|
decoder_depth = 6 # anywhere from 1 to 8
|
||||||
decoder_heads = 8
|
|
||||||
)
|
)
|
||||||
|
|
||||||
images = torch.randn(8, 3, 256, 256)
|
images = torch.randn(8, 3, 256, 256)
|
||||||
@@ -527,6 +526,9 @@ loss.backward()
|
|||||||
|
|
||||||
# that's all!
|
# that's all!
|
||||||
# do the above in a for loop many times with a lot of images and your vision transformer will learn
|
# do the above in a for loop many times with a lot of images and your vision transformer will learn
|
||||||
|
|
||||||
|
# save your improved vision transformer
|
||||||
|
torch.save(v.state_dict(), './trained-vit.pt')
|
||||||
```
|
```
|
||||||
|
|
||||||
## Masked Patch Prediction
|
## Masked Patch Prediction
|
||||||
|
|||||||
Reference in New Issue
Block a user