mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2026-05-17 06:11:34 +00:00
Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
173e07e02e | ||
|
|
0e63766e54 | ||
|
|
a6cbda37b9 | ||
|
|
73de1e8a73 | ||
|
|
1698b7bef8 | ||
|
|
fc14561de7 | ||
|
|
be5d560821 | ||
|
|
77703ae1fc | ||
|
|
a0a4fa5e7d | ||
|
|
174e71cf53 | ||
|
|
e14bd14a8f |
46
README.md
46
README.md
@@ -141,6 +141,52 @@ img = torch.randn(1, 3, 224, 224)
|
||||
v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## Masked Patch Prediction
|
||||
|
||||
Thanks to <a href="https://github.com/zankner">Zach</a>, you can train using the original masked patch prediction task presented in the paper, with the following code.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch import ViT
|
||||
from vit_pytorch.mpp import MPP
|
||||
|
||||
model = ViT(image_size=256,
|
||||
patch_size=32,
|
||||
num_classes=1000,
|
||||
dim=1024,
|
||||
depth=6,
|
||||
heads=8,
|
||||
mlp_dim=2048,
|
||||
dropout=0.1,
|
||||
emb_dropout=0.1)
|
||||
|
||||
mpp_trainer = MPP(
|
||||
transformer=model,
|
||||
patch_size=32,
|
||||
dim=1024,
|
||||
mask_prob=0.15, # probability of using token in masked prediction task
|
||||
random_patch_prob=0.30, # probability of randomly replacing a token being used for mpp
|
||||
replace_prob=0.50, # probability of replacing a token being used for mpp with the mask token
|
||||
)
|
||||
|
||||
opt = torch.optim.Adam(mpp_trainer.parameters(), lr=3e-4)
|
||||
|
||||
|
||||
def sample_unlabelled_images():
|
||||
return torch.randn(20, 3, 256, 256)
|
||||
|
||||
|
||||
for _ in range(100):
|
||||
images = sample_unlabelled_images()
|
||||
loss = mpp_trainer(images)
|
||||
opt.zero_grad()
|
||||
loss.backward()
|
||||
opt.step()
|
||||
|
||||
# save your improved network
|
||||
torch.save(model.state_dict(), './pretrained-net.pt')
|
||||
```
|
||||
|
||||
## Research Ideas
|
||||
|
||||
### Self Supervised Training
|
||||
|
||||
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.7.4',
|
||||
version = '0.8.0',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
|
||||
166
vit_pytorch/mpp.py
Normal file
166
vit_pytorch/mpp.py
Normal file
@@ -0,0 +1,166 @@
|
||||
import math
|
||||
from functools import reduce
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
# helpers
|
||||
|
||||
|
||||
def prob_mask_like(t, prob):
|
||||
batch, seq_length, _ = t.shape
|
||||
return torch.zeros((batch, seq_length)).float().uniform_(0, 1) < prob
|
||||
|
||||
|
||||
def get_mask_subset_with_prob(patched_input, prob):
|
||||
batch, seq_len, _, device = *patched_input.shape, patched_input.device
|
||||
max_masked = math.ceil(prob * seq_len)
|
||||
|
||||
rand = torch.rand((batch, seq_len), device=device)
|
||||
_, sampled_indices = rand.topk(max_masked, dim=-1)
|
||||
|
||||
new_mask = torch.zeros((batch, seq_len), device=device)
|
||||
new_mask.scatter_(1, sampled_indices, 1)
|
||||
return new_mask.bool()
|
||||
|
||||
|
||||
# mpp loss
|
||||
|
||||
|
||||
class MPPLoss(nn.Module):
|
||||
def __init__(self, patch_size, channels, output_channel_bits,
|
||||
max_pixel_val):
|
||||
super(MPPLoss, self).__init__()
|
||||
self.patch_size = patch_size
|
||||
self.channels = channels
|
||||
self.output_channel_bits = output_channel_bits
|
||||
self.max_pixel_val = max_pixel_val
|
||||
|
||||
def forward(self, predicted_patches, target, mask):
|
||||
# reshape target to patches
|
||||
p = self.patch_size
|
||||
target = rearrange(target,
|
||||
"b c (h p1) (w p2) -> b (h w) c (p1 p2) ",
|
||||
p1=p,
|
||||
p2=p)
|
||||
|
||||
avg_target = target.mean(dim=3)
|
||||
|
||||
bin_size = self.max_pixel_val / self.output_channel_bits
|
||||
channel_bins = torch.arange(bin_size, self.max_pixel_val, bin_size)
|
||||
discretized_target = torch.bucketize(avg_target, channel_bins)
|
||||
discretized_target = F.one_hot(discretized_target,
|
||||
self.output_channel_bits)
|
||||
c, bi = self.channels, self.output_channel_bits
|
||||
discretized_target = rearrange(discretized_target,
|
||||
"b n c bi -> b n (c bi)",
|
||||
c=c,
|
||||
bi=bi)
|
||||
|
||||
bin_mask = 2**torch.arange(c * bi - 1, -1,
|
||||
-1).to(discretized_target.device,
|
||||
discretized_target.dtype)
|
||||
target_label = torch.sum(bin_mask * discretized_target, -1)
|
||||
|
||||
predicted_patches = predicted_patches[mask]
|
||||
target_label = target_label[mask]
|
||||
loss = F.cross_entropy(predicted_patches, target_label)
|
||||
return loss
|
||||
|
||||
|
||||
# main class
|
||||
|
||||
|
||||
class MPP(nn.Module):
|
||||
def __init__(self,
|
||||
transformer,
|
||||
patch_size,
|
||||
dim,
|
||||
output_channel_bits=3,
|
||||
channels=3,
|
||||
max_pixel_val=1.0,
|
||||
mask_prob=0.15,
|
||||
replace_prob=0.5,
|
||||
random_patch_prob=0.5):
|
||||
super().__init__()
|
||||
|
||||
self.transformer = transformer
|
||||
self.loss = MPPLoss(patch_size, channels, output_channel_bits,
|
||||
max_pixel_val)
|
||||
|
||||
# output transformation
|
||||
self.to_bits = nn.Linear(dim, 2**(output_channel_bits * channels))
|
||||
|
||||
# vit related dimensions
|
||||
self.patch_size = patch_size
|
||||
|
||||
# mpp related probabilities
|
||||
self.mask_prob = mask_prob
|
||||
self.replace_prob = replace_prob
|
||||
self.random_patch_prob = random_patch_prob
|
||||
|
||||
# token ids
|
||||
self.mask_token = nn.Parameter(torch.randn(1, 1, dim * channels))
|
||||
|
||||
def forward(self, input, **kwargs):
|
||||
transformer = self.transformer
|
||||
# clone original image for loss
|
||||
img = input.clone().detach()
|
||||
|
||||
# reshape raw image to patches
|
||||
p = self.patch_size
|
||||
input = rearrange(input,
|
||||
'b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
|
||||
p1=p,
|
||||
p2=p)
|
||||
|
||||
mask = get_mask_subset_with_prob(input, self.mask_prob)
|
||||
|
||||
# mask input with mask patches with probability of `replace_prob` (keep patches the same with probability 1 - replace_prob)
|
||||
masked_input = input.clone().detach()
|
||||
|
||||
# if random token probability > 0 for mpp
|
||||
if self.random_patch_prob > 0:
|
||||
random_patch_sampling_prob = self.random_patch_prob / (
|
||||
1 - self.replace_prob)
|
||||
random_patch_prob = prob_mask_like(input,
|
||||
random_patch_sampling_prob)
|
||||
bool_random_patch_prob = mask * random_patch_prob == True
|
||||
random_patches = torch.randint(0,
|
||||
input.shape[1],
|
||||
(input.shape[0], input.shape[1]),
|
||||
device=input.device)
|
||||
randomized_input = masked_input[
|
||||
torch.arange(masked_input.shape[0]).unsqueeze(-1),
|
||||
random_patches]
|
||||
masked_input[bool_random_patch_prob] = randomized_input[
|
||||
bool_random_patch_prob]
|
||||
|
||||
# [mask] input
|
||||
replace_prob = prob_mask_like(input, self.replace_prob)
|
||||
bool_mask_replace = (mask * replace_prob) == True
|
||||
masked_input[bool_mask_replace] = self.mask_token
|
||||
|
||||
# linear embedding of patches
|
||||
masked_input = transformer.to_patch_embedding[-1](masked_input)
|
||||
|
||||
# add cls token to input sequence
|
||||
b, n, _ = masked_input.shape
|
||||
cls_tokens = repeat(transformer.cls_token, '() n d -> b n d', b=b)
|
||||
masked_input = torch.cat((cls_tokens, masked_input), dim=1)
|
||||
|
||||
# add positional embeddings to input
|
||||
masked_input += transformer.pos_embedding[:, :(n + 1)]
|
||||
masked_input = transformer.dropout(masked_input)
|
||||
|
||||
# get generator output and get mpp loss
|
||||
masked_input = transformer.transformer(masked_input, **kwargs)
|
||||
cls_logits = self.to_bits(masked_input)
|
||||
logits = cls_logits[:, 1:, :]
|
||||
|
||||
mpp_loss = self.loss(logits, img, mask)
|
||||
|
||||
return mpp_loss
|
||||
@@ -7,11 +7,16 @@ from vit_pytorch.vit_pytorch import Transformer
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# classes
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def conv_output_size(image_size, kernel_size, stride, padding):
|
||||
return int(((image_size - kernel_size + (2 * padding)) / stride) + 1)
|
||||
|
||||
# classes
|
||||
|
||||
class RearrangeImage(nn.Module):
|
||||
def forward(self, x):
|
||||
return rearrange(x, 'b (h w) c -> b c h w', h = int(math.sqrt(x.shape[1])))
|
||||
@@ -20,7 +25,7 @@ class RearrangeImage(nn.Module):
|
||||
|
||||
class T2TViT(nn.Module):
|
||||
def __init__(
|
||||
self, *, image_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., t2t_layers = ((7, 4), (3, 2), (3, 2))):
|
||||
self, *, image_size, num_classes, dim, depth = None, heads = None, mlp_dim = None, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., transformer = None, t2t_layers = ((7, 4), (3, 2), (3, 2))):
|
||||
super().__init__()
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
|
||||
@@ -47,7 +52,11 @@ class T2TViT(nn.Module):
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
if not exists(transformer):
|
||||
assert all([exists(depth), exists(heads), exists(mlp_dim)]), 'depth, heads, and mlp_dim must be supplied'
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
else:
|
||||
self.transformer = transformer
|
||||
|
||||
self.pool = pool
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
Reference in New Issue
Block a user