Compare commits

..

11 Commits

Author SHA1 Message Date
Phil Wang
26df10c0b7 fix max pool in nest 2021-05-28 11:06:02 -07:00
Phil Wang
17cb8976df make nest resilient to dimension that are not divisible by number of heads 2021-05-27 22:41:07 -07:00
Phil Wang
daf3abbeb5 add NesT 2021-05-27 22:02:17 -07:00
Phil Wang
b483b16833 0.18.4 2021-05-18 14:40:33 -07:00
Phil Wang
c457573808 Merge pull request #118 from loctruong96/main
update  mpp.py to work on GPU
2021-05-18 14:40:17 -07:00
Loc Truong
e75b6d0251 Update mpp.py
fix issue with GPU device mismatch
2021-05-16 20:07:49 -07:00
Phil Wang
679e5be3e7 apply scale to 2d rel pos bias in levit 2021-05-10 11:37:23 -07:00
Phil Wang
7333979e6b add link to official repo for levit 2021-05-06 13:12:30 -07:00
Phil Wang
74b402377b add image 2021-05-02 15:40:53 -07:00
Phil Wang
41d2d460d0 link to yannic 2021-05-02 14:51:55 -07:00
Phil Wang
04f86dee3c implement SOTA new self-supervised learning technique from facebook for vision transformers, Dino 2021-05-02 14:00:36 -07:00
8 changed files with 230 additions and 9 deletions

View File

@@ -271,6 +271,8 @@ preds = v(img) # (1, 1000)
<a href="https://arxiv.org/abs/2104.01136">This paper</a> proposes a number of changes, including (1) convolutional embedding instead of patch-wise projection (2) downsampling in stages (3) extra non-linearity in attention (4) 2d relative positional biases instead of initial absolute positional bias (5) batchnorm in place of layernorm.
<a href="https://github.com/facebookresearch/LeViT">Official repository</a>
```python
import torch
from vit_pytorch.levit import LeViT
@@ -376,6 +378,32 @@ img = torch.randn(1, 3, 224, 224)
pred = model(img) # (1, 1000)
```
## NesT
<img src="./images/nest.png" width="400px"></img>
This <a href="https://arxiv.org/abs/2105.12723">paper</a> decided to process the image in heirarchical stages, with attention only within tokens of local blocks, which aggregate as it moves up the heirarchy. The aggregation is done in the image plane, and contains a convolution to allow it to pass information across the boundary.
You can use it with the following code (ex. NesT-T)
```python
import torch
from vit_pytorch.nest import NesT
nest = NesT(
image_size = 224,
patch_size = 4,
dim = 96,
heads = 3,
num_heirarchies = 3, # number of heirarchies
block_repeats = (8, 4, 1), # the number of transformer blocks at each heirarchy, starting from the bottom
num_classes = 1000
)
img = torch.randn(1, 3, 224, 224)
pred = nest(img) # (1, 1000)
```
## Masked Patch Prediction
Thanks to <a href="https://github.com/zankner">Zach</a>, you can train using the original masked patch prediction task presented in the paper, with the following code.
@@ -424,8 +452,12 @@ torch.save(model.state_dict(), './pretrained-net.pt')
## Dino
<img src="./images/dino.png" width="350px"></img>
You can train `ViT` with the recent SOTA self-supervised learning technique, <a href="https://arxiv.org/abs/2104.14294">Dino</a>, with the following code.
<a href="https://www.youtube.com/watch?v=h3ij3F3cPIk">Yannic Kilcher</a> video
```python
import torch
from vit_pytorch import ViT, Dino
@@ -781,6 +813,17 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```
```bibtex
@misc{zhang2021aggregating,
title = {Aggregating Nested Transformers},
author = {Zizhao Zhang and Han Zhang and Long Zhao and Ting Chen and Tomas Pfister},
year = {2021},
eprint = {2105.12723},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@misc{caron2021emerging,
title = {Emerging Properties in Self-Supervised Vision Transformers},

BIN
images/dino.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 84 KiB

BIN
images/nest.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 75 KiB

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.18.0',
version = '0.19.2',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',

View File

@@ -278,8 +278,8 @@ class Dino(nn.Module):
image_one, image_two = self.augment1(x), self.augment2(x)
local_image_one, local_image_two = self.local_crop(image_one), self.local_crop(image_one)
global_image_one, global_image_two = self.global_crop(image_one), self.global_crop(image_one)
local_image_one, local_image_two = self.local_crop(image_one), self.local_crop(image_two)
global_image_one, global_image_two = self.global_crop(image_one), self.global_crop(image_two)
student_proj_one, _ = self.student_encoder(local_image_one)
student_proj_two, _ = self.student_encoder(local_image_two)

View File

@@ -84,7 +84,7 @@ class Attention(nn.Module):
def apply_pos_bias(self, fmap):
bias = self.pos_bias(self.pos_indices)
bias = rearrange(bias, 'i j h -> () h i j')
return fmap + bias
return fmap + (bias / self.scale)
def forward(self, x):
b, n, *_, h = *x.shape, self.heads

View File

@@ -50,7 +50,7 @@ class MPPLoss(nn.Module):
avg_target = target.mean(dim=3)
bin_size = self.max_pixel_val / self.output_channel_bits
channel_bins = torch.arange(bin_size, self.max_pixel_val, bin_size)
channel_bins = torch.arange(bin_size, self.max_pixel_val, bin_size).to(avg_target.device)
discretized_target = torch.bucketize(avg_target, channel_bins)
discretized_target = F.one_hot(discretized_target,
self.output_channel_bits)
@@ -86,7 +86,6 @@ class MPP(nn.Module):
replace_prob=0.5,
random_patch_prob=0.5):
super().__init__()
self.transformer = transformer
self.loss = MPPLoss(patch_size, channels, output_channel_bits,
max_pixel_val)
@@ -127,8 +126,9 @@ class MPP(nn.Module):
random_patch_sampling_prob = self.random_patch_prob / (
1 - self.replace_prob)
random_patch_prob = prob_mask_like(input,
random_patch_sampling_prob)
bool_random_patch_prob = mask * random_patch_prob == True
random_patch_sampling_prob).to(mask.device)
bool_random_patch_prob = mask * (random_patch_prob == True)
random_patches = torch.randint(0,
input.shape[1],
(input.shape[0], input.shape[1]),
@@ -140,7 +140,7 @@ class MPP(nn.Module):
bool_random_patch_prob]
# [mask] input
replace_prob = prob_mask_like(input, self.replace_prob)
replace_prob = prob_mask_like(input, self.replace_prob).to(mask.device)
bool_mask_replace = (mask * replace_prob) == True
masked_input[bool_mask_replace] = self.mask_token

178
vit_pytorch/nest.py Normal file
View File

@@ -0,0 +1,178 @@
import torch
from torch import nn, einsum
from einops import rearrange
from einops.layers.torch import Rearrange, Reduce
# helpers
def cast_tuple(val, depth):
return val if isinstance(val, tuple) else ((val,) * depth)
# classes
class ChanNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (std + self.eps) * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = ChanNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, mlp_mult = 4, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(dim, dim * mlp_mult, 1),
nn.GELU(),
nn.Dropout(dropout),
nn.Conv2d(dim * mlp_mult, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dropout = 0.):
super().__init__()
dim_head = dim // heads
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
self.to_out = nn.Sequential(
nn.Conv2d(inner_dim, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
b, c, h, w, heads = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> b h (x y) d', h = heads), qkv)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = self.attend(dots)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
return self.to_out(out)
def Aggregate(dim, dim_out):
return nn.Sequential(
nn.Conv2d(dim, dim_out, 3, padding = 1),
ChanNorm(dim_out),
nn.MaxPool2d(3, stride = 2, padding = 1)
)
class Transformer(nn.Module):
def __init__(self, dim, seq_len, depth, heads, mlp_mult, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
self.pos_emb = nn.Parameter(torch.randn(seq_len))
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout))
]))
def forward(self, x):
*_, h, w = x.shape
pos_emb = self.pos_emb[:(h * w)]
pos_emb = rearrange(pos_emb, '(h w) -> () () h w', h = h, w = w)
x = x + pos_emb
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class NesT(nn.Module):
def __init__(
self,
*,
image_size,
patch_size,
num_classes,
dim,
heads,
num_heirarchies,
block_repeats,
mlp_mult = 4,
channels = 3,
dim_head = 64,
dropout = 0.
):
super().__init__()
assert (image_size % patch_size) == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
fmap_size = image_size // patch_size
blocks = 2 ** (num_heirarchies - 1)
seq_len = (fmap_size // blocks) ** 2 # sequence length is held constant across heirarchy
heirarchies = list(reversed(range(num_heirarchies)))
mults = [2 ** i for i in heirarchies]
layer_heads = list(map(lambda t: t * heads, mults))
layer_dims = list(map(lambda t: t * dim, mults))
layer_dims = [*layer_dims, layer_dims[-1]]
dim_pairs = zip(layer_dims[:-1], layer_dims[1:])
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (p1 p2 c) h w', p1 = patch_size, p2 = patch_size),
nn.Conv2d(patch_dim, layer_dims[0], 1),
)
block_repeats = cast_tuple(block_repeats, num_heirarchies)
self.layers = nn.ModuleList([])
for level, heads, (dim_in, dim_out), block_repeat in zip(heirarchies, layer_heads, dim_pairs, block_repeats):
is_last = level == 0
depth = block_repeat
self.layers.append(nn.ModuleList([
Transformer(dim_in, seq_len, depth, heads, mlp_mult, dropout),
Aggregate(dim_in, dim_out) if not is_last else nn.Identity()
]))
self.mlp_head = nn.Sequential(
ChanNorm(dim),
Reduce('b c h w -> b c', 'mean'),
nn.Linear(dim, num_classes)
)
def forward(self, img):
x = self.to_patch_embedding(img)
b, c, h, w = x.shape
num_heirarchies = len(self.layers)
for level, (transformer, aggregate) in zip(reversed(range(num_heirarchies)), self.layers):
block_size = 2 ** level
x = rearrange(x, 'b c (b1 h) (b2 w) -> (b b1 b2) c h w', b1 = block_size, b2 = block_size)
x = transformer(x)
x = rearrange(x, '(b b1 b2) c h w -> b c (b1 h) (b2 w)', b1 = block_size, b2 = block_size)
x = aggregate(x)
return self.mlp_head(x)