Compare commits

...

11 Commits
0.7.4 ... 0.8.0

4 changed files with 225 additions and 4 deletions

View File

@@ -141,6 +141,52 @@ img = torch.randn(1, 3, 224, 224)
v(img) # (1, 1000)
```
## Masked Patch Prediction
Thanks to <a href="https://github.com/zankner">Zach</a>, you can train using the original masked patch prediction task presented in the paper, with the following code.
```python
import torch
from vit_pytorch import ViT
from vit_pytorch.mpp import MPP
model = ViT(image_size=256,
patch_size=32,
num_classes=1000,
dim=1024,
depth=6,
heads=8,
mlp_dim=2048,
dropout=0.1,
emb_dropout=0.1)
mpp_trainer = MPP(
transformer=model,
patch_size=32,
dim=1024,
mask_prob=0.15, # probability of using token in masked prediction task
random_patch_prob=0.30, # probability of randomly replacing a token being used for mpp
replace_prob=0.50, # probability of replacing a token being used for mpp with the mask token
)
opt = torch.optim.Adam(mpp_trainer.parameters(), lr=3e-4)
def sample_unlabelled_images():
return torch.randn(20, 3, 256, 256)
for _ in range(100):
images = sample_unlabelled_images()
loss = mpp_trainer(images)
opt.zero_grad()
loss.backward()
opt.step()
# save your improved network
torch.save(model.state_dict(), './pretrained-net.pt')
```
## Research Ideas
### Self Supervised Training

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.7.4',
version = '0.8.0',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',

166
vit_pytorch/mpp.py Normal file
View File

@@ -0,0 +1,166 @@
import math
from functools import reduce
import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange, repeat
# helpers
def prob_mask_like(t, prob):
batch, seq_length, _ = t.shape
return torch.zeros((batch, seq_length)).float().uniform_(0, 1) < prob
def get_mask_subset_with_prob(patched_input, prob):
batch, seq_len, _, device = *patched_input.shape, patched_input.device
max_masked = math.ceil(prob * seq_len)
rand = torch.rand((batch, seq_len), device=device)
_, sampled_indices = rand.topk(max_masked, dim=-1)
new_mask = torch.zeros((batch, seq_len), device=device)
new_mask.scatter_(1, sampled_indices, 1)
return new_mask.bool()
# mpp loss
class MPPLoss(nn.Module):
def __init__(self, patch_size, channels, output_channel_bits,
max_pixel_val):
super(MPPLoss, self).__init__()
self.patch_size = patch_size
self.channels = channels
self.output_channel_bits = output_channel_bits
self.max_pixel_val = max_pixel_val
def forward(self, predicted_patches, target, mask):
# reshape target to patches
p = self.patch_size
target = rearrange(target,
"b c (h p1) (w p2) -> b (h w) c (p1 p2) ",
p1=p,
p2=p)
avg_target = target.mean(dim=3)
bin_size = self.max_pixel_val / self.output_channel_bits
channel_bins = torch.arange(bin_size, self.max_pixel_val, bin_size)
discretized_target = torch.bucketize(avg_target, channel_bins)
discretized_target = F.one_hot(discretized_target,
self.output_channel_bits)
c, bi = self.channels, self.output_channel_bits
discretized_target = rearrange(discretized_target,
"b n c bi -> b n (c bi)",
c=c,
bi=bi)
bin_mask = 2**torch.arange(c * bi - 1, -1,
-1).to(discretized_target.device,
discretized_target.dtype)
target_label = torch.sum(bin_mask * discretized_target, -1)
predicted_patches = predicted_patches[mask]
target_label = target_label[mask]
loss = F.cross_entropy(predicted_patches, target_label)
return loss
# main class
class MPP(nn.Module):
def __init__(self,
transformer,
patch_size,
dim,
output_channel_bits=3,
channels=3,
max_pixel_val=1.0,
mask_prob=0.15,
replace_prob=0.5,
random_patch_prob=0.5):
super().__init__()
self.transformer = transformer
self.loss = MPPLoss(patch_size, channels, output_channel_bits,
max_pixel_val)
# output transformation
self.to_bits = nn.Linear(dim, 2**(output_channel_bits * channels))
# vit related dimensions
self.patch_size = patch_size
# mpp related probabilities
self.mask_prob = mask_prob
self.replace_prob = replace_prob
self.random_patch_prob = random_patch_prob
# token ids
self.mask_token = nn.Parameter(torch.randn(1, 1, dim * channels))
def forward(self, input, **kwargs):
transformer = self.transformer
# clone original image for loss
img = input.clone().detach()
# reshape raw image to patches
p = self.patch_size
input = rearrange(input,
'b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
p1=p,
p2=p)
mask = get_mask_subset_with_prob(input, self.mask_prob)
# mask input with mask patches with probability of `replace_prob` (keep patches the same with probability 1 - replace_prob)
masked_input = input.clone().detach()
# if random token probability > 0 for mpp
if self.random_patch_prob > 0:
random_patch_sampling_prob = self.random_patch_prob / (
1 - self.replace_prob)
random_patch_prob = prob_mask_like(input,
random_patch_sampling_prob)
bool_random_patch_prob = mask * random_patch_prob == True
random_patches = torch.randint(0,
input.shape[1],
(input.shape[0], input.shape[1]),
device=input.device)
randomized_input = masked_input[
torch.arange(masked_input.shape[0]).unsqueeze(-1),
random_patches]
masked_input[bool_random_patch_prob] = randomized_input[
bool_random_patch_prob]
# [mask] input
replace_prob = prob_mask_like(input, self.replace_prob)
bool_mask_replace = (mask * replace_prob) == True
masked_input[bool_mask_replace] = self.mask_token
# linear embedding of patches
masked_input = transformer.to_patch_embedding[-1](masked_input)
# add cls token to input sequence
b, n, _ = masked_input.shape
cls_tokens = repeat(transformer.cls_token, '() n d -> b n d', b=b)
masked_input = torch.cat((cls_tokens, masked_input), dim=1)
# add positional embeddings to input
masked_input += transformer.pos_embedding[:, :(n + 1)]
masked_input = transformer.dropout(masked_input)
# get generator output and get mpp loss
masked_input = transformer.transformer(masked_input, **kwargs)
cls_logits = self.to_bits(masked_input)
logits = cls_logits[:, 1:, :]
mpp_loss = self.loss(logits, img, mask)
return mpp_loss

View File

@@ -7,11 +7,16 @@ from vit_pytorch.vit_pytorch import Transformer
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# classes
# helpers
def exists(val):
return val is not None
def conv_output_size(image_size, kernel_size, stride, padding):
return int(((image_size - kernel_size + (2 * padding)) / stride) + 1)
# classes
class RearrangeImage(nn.Module):
def forward(self, x):
return rearrange(x, 'b (h w) c -> b c h w', h = int(math.sqrt(x.shape[1])))
@@ -20,7 +25,7 @@ class RearrangeImage(nn.Module):
class T2TViT(nn.Module):
def __init__(
self, *, image_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., t2t_layers = ((7, 4), (3, 2), (3, 2))):
self, *, image_size, num_classes, dim, depth = None, heads = None, mlp_dim = None, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., transformer = None, t2t_layers = ((7, 4), (3, 2), (3, 2))):
super().__init__()
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
@@ -47,7 +52,11 @@ class T2TViT(nn.Module):
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
if not exists(transformer):
assert all([exists(depth), exists(heads), exists(mlp_dim)]), 'depth, heads, and mlp_dim must be supplied'
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
else:
self.transformer = transformer
self.pool = pool
self.to_latent = nn.Identity()