mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 16:12:29 +00:00
Compare commits
24 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
171fd97a45 | ||
|
|
b900850144 | ||
|
|
78489045cd | ||
|
|
173e07e02e | ||
|
|
0e63766e54 | ||
|
|
a6cbda37b9 | ||
|
|
73de1e8a73 | ||
|
|
1698b7bef8 | ||
|
|
6760d554aa | ||
|
|
a82894846d | ||
|
|
3744ac691a | ||
|
|
6af7bbcd11 | ||
|
|
05edfff33c | ||
|
|
e3205c0a4f | ||
|
|
4fc7365356 | ||
|
|
3f2cbc6e23 | ||
|
|
fc14561de7 | ||
|
|
be5d560821 | ||
|
|
77703ae1fc | ||
|
|
a0a4fa5e7d | ||
|
|
174e71cf53 | ||
|
|
e14bd14a8f | ||
|
|
85314cf0b6 | ||
|
|
5db8d9deed |
132
README.md
132
README.md
@@ -40,7 +40,7 @@ preds = v(img, mask = mask) # (1, 1000)
|
||||
|
||||
## Parameters
|
||||
- `image_size`: int.
|
||||
Image size.
|
||||
Image size. If you have rectangular images, make sure your image size is the maximum of the width and height
|
||||
- `patch_size`: int.
|
||||
Number of patches. `image_size` must be divisible by `patch_size`.
|
||||
The number of patches is: ` n = (image_size // patch_size) ** 2` and `n` **must be greater than 16**.
|
||||
@@ -117,6 +117,103 @@ v = v.to_vit()
|
||||
type(v) # <class 'vit_pytorch.vit_pytorch.ViT'>
|
||||
```
|
||||
|
||||
## Deep ViT
|
||||
|
||||
This <a href="https://arxiv.org/abs/2103.11886">paper</a> notes that ViT struggles to attend at greater depths (past 12 layers), and suggests mixing the attention of each head post-softmax as a solution, dubbed Re-attention. The results line up with the <a href="https://github.com/lucidrains/x-transformers#talking-heads-attention">Talking Heads</a> paper from NLP.
|
||||
|
||||
You can use it as follows
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.deepvit import DeepViT
|
||||
|
||||
v = DeepViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 256, 256)
|
||||
|
||||
preds = v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## Token-to-Token ViT
|
||||
|
||||
<img src="./t2t.png" width="400px"></img>
|
||||
|
||||
<a href="https://arxiv.org/abs/2101.11986">This paper</a> proposes that the first couple layers should downsample the image sequence by unfolding, leading to overlapping image data in each token as shown in the figure above. You can use this variant of the `ViT` as follows.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.t2t import T2TViT
|
||||
|
||||
v = T2TViT(
|
||||
dim = 512,
|
||||
image_size = 224,
|
||||
depth = 5,
|
||||
heads = 8,
|
||||
mlp_dim = 512,
|
||||
num_classes = 1000,
|
||||
t2t_layers = ((7, 4), (3, 2), (3, 2)) # tuples of the kernel size and stride of each consecutive layers of the initial token to token module
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 224, 224)
|
||||
v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## Masked Patch Prediction
|
||||
|
||||
Thanks to <a href="https://github.com/zankner">Zach</a>, you can train using the original masked patch prediction task presented in the paper, with the following code.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch import ViT
|
||||
from vit_pytorch.mpp import MPP
|
||||
|
||||
model = ViT(
|
||||
image_size=256,
|
||||
patch_size=32,
|
||||
num_classes=1000,
|
||||
dim=1024,
|
||||
depth=6,
|
||||
heads=8,
|
||||
mlp_dim=2048,
|
||||
dropout=0.1,
|
||||
emb_dropout=0.1
|
||||
)
|
||||
|
||||
mpp_trainer = MPP(
|
||||
transformer=model,
|
||||
patch_size=32,
|
||||
dim=1024,
|
||||
mask_prob=0.15, # probability of using token in masked prediction task
|
||||
random_patch_prob=0.30, # probability of randomly replacing a token being used for mpp
|
||||
replace_prob=0.50, # probability of replacing a token being used for mpp with the mask token
|
||||
)
|
||||
|
||||
opt = torch.optim.Adam(mpp_trainer.parameters(), lr=3e-4)
|
||||
|
||||
def sample_unlabelled_images():
|
||||
return torch.randn(20, 3, 256, 256)
|
||||
|
||||
for _ in range(100):
|
||||
images = sample_unlabelled_images()
|
||||
loss = mpp_trainer(images)
|
||||
opt.zero_grad()
|
||||
loss.backward()
|
||||
opt.step()
|
||||
|
||||
# save your improved network
|
||||
torch.save(model.state_dict(), './pretrained-net.pt')
|
||||
```
|
||||
|
||||
## Research Ideas
|
||||
|
||||
### Self Supervised Training
|
||||
@@ -173,23 +270,22 @@ A pytorch-lightning script is ready for you to use at the repository link above.
|
||||
|
||||
There may be some coming from computer vision who think attention still suffers from quadratic costs. Fortunately, we have a lot of new techniques that may help. This repository offers a way for you to plugin your own sparse attention transformer.
|
||||
|
||||
An example with <a href="https://arxiv.org/abs/2006.04768">Linformer</a>
|
||||
An example with <a href="https://arxiv.org/abs/2102.03902">Nystromformer</a>
|
||||
|
||||
```bash
|
||||
$ pip install linformer
|
||||
$ pip install nystrom-attention
|
||||
```
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.efficient import ViT
|
||||
from linformer import Linformer
|
||||
from nystrom_attention import Nystromformer
|
||||
|
||||
efficient_transformer = Linformer(
|
||||
efficient_transformer = Nystromformer(
|
||||
dim = 512,
|
||||
seq_len = 4096 + 1, # 64 x 64 patches + 1 cls token
|
||||
depth = 12,
|
||||
heads = 8,
|
||||
k = 256
|
||||
num_landmarks = 256
|
||||
)
|
||||
|
||||
v = ViT(
|
||||
@@ -274,6 +370,28 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{yuan2021tokenstotoken,
|
||||
title = {Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet},
|
||||
author = {Li Yuan and Yunpeng Chen and Tao Wang and Weihao Yu and Yujun Shi and Francis EH Tay and Jiashi Feng and Shuicheng Yan},
|
||||
year = {2021},
|
||||
eprint = {2101.11986},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{zhou2021deepvit,
|
||||
title = {DeepViT: Towards Deeper Vision Transformer},
|
||||
author = {Daquan Zhou and Bingyi Kang and Xiaojie Jin and Linjie Yang and Xiaochen Lian and Qibin Hou and Jiashi Feng},
|
||||
year = {2021},
|
||||
eprint = {2103.11886},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{vaswani2017attention,
|
||||
title = {Attention Is All You Need},
|
||||
|
||||
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.6.6',
|
||||
version = '0.9.3',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
|
||||
@@ -1 +1 @@
|
||||
from vit_pytorch.vit_pytorch import ViT
|
||||
from vit_pytorch.vit import ViT
|
||||
|
||||
136
vit_pytorch/deepvit.py
Normal file
136
vit_pytorch/deepvit.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
class Residual(nn.Module):
|
||||
def __init__(self, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(x, **kwargs) + x
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.reattn_weights = nn.Parameter(torch.randn(heads, heads))
|
||||
|
||||
self.reattn_norm = nn.Sequential(
|
||||
Rearrange('b h i j -> b i j h'),
|
||||
nn.LayerNorm(heads),
|
||||
Rearrange('b i j h -> b h i j')
|
||||
)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
b, n, _, h = *x.shape, self.heads
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
||||
|
||||
# attention
|
||||
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
attn = dots.softmax(dim=-1)
|
||||
|
||||
# re-attention
|
||||
|
||||
attn = einsum('b h i j, h g -> b g i j', attn, self.reattn_weights)
|
||||
attn = self.reattn_norm(attn)
|
||||
|
||||
# aggregate and out
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
out = self.to_out(out)
|
||||
return out
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
|
||||
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x)
|
||||
x = ff(x)
|
||||
return x
|
||||
|
||||
class DeepViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
|
||||
super().__init__()
|
||||
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
patch_dim = channels * patch_size ** 2
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
||||
nn.Linear(patch_dim, dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
self.pool = pool
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.mlp_head(x)
|
||||
@@ -1,7 +1,8 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from vit_pytorch.vit_pytorch import ViT
|
||||
from vit_pytorch.vit import ViT
|
||||
from vit_pytorch.t2t import T2TViT
|
||||
from vit_pytorch.efficient import ViT as EfficientViT
|
||||
|
||||
from einops import rearrange, repeat
|
||||
@@ -14,11 +15,9 @@ def exists(val):
|
||||
# classes
|
||||
|
||||
class DistillMixin:
|
||||
def forward(self, img, distill_token = None, mask = None):
|
||||
p, distilling = self.patch_size, exists(distill_token)
|
||||
|
||||
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
|
||||
x = self.patch_to_embedding(x)
|
||||
def forward(self, img, distill_token = None):
|
||||
distilling = exists(distill_token)
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
@@ -29,7 +28,7 @@ class DistillMixin:
|
||||
distill_tokens = repeat(distill_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((x, distill_tokens), dim = 1)
|
||||
|
||||
x = self._attend(x, mask)
|
||||
x = self._attend(x)
|
||||
|
||||
if distilling:
|
||||
x, distill_tokens = x[:, :-1], x[:, -1]
|
||||
@@ -57,9 +56,27 @@ class DistillableViT(DistillMixin, ViT):
|
||||
v.load_state_dict(self.state_dict())
|
||||
return v
|
||||
|
||||
def _attend(self, x, mask):
|
||||
def _attend(self, x):
|
||||
x = self.dropout(x)
|
||||
x = self.transformer(x, mask)
|
||||
x = self.transformer(x)
|
||||
return x
|
||||
|
||||
class DistillableT2TViT(DistillMixin, T2TViT):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(DistillableT2TViT, self).__init__(*args, **kwargs)
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.dim = kwargs['dim']
|
||||
self.num_classes = kwargs['num_classes']
|
||||
|
||||
def to_vit(self):
|
||||
v = T2TViT(*self.args, **self.kwargs)
|
||||
v.load_state_dict(self.state_dict())
|
||||
return v
|
||||
|
||||
def _attend(self, x):
|
||||
x = self.dropout(x)
|
||||
x = self.transformer(x)
|
||||
return x
|
||||
|
||||
class DistillableEfficientViT(DistillMixin, EfficientViT):
|
||||
@@ -75,7 +92,7 @@ class DistillableEfficientViT(DistillMixin, EfficientViT):
|
||||
v.load_state_dict(self.state_dict())
|
||||
return v
|
||||
|
||||
def _attend(self, x, mask):
|
||||
def _attend(self, x):
|
||||
return self.transformer(x)
|
||||
|
||||
# knowledge distillation wrapper
|
||||
@@ -90,7 +107,7 @@ class DistillWrapper(nn.Module):
|
||||
alpha = 0.5
|
||||
):
|
||||
super().__init__()
|
||||
assert (isinstance(student, (DistillableViT, DistillableEfficientViT))) , 'student must be a vision transformer'
|
||||
assert (isinstance(student, (DistillableViT, DistillableT2TViT, DistillableEfficientViT))) , 'student must be a vision transformer'
|
||||
|
||||
self.teacher = teacher
|
||||
self.student = student
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from torch import nn
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
class ViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, transformer, pool = 'cls', channels = 3):
|
||||
@@ -10,10 +11,12 @@ class ViT(nn.Module):
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
patch_dim = channels * patch_size ** 2
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
||||
nn.Linear(patch_dim, dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.patch_to_embedding = nn.Linear(patch_dim, dim)
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.transformer = transformer
|
||||
|
||||
@@ -26,10 +29,7 @@ class ViT(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
p = self.patch_size
|
||||
|
||||
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
|
||||
x = self.patch_to_embedding(x)
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
|
||||
166
vit_pytorch/mpp.py
Normal file
166
vit_pytorch/mpp.py
Normal file
@@ -0,0 +1,166 @@
|
||||
import math
|
||||
from functools import reduce
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
# helpers
|
||||
|
||||
|
||||
def prob_mask_like(t, prob):
|
||||
batch, seq_length, _ = t.shape
|
||||
return torch.zeros((batch, seq_length)).float().uniform_(0, 1) < prob
|
||||
|
||||
|
||||
def get_mask_subset_with_prob(patched_input, prob):
|
||||
batch, seq_len, _, device = *patched_input.shape, patched_input.device
|
||||
max_masked = math.ceil(prob * seq_len)
|
||||
|
||||
rand = torch.rand((batch, seq_len), device=device)
|
||||
_, sampled_indices = rand.topk(max_masked, dim=-1)
|
||||
|
||||
new_mask = torch.zeros((batch, seq_len), device=device)
|
||||
new_mask.scatter_(1, sampled_indices, 1)
|
||||
return new_mask.bool()
|
||||
|
||||
|
||||
# mpp loss
|
||||
|
||||
|
||||
class MPPLoss(nn.Module):
|
||||
def __init__(self, patch_size, channels, output_channel_bits,
|
||||
max_pixel_val):
|
||||
super(MPPLoss, self).__init__()
|
||||
self.patch_size = patch_size
|
||||
self.channels = channels
|
||||
self.output_channel_bits = output_channel_bits
|
||||
self.max_pixel_val = max_pixel_val
|
||||
|
||||
def forward(self, predicted_patches, target, mask):
|
||||
# reshape target to patches
|
||||
p = self.patch_size
|
||||
target = rearrange(target,
|
||||
"b c (h p1) (w p2) -> b (h w) c (p1 p2) ",
|
||||
p1=p,
|
||||
p2=p)
|
||||
|
||||
avg_target = target.mean(dim=3)
|
||||
|
||||
bin_size = self.max_pixel_val / self.output_channel_bits
|
||||
channel_bins = torch.arange(bin_size, self.max_pixel_val, bin_size)
|
||||
discretized_target = torch.bucketize(avg_target, channel_bins)
|
||||
discretized_target = F.one_hot(discretized_target,
|
||||
self.output_channel_bits)
|
||||
c, bi = self.channels, self.output_channel_bits
|
||||
discretized_target = rearrange(discretized_target,
|
||||
"b n c bi -> b n (c bi)",
|
||||
c=c,
|
||||
bi=bi)
|
||||
|
||||
bin_mask = 2**torch.arange(c * bi - 1, -1,
|
||||
-1).to(discretized_target.device,
|
||||
discretized_target.dtype)
|
||||
target_label = torch.sum(bin_mask * discretized_target, -1)
|
||||
|
||||
predicted_patches = predicted_patches[mask]
|
||||
target_label = target_label[mask]
|
||||
loss = F.cross_entropy(predicted_patches, target_label)
|
||||
return loss
|
||||
|
||||
|
||||
# main class
|
||||
|
||||
|
||||
class MPP(nn.Module):
|
||||
def __init__(self,
|
||||
transformer,
|
||||
patch_size,
|
||||
dim,
|
||||
output_channel_bits=3,
|
||||
channels=3,
|
||||
max_pixel_val=1.0,
|
||||
mask_prob=0.15,
|
||||
replace_prob=0.5,
|
||||
random_patch_prob=0.5):
|
||||
super().__init__()
|
||||
|
||||
self.transformer = transformer
|
||||
self.loss = MPPLoss(patch_size, channels, output_channel_bits,
|
||||
max_pixel_val)
|
||||
|
||||
# output transformation
|
||||
self.to_bits = nn.Linear(dim, 2**(output_channel_bits * channels))
|
||||
|
||||
# vit related dimensions
|
||||
self.patch_size = patch_size
|
||||
|
||||
# mpp related probabilities
|
||||
self.mask_prob = mask_prob
|
||||
self.replace_prob = replace_prob
|
||||
self.random_patch_prob = random_patch_prob
|
||||
|
||||
# token ids
|
||||
self.mask_token = nn.Parameter(torch.randn(1, 1, dim * channels))
|
||||
|
||||
def forward(self, input, **kwargs):
|
||||
transformer = self.transformer
|
||||
# clone original image for loss
|
||||
img = input.clone().detach()
|
||||
|
||||
# reshape raw image to patches
|
||||
p = self.patch_size
|
||||
input = rearrange(input,
|
||||
'b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
|
||||
p1=p,
|
||||
p2=p)
|
||||
|
||||
mask = get_mask_subset_with_prob(input, self.mask_prob)
|
||||
|
||||
# mask input with mask patches with probability of `replace_prob` (keep patches the same with probability 1 - replace_prob)
|
||||
masked_input = input.clone().detach()
|
||||
|
||||
# if random token probability > 0 for mpp
|
||||
if self.random_patch_prob > 0:
|
||||
random_patch_sampling_prob = self.random_patch_prob / (
|
||||
1 - self.replace_prob)
|
||||
random_patch_prob = prob_mask_like(input,
|
||||
random_patch_sampling_prob)
|
||||
bool_random_patch_prob = mask * random_patch_prob == True
|
||||
random_patches = torch.randint(0,
|
||||
input.shape[1],
|
||||
(input.shape[0], input.shape[1]),
|
||||
device=input.device)
|
||||
randomized_input = masked_input[
|
||||
torch.arange(masked_input.shape[0]).unsqueeze(-1),
|
||||
random_patches]
|
||||
masked_input[bool_random_patch_prob] = randomized_input[
|
||||
bool_random_patch_prob]
|
||||
|
||||
# [mask] input
|
||||
replace_prob = prob_mask_like(input, self.replace_prob)
|
||||
bool_mask_replace = (mask * replace_prob) == True
|
||||
masked_input[bool_mask_replace] = self.mask_token
|
||||
|
||||
# linear embedding of patches
|
||||
masked_input = transformer.to_patch_embedding[-1](masked_input)
|
||||
|
||||
# add cls token to input sequence
|
||||
b, n, _ = masked_input.shape
|
||||
cls_tokens = repeat(transformer.cls_token, '() n d -> b n d', b=b)
|
||||
masked_input = torch.cat((cls_tokens, masked_input), dim=1)
|
||||
|
||||
# add positional embeddings to input
|
||||
masked_input += transformer.pos_embedding[:, :(n + 1)]
|
||||
masked_input = transformer.dropout(masked_input)
|
||||
|
||||
# get generator output and get mpp loss
|
||||
masked_input = transformer.transformer(masked_input, **kwargs)
|
||||
cls_logits = self.to_bits(masked_input)
|
||||
logits = cls_logits[:, 1:, :]
|
||||
|
||||
mpp_loss = self.loss(logits, img, mask)
|
||||
|
||||
return mpp_loss
|
||||
82
vit_pytorch/t2t.py
Normal file
82
vit_pytorch/t2t.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vit_pytorch.vit import Transformer
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def conv_output_size(image_size, kernel_size, stride, padding):
|
||||
return int(((image_size - kernel_size + (2 * padding)) / stride) + 1)
|
||||
|
||||
# classes
|
||||
|
||||
class RearrangeImage(nn.Module):
|
||||
def forward(self, x):
|
||||
return rearrange(x, 'b (h w) c -> b c h w', h = int(math.sqrt(x.shape[1])))
|
||||
|
||||
# main class
|
||||
|
||||
class T2TViT(nn.Module):
|
||||
def __init__(self, *, image_size, num_classes, dim, depth = None, heads = None, mlp_dim = None, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., transformer = None, t2t_layers = ((7, 4), (3, 2), (3, 2))):
|
||||
super().__init__()
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
|
||||
layers = []
|
||||
layer_dim = channels
|
||||
output_image_size = image_size
|
||||
|
||||
for i, (kernel_size, stride) in enumerate(t2t_layers):
|
||||
layer_dim *= kernel_size ** 2
|
||||
is_first = i == 0
|
||||
output_image_size = conv_output_size(output_image_size, kernel_size, stride, stride // 2)
|
||||
|
||||
layers.extend([
|
||||
RearrangeImage() if not is_first else nn.Identity(),
|
||||
nn.Unfold(kernel_size = kernel_size, stride = stride, padding = stride // 2),
|
||||
Rearrange('b c n -> b n c'),
|
||||
Transformer(dim = layer_dim, heads = 1, depth = 1, dim_head = layer_dim, mlp_dim = layer_dim, dropout = dropout),
|
||||
])
|
||||
|
||||
layers.append(nn.Linear(layer_dim, dim))
|
||||
self.to_patch_embedding = nn.Sequential(*layers)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, output_image_size ** 2 + 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
if not exists(transformer):
|
||||
assert all([exists(depth), exists(heads), exists(mlp_dim)]), 'depth, heads, and mlp_dim must be supplied'
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
else:
|
||||
self.transformer = transformer
|
||||
|
||||
self.pool = pool
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.mlp_head(x)
|
||||
@@ -1,9 +1,9 @@
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from torch import nn
|
||||
|
||||
MIN_NUM_PATCHES = 16
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
class Residual(nn.Module):
|
||||
def __init__(self, fn):
|
||||
@@ -37,39 +37,34 @@ class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim ** -0.5
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x, mask = None):
|
||||
def forward(self, x):
|
||||
b, n, _, h = *x.shape, self.heads
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
||||
|
||||
dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
|
||||
mask_value = -torch.finfo(dots.dtype).max
|
||||
|
||||
if mask is not None:
|
||||
mask = F.pad(mask.flatten(1), (1, 0), value = True)
|
||||
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
|
||||
mask = mask[:, None, :] * mask[:, :, None]
|
||||
dots.masked_fill_(~mask, mask_value)
|
||||
del mask
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
|
||||
attn = dots.softmax(dim=-1)
|
||||
|
||||
out = torch.einsum('bhij,bhjd->bhid', attn, v)
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
out = self.to_out(out)
|
||||
return out
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
@@ -77,9 +72,9 @@ class Transformer(nn.Module):
|
||||
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
|
||||
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
|
||||
]))
|
||||
def forward(self, x, mask = None):
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x, mask = mask)
|
||||
x = attn(x)
|
||||
x = ff(x)
|
||||
return x
|
||||
|
||||
@@ -89,13 +84,14 @@ class ViT(nn.Module):
|
||||
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
patch_dim = channels * patch_size ** 2
|
||||
assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size'
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
||||
nn.Linear(patch_dim, dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.patch_to_embedding = nn.Linear(patch_dim, dim)
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
@@ -109,11 +105,8 @@ class ViT(nn.Module):
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img, mask = None):
|
||||
p = self.patch_size
|
||||
|
||||
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
|
||||
x = self.patch_to_embedding(x)
|
||||
def forward(self, img):
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
@@ -121,7 +114,7 @@ class ViT(nn.Module):
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x, mask)
|
||||
x = self.transformer(x)
|
||||
|
||||
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
Reference in New Issue
Block a user