cleanup and release 0.8.0

This commit is contained in:
Phil Wang
2021-03-08 07:28:31 -08:00
parent 0e63766e54
commit 173e07e02e
4 changed files with 17 additions and 14 deletions

View File

@@ -141,15 +141,14 @@ img = torch.randn(1, 3, 224, 224)
v(img) # (1, 1000)
```
## Research Ideas
## Masked Patch Prediction
### Self Supervised Training
You can train using the original masked patch prediction task presented in the paper, with the following code.
Thanks to <a href="https://github.com/zankner">Zach</a>, you can train using the original masked patch prediction task presented in the paper, with the following code.
```python
import torch
from vit_pytorch import ViT, MPP
from vit_pytorch import ViT
from vit_pytorch.mpp import MPP
model = ViT(image_size=256,
patch_size=32,
@@ -165,9 +164,9 @@ mpp_trainer = MPP(
transformer=model,
patch_size=32,
dim=1024,
mask_prob=0.15, # probability of using token in masked prediction task
mask_prob=0.15, # probability of using token in masked prediction task
random_patch_prob=0.30, # probability of randomly replacing a token being used for mpp
replace_prob=0.50, # probability of replacing a token being used for mpp with the mask token
replace_prob=0.50, # probability of replacing a token being used for mpp with the mask token
)
opt = torch.optim.Adam(mpp_trainer.parameters(), lr=3e-4)
@@ -188,6 +187,10 @@ for _ in range(100):
torch.save(model.state_dict(), './pretrained-net.pt')
```
## Research Ideas
### Self Supervised Training
You can train this with a near SOTA self-supervised learning technique, <a href="https://github.com/lucidrains/byol-pytorch">BYOL</a>, with the following code.
(1)

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.7.6',
version = '0.8.0',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',

View File

@@ -1,2 +1 @@
from vit_pytorch.vit_pytorch import ViT
from vit_pytorch.mpp_pytorch import MPP

View File

@@ -106,6 +106,7 @@ class MPP(nn.Module):
self.mask_token = nn.Parameter(torch.randn(1, 1, dim * channels))
def forward(self, input, **kwargs):
transformer = self.transformer
# clone original image for loss
img = input.clone().detach()
@@ -144,19 +145,19 @@ class MPP(nn.Module):
masked_input[bool_mask_replace] = self.mask_token
# linear embedding of patches
masked_input = self.transformer.patch_to_embedding(masked_input)
masked_input = transformer.to_patch_embedding[-1](masked_input)
# add cls token to input sequence
b, n, _ = masked_input.shape
cls_tokens = repeat(self.transformer.cls_token, '() n d -> b n d', b=b)
cls_tokens = repeat(transformer.cls_token, '() n d -> b n d', b=b)
masked_input = torch.cat((cls_tokens, masked_input), dim=1)
# add positional embeddings to input
masked_input += self.transformer.pos_embedding[:, :(n + 1)]
masked_input = self.transformer.dropout(masked_input)
masked_input += transformer.pos_embedding[:, :(n + 1)]
masked_input = transformer.dropout(masked_input)
# get generator output and get mpp loss
masked_input = self.transformer.transformer(masked_input, **kwargs)
masked_input = transformer.transformer(masked_input, **kwargs)
cls_logits = self.to_bits(masked_input)
logits = cls_logits[:, 1:, :]