Compare commits

...

16 Commits
0.7.2 ... 0.9.1

Author SHA1 Message Date
Phil Wang
b900850144 add deep vit 2021-03-23 11:57:13 -07:00
Phil Wang
78489045cd readme 2021-03-09 19:23:09 -08:00
Phil Wang
173e07e02e cleanup and release 0.8.0 2021-03-08 07:28:31 -08:00
Phil Wang
0e63766e54 Merge pull request #66 from zankner/masked_patch_pred
Masked Patch Prediction "Suggested in #63" Work in Progress
2021-03-08 07:21:52 -08:00
Zack Ankner
a6cbda37b9 added to readme 2021-03-08 09:34:55 -05:00
Zack Ankner
73de1e8a73 converting bin targets to hard labels 2021-03-07 12:19:30 -05:00
Phil Wang
1698b7bef8 make it so one can plug performer into t2tvit 2021-02-25 20:55:34 -08:00
Phil Wang
6760d554aa no need to do projection to combine attention heads for T2Ts initial one-headed attention layers 2021-02-24 12:23:39 -08:00
Phil Wang
a82894846d add DistillableT2TViT 2021-02-21 19:54:45 -08:00
Phil Wang
3744ac691a remove patch size from T2TViT 2021-02-21 19:15:19 -08:00
Zack Ankner
fc14561de7 made bit boundaries a function of output bits and max pixel val, fixed spelling error and reset vit_pytorch to og file 2021-02-13 18:19:21 -07:00
Zack Ankner
be5d560821 mpp loss is now based on descritized average pixels, vit forward unchanged 2021-02-12 18:30:56 -07:00
Zack Ankner
77703ae1fc moving mpp loss into wrapper 2021-02-10 21:47:49 -07:00
Zack Ankner
a0a4fa5e7d Working implementation of masked patch prediction as a wrapper. Need to clean code up 2021-02-09 22:55:06 -07:00
Zack Ankner
174e71cf53 Wrapper for masked patch prediction. Built handling of input and masking of patches. Need to work on integrating into vit forward call and mpp loss function 2021-02-07 16:49:06 -05:00
Zack Ankner
e14bd14a8f Prelim work on masked patch prediction for self supervision 2021-02-04 22:00:02 -05:00
7 changed files with 320 additions and 34 deletions

View File

@@ -117,6 +117,33 @@ v = v.to_vit()
type(v) # <class 'vit_pytorch.vit_pytorch.ViT'>
```
## Deep ViT
This <a href="https://arxiv.org/abs/2103.11886">paper</a> notes that ViT struggles to attend at greater depths (past 12 layers), and suggests mixing the attention of each head post-softmax as a solution, dubbed Re-attention. The results line up with the <a href="https://github.com/lucidrains/x-transformers#talking-heads-attention">Talking Heads</a> paper from NLP.
You can use it as follows
```python
import torch
from vit_pytorch.deepvit import DeepViT
v = DeepViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
img = torch.randn(1, 3, 256, 256)
preds = v(img) # (1, 1000)
```
## Token-to-Token ViT
<img src="./t2t.png" width="400px"></img>
@@ -130,7 +157,6 @@ from vit_pytorch.t2t import T2TViT
v = T2TViT(
dim = 512,
image_size = 224,
patch_size = 16,
depth = 5,
heads = 8,
mlp_dim = 512,
@@ -142,6 +168,52 @@ img = torch.randn(1, 3, 224, 224)
v(img) # (1, 1000)
```
## Masked Patch Prediction
Thanks to <a href="https://github.com/zankner">Zach</a>, you can train using the original masked patch prediction task presented in the paper, with the following code.
```python
import torch
from vit_pytorch import ViT
from vit_pytorch.mpp import MPP
model = ViT(
image_size=256,
patch_size=32,
num_classes=1000,
dim=1024,
depth=6,
heads=8,
mlp_dim=2048,
dropout=0.1,
emb_dropout=0.1
)
mpp_trainer = MPP(
transformer=model,
patch_size=32,
dim=1024,
mask_prob=0.15, # probability of using token in masked prediction task
random_patch_prob=0.30, # probability of randomly replacing a token being used for mpp
replace_prob=0.50, # probability of replacing a token being used for mpp with the mask token
)
opt = torch.optim.Adam(mpp_trainer.parameters(), lr=3e-4)
def sample_unlabelled_images():
return torch.randn(20, 3, 256, 256)
for _ in range(100):
images = sample_unlabelled_images()
loss = mpp_trainer(images)
opt.zero_grad()
loss.backward()
opt.step()
# save your improved network
torch.save(model.state_dict(), './pretrained-net.pt')
```
## Research Ideas
### Self Supervised Training
@@ -300,12 +372,23 @@ Coming from computer vision and new to transformers? Here are some resources tha
```bibtex
@misc{yuan2021tokenstotoken,
title = {Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet},
author = {Li Yuan and Yunpeng Chen and Tao Wang and Weihao Yu and Yujun Shi and Francis EH Tay and Jiashi Feng and Shuicheng Yan},
year = {2021},
eprint = {2101.11986},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
title = {Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet},
author = {Li Yuan and Yunpeng Chen and Tao Wang and Weihao Yu and Yujun Shi and Francis EH Tay and Jiashi Feng and Shuicheng Yan},
year = {2021},
eprint = {2101.11986},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@misc{zhou2021deepvit,
title = {DeepViT: Towards Deeper Vision Transformer},
author = {Daquan Zhou and Bingyi Kang and Xiaojie Jin and Linjie Yang and Xiaochen Lian and Qibin Hou and Jiashi Feng},
year = {2021},
eprint = {2103.11886},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.7.1',
version = '0.9.1',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',

View File

@@ -1 +1 @@
from vit_pytorch.vit_pytorch import ViT
from vit_pytorch.vit import ViT

View File

@@ -1,4 +1,3 @@
import math
import torch
from torch import nn, einsum
import torch.nn.functional as F
@@ -42,28 +41,37 @@ class Attention(nn.Module):
self.scale = dim_head ** -0.5
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.reattn_weights = nn.Parameter(torch.randn(heads, heads))
self.reattn_norm = nn.Sequential(
Rearrange('b h i j -> b i j h'),
nn.LayerNorm(heads),
Rearrange('b i j h -> b h i j')
)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, mask = None):
def forward(self, x):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
# attention
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
mask_value = -torch.finfo(dots.dtype).max
if mask is not None:
mask = F.pad(mask.flatten(1), (1, 0), value = True)
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
mask = rearrange(mask, 'b i -> b () i ()') * rearrange(mask, 'b j -> b () () j')
dots.masked_fill_(~mask, mask_value)
del mask
attn = dots.softmax(dim=-1)
# re-attention
attn = einsum('b h i j, h g -> b g i j', attn, self.reattn_weights)
attn = self.reattn_norm(attn)
# aggregate and out
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
@@ -78,13 +86,13 @@ class Transformer(nn.Module):
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
]))
def forward(self, x, mask = None):
def forward(self, x):
for attn, ff in self.layers:
x = attn(x, mask = mask)
x = attn(x)
x = ff(x)
return x
class ViT(nn.Module):
class DeepViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
@@ -111,7 +119,7 @@ class ViT(nn.Module):
nn.Linear(dim, num_classes)
)
def forward(self, img, mask = None):
def forward(self, img):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
@@ -120,7 +128,7 @@ class ViT(nn.Module):
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x, mask)
x = self.transformer(x)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

View File

@@ -2,6 +2,7 @@ import torch
import torch.nn.functional as F
from torch import nn
from vit_pytorch.vit_pytorch import ViT
from vit_pytorch.t2t import T2TViT
from vit_pytorch.efficient import ViT as EfficientViT
from einops import rearrange, repeat
@@ -60,6 +61,24 @@ class DistillableViT(DistillMixin, ViT):
x = self.transformer(x, mask)
return x
class DistillableT2TViT(DistillMixin, T2TViT):
def __init__(self, *args, **kwargs):
super(DistillableT2TViT, self).__init__(*args, **kwargs)
self.args = args
self.kwargs = kwargs
self.dim = kwargs['dim']
self.num_classes = kwargs['num_classes']
def to_vit(self):
v = T2TViT(*self.args, **self.kwargs)
v.load_state_dict(self.state_dict())
return v
def _attend(self, x, mask):
x = self.dropout(x)
x = self.transformer(x)
return x
class DistillableEfficientViT(DistillMixin, EfficientViT):
def __init__(self, *args, **kwargs):
super(DistillableEfficientViT, self).__init__(*args, **kwargs)
@@ -88,7 +107,7 @@ class DistillWrapper(nn.Module):
alpha = 0.5
):
super().__init__()
assert (isinstance(student, (DistillableViT, DistillableEfficientViT))) , 'student must be a vision transformer'
assert (isinstance(student, (DistillableViT, DistillableT2TViT, DistillableEfficientViT))) , 'student must be a vision transformer'
self.teacher = teacher
self.student = student

166
vit_pytorch/mpp.py Normal file
View File

@@ -0,0 +1,166 @@
import math
from functools import reduce
import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange, repeat
# helpers
def prob_mask_like(t, prob):
batch, seq_length, _ = t.shape
return torch.zeros((batch, seq_length)).float().uniform_(0, 1) < prob
def get_mask_subset_with_prob(patched_input, prob):
batch, seq_len, _, device = *patched_input.shape, patched_input.device
max_masked = math.ceil(prob * seq_len)
rand = torch.rand((batch, seq_len), device=device)
_, sampled_indices = rand.topk(max_masked, dim=-1)
new_mask = torch.zeros((batch, seq_len), device=device)
new_mask.scatter_(1, sampled_indices, 1)
return new_mask.bool()
# mpp loss
class MPPLoss(nn.Module):
def __init__(self, patch_size, channels, output_channel_bits,
max_pixel_val):
super(MPPLoss, self).__init__()
self.patch_size = patch_size
self.channels = channels
self.output_channel_bits = output_channel_bits
self.max_pixel_val = max_pixel_val
def forward(self, predicted_patches, target, mask):
# reshape target to patches
p = self.patch_size
target = rearrange(target,
"b c (h p1) (w p2) -> b (h w) c (p1 p2) ",
p1=p,
p2=p)
avg_target = target.mean(dim=3)
bin_size = self.max_pixel_val / self.output_channel_bits
channel_bins = torch.arange(bin_size, self.max_pixel_val, bin_size)
discretized_target = torch.bucketize(avg_target, channel_bins)
discretized_target = F.one_hot(discretized_target,
self.output_channel_bits)
c, bi = self.channels, self.output_channel_bits
discretized_target = rearrange(discretized_target,
"b n c bi -> b n (c bi)",
c=c,
bi=bi)
bin_mask = 2**torch.arange(c * bi - 1, -1,
-1).to(discretized_target.device,
discretized_target.dtype)
target_label = torch.sum(bin_mask * discretized_target, -1)
predicted_patches = predicted_patches[mask]
target_label = target_label[mask]
loss = F.cross_entropy(predicted_patches, target_label)
return loss
# main class
class MPP(nn.Module):
def __init__(self,
transformer,
patch_size,
dim,
output_channel_bits=3,
channels=3,
max_pixel_val=1.0,
mask_prob=0.15,
replace_prob=0.5,
random_patch_prob=0.5):
super().__init__()
self.transformer = transformer
self.loss = MPPLoss(patch_size, channels, output_channel_bits,
max_pixel_val)
# output transformation
self.to_bits = nn.Linear(dim, 2**(output_channel_bits * channels))
# vit related dimensions
self.patch_size = patch_size
# mpp related probabilities
self.mask_prob = mask_prob
self.replace_prob = replace_prob
self.random_patch_prob = random_patch_prob
# token ids
self.mask_token = nn.Parameter(torch.randn(1, 1, dim * channels))
def forward(self, input, **kwargs):
transformer = self.transformer
# clone original image for loss
img = input.clone().detach()
# reshape raw image to patches
p = self.patch_size
input = rearrange(input,
'b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
p1=p,
p2=p)
mask = get_mask_subset_with_prob(input, self.mask_prob)
# mask input with mask patches with probability of `replace_prob` (keep patches the same with probability 1 - replace_prob)
masked_input = input.clone().detach()
# if random token probability > 0 for mpp
if self.random_patch_prob > 0:
random_patch_sampling_prob = self.random_patch_prob / (
1 - self.replace_prob)
random_patch_prob = prob_mask_like(input,
random_patch_sampling_prob)
bool_random_patch_prob = mask * random_patch_prob == True
random_patches = torch.randint(0,
input.shape[1],
(input.shape[0], input.shape[1]),
device=input.device)
randomized_input = masked_input[
torch.arange(masked_input.shape[0]).unsqueeze(-1),
random_patches]
masked_input[bool_random_patch_prob] = randomized_input[
bool_random_patch_prob]
# [mask] input
replace_prob = prob_mask_like(input, self.replace_prob)
bool_mask_replace = (mask * replace_prob) == True
masked_input[bool_mask_replace] = self.mask_token
# linear embedding of patches
masked_input = transformer.to_patch_embedding[-1](masked_input)
# add cls token to input sequence
b, n, _ = masked_input.shape
cls_tokens = repeat(transformer.cls_token, '() n d -> b n d', b=b)
masked_input = torch.cat((cls_tokens, masked_input), dim=1)
# add positional embeddings to input
masked_input += transformer.pos_embedding[:, :(n + 1)]
masked_input = transformer.dropout(masked_input)
# get generator output and get mpp loss
masked_input = transformer.transformer(masked_input, **kwargs)
cls_logits = self.to_bits(masked_input)
logits = cls_logits[:, 1:, :]
mpp_loss = self.loss(logits, img, mask)
return mpp_loss

View File

@@ -7,6 +7,14 @@ from vit_pytorch.vit_pytorch import Transformer
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
def exists(val):
return val is not None
def conv_output_size(image_size, kernel_size, stride, padding):
return int(((image_size - kernel_size + (2 * padding)) / stride) + 1)
# classes
class RearrangeImage(nn.Module):
@@ -16,20 +24,18 @@ class RearrangeImage(nn.Module):
# main class
class T2TViT(nn.Module):
def __init__(
self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., t2t_layers = ((7, 4), (3, 2), (3, 2))):
def __init__(self, *, image_size, num_classes, dim, depth = None, heads = None, mlp_dim = None, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., transformer = None, t2t_layers = ((7, 4), (3, 2), (3, 2))):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
layers = []
layer_dim = channels
output_image_size = image_size
for i, (kernel_size, stride) in enumerate(t2t_layers):
layer_dim *= kernel_size ** 2
is_first = i == 0
output_image_size = conv_output_size(output_image_size, kernel_size, stride, stride // 2)
layers.extend([
RearrangeImage() if not is_first else nn.Identity(),
@@ -41,11 +47,15 @@ class T2TViT(nn.Module):
layers.append(nn.Linear(layer_dim, dim))
self.to_patch_embedding = nn.Sequential(*layers)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.pos_embedding = nn.Parameter(torch.randn(1, output_image_size ** 2 + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
if not exists(transformer):
assert all([exists(depth), exists(heads), exists(mlp_dim)]), 'depth, heads, and mlp_dim must be supplied'
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
else:
self.transformer = transformer
self.pool = pool
self.to_latent = nn.Identity()
@@ -61,7 +71,7 @@ class T2TViT(nn.Module):
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x += self.pos_embedding
x = self.dropout(x)
x = self.transformer(x)