mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
readme
This commit is contained in:
10
README.md
10
README.md
@@ -514,10 +514,9 @@ v = ViT(
|
||||
|
||||
mae = MAE(
|
||||
encoder = v,
|
||||
masking_ratio = 0.75,
|
||||
decoder_dim = 1024,
|
||||
decoder_depth = 6,
|
||||
decoder_heads = 8
|
||||
masking_ratio = 0.75, # the paper recommended 75% masked patches
|
||||
decoder_dim = 512, # paper showed good results with just 512
|
||||
decoder_depth = 6 # anywhere from 1 to 8
|
||||
)
|
||||
|
||||
images = torch.randn(8, 3, 256, 256)
|
||||
@@ -527,6 +526,9 @@ loss.backward()
|
||||
|
||||
# that's all!
|
||||
# do the above in a for loop many times with a lot of images and your vision transformer will learn
|
||||
|
||||
# save your improved vision transformer
|
||||
torch.save(v.state_dict(), './trained-vit.pt')
|
||||
```
|
||||
|
||||
## Masked Patch Prediction
|
||||
|
||||
Reference in New Issue
Block a user