mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 16:12:29 +00:00
Compare commits
16 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
53884f583f | ||
|
|
e616b5dcbc | ||
|
|
60ad4e266e | ||
|
|
a254a0258a | ||
|
|
26df10c0b7 | ||
|
|
17cb8976df | ||
|
|
daf3abbeb5 | ||
|
|
b483b16833 | ||
|
|
c457573808 | ||
|
|
e75b6d0251 | ||
|
|
679e5be3e7 | ||
|
|
7333979e6b | ||
|
|
74b402377b | ||
|
|
41d2d460d0 | ||
|
|
a2df363224 | ||
|
|
710b6d57d3 |
43
README.md
43
README.md
@@ -271,6 +271,8 @@ preds = v(img) # (1, 1000)
|
||||
|
||||
<a href="https://arxiv.org/abs/2104.01136">This paper</a> proposes a number of changes, including (1) convolutional embedding instead of patch-wise projection (2) downsampling in stages (3) extra non-linearity in attention (4) 2d relative positional biases instead of initial absolute positional bias (5) batchnorm in place of layernorm.
|
||||
|
||||
<a href="https://github.com/facebookresearch/LeViT">Official repository</a>
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.levit import LeViT
|
||||
@@ -376,6 +378,32 @@ img = torch.randn(1, 3, 224, 224)
|
||||
pred = model(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## NesT
|
||||
|
||||
<img src="./images/nest.png" width="400px"></img>
|
||||
|
||||
This <a href="https://arxiv.org/abs/2105.12723">paper</a> decided to process the image in hierarchical stages, with attention only within tokens of local blocks, which aggregate as it moves up the heirarchy. The aggregation is done in the image plane, and contains a convolution and subsequent maxpool to allow it to pass information across the boundary.
|
||||
|
||||
You can use it with the following code (ex. NesT-T)
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.nest import NesT
|
||||
|
||||
nest = NesT(
|
||||
image_size = 224,
|
||||
patch_size = 4,
|
||||
dim = 96,
|
||||
heads = 3,
|
||||
num_hierarchies = 3, # number of hierarchies
|
||||
block_repeats = (8, 4, 1), # the number of transformer blocks at each heirarchy, starting from the bottom
|
||||
num_classes = 1000
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 224, 224)
|
||||
pred = nest(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## Masked Patch Prediction
|
||||
|
||||
Thanks to <a href="https://github.com/zankner">Zach</a>, you can train using the original masked patch prediction task presented in the paper, with the following code.
|
||||
@@ -424,8 +452,12 @@ torch.save(model.state_dict(), './pretrained-net.pt')
|
||||
|
||||
## Dino
|
||||
|
||||
<img src="./images/dino.png" width="350px"></img>
|
||||
|
||||
You can train `ViT` with the recent SOTA self-supervised learning technique, <a href="https://arxiv.org/abs/2104.14294">Dino</a>, with the following code.
|
||||
|
||||
<a href="https://www.youtube.com/watch?v=h3ij3F3cPIk">Yannic Kilcher</a> video
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch import ViT, Dino
|
||||
@@ -781,6 +813,17 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{zhang2021aggregating,
|
||||
title = {Aggregating Nested Transformers},
|
||||
author = {Zizhao Zhang and Han Zhang and Long Zhao and Ting Chen and Tomas Pfister},
|
||||
year = {2021},
|
||||
eprint = {2105.12723},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{caron2021emerging,
|
||||
title = {Emerging Properties in Self-Supervised Vision Transformers},
|
||||
|
||||
BIN
images/dino.png
Normal file
BIN
images/dino.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 84 KiB |
BIN
images/nest.png
Normal file
BIN
images/nest.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 75 KiB |
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.18.1',
|
||||
version = '0.19.5',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
|
||||
@@ -84,7 +84,7 @@ class Attention(nn.Module):
|
||||
def apply_pos_bias(self, fmap):
|
||||
bias = self.pos_bias(self.pos_indices)
|
||||
bias = rearrange(bias, 'i j h -> () h i j')
|
||||
return fmap + bias
|
||||
return fmap + (bias / self.scale)
|
||||
|
||||
def forward(self, x):
|
||||
b, n, *_, h = *x.shape, self.heads
|
||||
|
||||
@@ -32,14 +32,27 @@ def get_mask_subset_with_prob(patched_input, prob):
|
||||
|
||||
class MPPLoss(nn.Module):
|
||||
def __init__(self, patch_size, channels, output_channel_bits,
|
||||
max_pixel_val):
|
||||
max_pixel_val, mean, std):
|
||||
super(MPPLoss, self).__init__()
|
||||
self.patch_size = patch_size
|
||||
self.channels = channels
|
||||
self.output_channel_bits = output_channel_bits
|
||||
self.max_pixel_val = max_pixel_val
|
||||
|
||||
if mean:
|
||||
self.mean = torch.tensor(mean).view(-1, 1, 1)
|
||||
else:
|
||||
self.mean = None
|
||||
if std:
|
||||
self.std = torch.tensor(std).view(-1, 1, 1)
|
||||
else:
|
||||
self.std = None
|
||||
|
||||
def forward(self, predicted_patches, target, mask):
|
||||
# un-normalize input
|
||||
if self.mean is not None and self.std is not None:
|
||||
target = target * self.std + self.mean
|
||||
|
||||
# reshape target to patches
|
||||
p = self.patch_size
|
||||
target = rearrange(target,
|
||||
@@ -50,7 +63,7 @@ class MPPLoss(nn.Module):
|
||||
avg_target = target.mean(dim=3)
|
||||
|
||||
bin_size = self.max_pixel_val / self.output_channel_bits
|
||||
channel_bins = torch.arange(bin_size, self.max_pixel_val, bin_size)
|
||||
channel_bins = torch.arange(bin_size, self.max_pixel_val, bin_size).to(avg_target.device)
|
||||
discretized_target = torch.bucketize(avg_target, channel_bins)
|
||||
discretized_target = F.one_hot(discretized_target,
|
||||
self.output_channel_bits)
|
||||
@@ -64,7 +77,6 @@ class MPPLoss(nn.Module):
|
||||
-1).to(discretized_target.device,
|
||||
discretized_target.dtype)
|
||||
target_label = torch.sum(bin_mask * discretized_target, -1)
|
||||
|
||||
predicted_patches = predicted_patches[mask]
|
||||
target_label = target_label[mask]
|
||||
loss = F.cross_entropy(predicted_patches, target_label)
|
||||
@@ -84,12 +96,13 @@ class MPP(nn.Module):
|
||||
max_pixel_val=1.0,
|
||||
mask_prob=0.15,
|
||||
replace_prob=0.5,
|
||||
random_patch_prob=0.5):
|
||||
random_patch_prob=0.5,
|
||||
mean=None,
|
||||
std=None):
|
||||
super().__init__()
|
||||
|
||||
self.transformer = transformer
|
||||
self.loss = MPPLoss(patch_size, channels, output_channel_bits,
|
||||
max_pixel_val)
|
||||
max_pixel_val, mean, std)
|
||||
|
||||
# output transformation
|
||||
self.to_bits = nn.Linear(dim, 2**(output_channel_bits * channels))
|
||||
@@ -103,7 +116,7 @@ class MPP(nn.Module):
|
||||
self.random_patch_prob = random_patch_prob
|
||||
|
||||
# token ids
|
||||
self.mask_token = nn.Parameter(torch.randn(1, 1, dim * channels))
|
||||
self.mask_token = nn.Parameter(torch.randn(1, 1, channels * patch_size ** 2))
|
||||
|
||||
def forward(self, input, **kwargs):
|
||||
transformer = self.transformer
|
||||
@@ -127,8 +140,9 @@ class MPP(nn.Module):
|
||||
random_patch_sampling_prob = self.random_patch_prob / (
|
||||
1 - self.replace_prob)
|
||||
random_patch_prob = prob_mask_like(input,
|
||||
random_patch_sampling_prob)
|
||||
bool_random_patch_prob = mask * random_patch_prob == True
|
||||
random_patch_sampling_prob).to(mask.device)
|
||||
|
||||
bool_random_patch_prob = mask * (random_patch_prob == True)
|
||||
random_patches = torch.randint(0,
|
||||
input.shape[1],
|
||||
(input.shape[0], input.shape[1]),
|
||||
@@ -140,7 +154,7 @@ class MPP(nn.Module):
|
||||
bool_random_patch_prob]
|
||||
|
||||
# [mask] input
|
||||
replace_prob = prob_mask_like(input, self.replace_prob)
|
||||
replace_prob = prob_mask_like(input, self.replace_prob).to(mask.device)
|
||||
bool_mask_replace = (mask * replace_prob) == True
|
||||
masked_input[bool_mask_replace] = self.mask_token
|
||||
|
||||
|
||||
169
vit_pytorch/nest.py
Normal file
169
vit_pytorch/nest.py
Normal file
@@ -0,0 +1,169 @@
|
||||
from functools import partial
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange, Reduce
|
||||
|
||||
# helpers
|
||||
|
||||
def cast_tuple(val, depth):
|
||||
return val if isinstance(val, tuple) else ((val,) * depth)
|
||||
|
||||
LayerNorm = partial(nn.InstanceNorm2d, affine = True)
|
||||
|
||||
# classes
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = LayerNorm(dim)
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, mlp_mult = 4, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(dim, dim * mlp_mult, 1),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Conv2d(dim * mlp_mult, dim, 1),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dropout = 0.):
|
||||
super().__init__()
|
||||
dim_head = dim // heads
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Conv2d(inner_dim, dim, 1),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w, heads = *x.shape, self.heads
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = 1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> b h (x y) d', h = heads), qkv)
|
||||
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
|
||||
return self.to_out(out)
|
||||
|
||||
def Aggregate(dim, dim_out):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(dim, dim_out, 3, padding = 1),
|
||||
LayerNorm(dim_out),
|
||||
nn.MaxPool2d(3, stride = 2, padding = 1)
|
||||
)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, seq_len, depth, heads, mlp_mult, dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
self.pos_emb = nn.Parameter(torch.randn(seq_len))
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, heads = heads, dropout = dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout))
|
||||
]))
|
||||
def forward(self, x):
|
||||
*_, h, w = x.shape
|
||||
|
||||
pos_emb = self.pos_emb[:(h * w)]
|
||||
pos_emb = rearrange(pos_emb, '(h w) -> () () h w', h = h, w = w)
|
||||
x = x + pos_emb
|
||||
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
class NesT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_size,
|
||||
patch_size,
|
||||
num_classes,
|
||||
dim,
|
||||
heads,
|
||||
num_hierarchies,
|
||||
block_repeats,
|
||||
mlp_mult = 4,
|
||||
channels = 3,
|
||||
dim_head = 64,
|
||||
dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
assert (image_size % patch_size) == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
patch_dim = channels * patch_size ** 2
|
||||
fmap_size = image_size // patch_size
|
||||
blocks = 2 ** (num_hierarchies - 1)
|
||||
|
||||
seq_len = (fmap_size // blocks) ** 2 # sequence length is held constant across heirarchy
|
||||
hierarchies = list(reversed(range(num_hierarchies)))
|
||||
mults = [2 ** i for i in hierarchies]
|
||||
|
||||
layer_heads = list(map(lambda t: t * heads, mults))
|
||||
layer_dims = list(map(lambda t: t * dim, mults))
|
||||
|
||||
layer_dims = [*layer_dims, layer_dims[-1]]
|
||||
dim_pairs = zip(layer_dims[:-1], layer_dims[1:])
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (p1 p2 c) h w', p1 = patch_size, p2 = patch_size),
|
||||
nn.Conv2d(patch_dim, layer_dims[0], 1),
|
||||
)
|
||||
|
||||
block_repeats = cast_tuple(block_repeats, num_hierarchies)
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
for level, heads, (dim_in, dim_out), block_repeat in zip(hierarchies, layer_heads, dim_pairs, block_repeats):
|
||||
is_last = level == 0
|
||||
depth = block_repeat
|
||||
|
||||
self.layers.append(nn.ModuleList([
|
||||
Transformer(dim_in, seq_len, depth, heads, mlp_mult, dropout),
|
||||
Aggregate(dim_in, dim_out) if not is_last else nn.Identity()
|
||||
]))
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
LayerNorm(dim),
|
||||
Reduce('b c h w -> b c', 'mean'),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
x = self.to_patch_embedding(img)
|
||||
b, c, h, w = x.shape
|
||||
|
||||
num_hierarchies = len(self.layers)
|
||||
|
||||
for level, (transformer, aggregate) in zip(reversed(range(num_hierarchies)), self.layers):
|
||||
block_size = 2 ** level
|
||||
x = rearrange(x, 'b c (b1 h) (b2 w) -> (b b1 b2) c h w', b1 = block_size, b2 = block_size)
|
||||
x = transformer(x)
|
||||
x = rearrange(x, '(b b1 b2) c h w -> b c (b1 h) (b2 w)', b1 = block_size, b2 = block_size)
|
||||
x = aggregate(x)
|
||||
|
||||
return self.mlp_head(x)
|
||||
Reference in New Issue
Block a user