mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
cleanup and release 0.8.0
This commit is contained in:
17
README.md
17
README.md
@@ -141,15 +141,14 @@ img = torch.randn(1, 3, 224, 224)
|
||||
v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## Research Ideas
|
||||
## Masked Patch Prediction
|
||||
|
||||
### Self Supervised Training
|
||||
|
||||
You can train using the original masked patch prediction task presented in the paper, with the following code.
|
||||
Thanks to <a href="https://github.com/zankner">Zach</a>, you can train using the original masked patch prediction task presented in the paper, with the following code.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch import ViT, MPP
|
||||
from vit_pytorch import ViT
|
||||
from vit_pytorch.mpp import MPP
|
||||
|
||||
model = ViT(image_size=256,
|
||||
patch_size=32,
|
||||
@@ -165,9 +164,9 @@ mpp_trainer = MPP(
|
||||
transformer=model,
|
||||
patch_size=32,
|
||||
dim=1024,
|
||||
mask_prob=0.15, # probability of using token in masked prediction task
|
||||
mask_prob=0.15, # probability of using token in masked prediction task
|
||||
random_patch_prob=0.30, # probability of randomly replacing a token being used for mpp
|
||||
replace_prob=0.50, # probability of replacing a token being used for mpp with the mask token
|
||||
replace_prob=0.50, # probability of replacing a token being used for mpp with the mask token
|
||||
)
|
||||
|
||||
opt = torch.optim.Adam(mpp_trainer.parameters(), lr=3e-4)
|
||||
@@ -188,6 +187,10 @@ for _ in range(100):
|
||||
torch.save(model.state_dict(), './pretrained-net.pt')
|
||||
```
|
||||
|
||||
## Research Ideas
|
||||
|
||||
### Self Supervised Training
|
||||
|
||||
You can train this with a near SOTA self-supervised learning technique, <a href="https://github.com/lucidrains/byol-pytorch">BYOL</a>, with the following code.
|
||||
|
||||
(1)
|
||||
|
||||
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.7.6',
|
||||
version = '0.8.0',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
|
||||
@@ -1,2 +1 @@
|
||||
from vit_pytorch.vit_pytorch import ViT
|
||||
from vit_pytorch.mpp_pytorch import MPP
|
||||
|
||||
@@ -106,6 +106,7 @@ class MPP(nn.Module):
|
||||
self.mask_token = nn.Parameter(torch.randn(1, 1, dim * channels))
|
||||
|
||||
def forward(self, input, **kwargs):
|
||||
transformer = self.transformer
|
||||
# clone original image for loss
|
||||
img = input.clone().detach()
|
||||
|
||||
@@ -144,19 +145,19 @@ class MPP(nn.Module):
|
||||
masked_input[bool_mask_replace] = self.mask_token
|
||||
|
||||
# linear embedding of patches
|
||||
masked_input = self.transformer.patch_to_embedding(masked_input)
|
||||
masked_input = transformer.to_patch_embedding[-1](masked_input)
|
||||
|
||||
# add cls token to input sequence
|
||||
b, n, _ = masked_input.shape
|
||||
cls_tokens = repeat(self.transformer.cls_token, '() n d -> b n d', b=b)
|
||||
cls_tokens = repeat(transformer.cls_token, '() n d -> b n d', b=b)
|
||||
masked_input = torch.cat((cls_tokens, masked_input), dim=1)
|
||||
|
||||
# add positional embeddings to input
|
||||
masked_input += self.transformer.pos_embedding[:, :(n + 1)]
|
||||
masked_input = self.transformer.dropout(masked_input)
|
||||
masked_input += transformer.pos_embedding[:, :(n + 1)]
|
||||
masked_input = transformer.dropout(masked_input)
|
||||
|
||||
# get generator output and get mpp loss
|
||||
masked_input = self.transformer.transformer(masked_input, **kwargs)
|
||||
masked_input = transformer.transformer(masked_input, **kwargs)
|
||||
cls_logits = self.to_bits(masked_input)
|
||||
logits = cls_logits[:, 1:, :]
|
||||
|
||||
Reference in New Issue
Block a user