Compare commits

...

6 Commits

Author SHA1 Message Date
Phil Wang
72c34f6bae add CrossViT 2021-03-30 00:49:33 -07:00
Phil Wang
d04ce06a30 make recorder work for t2t and deepvit 2021-03-29 18:16:34 -07:00
Phil Wang
8135d70e4e use hooks to retrieve attention maps for user without modifying ViT 2021-03-29 15:10:12 -07:00
Phil Wang
3067155cea add recorder class, for recording attention across layers, for researchers 2021-03-29 11:08:19 -07:00
Phil Wang
ab7315cca1 cleanup 2021-03-27 22:14:16 -07:00
Phil Wang
15294c304e remove masking, as it complicates with little benefit 2021-03-23 12:18:47 -07:00
10 changed files with 502 additions and 14 deletions

View File

@@ -1,4 +1,4 @@
<img src="./vit.gif" width="500px"></img>
<img src="./images/vit.gif" width="500px"></img>
## Vision Transformer - Pytorch
@@ -33,9 +33,8 @@ v = ViT(
)
img = torch.randn(1, 3, 256, 256)
mask = torch.ones(1, 8, 8).bool() # optional mask, designating which patch to attend to
preds = v(img, mask = mask) # (1, 1000)
preds = v(img) # (1, 1000)
```
## Parameters
@@ -64,7 +63,7 @@ Embedding dropout rate.
## Distillation
<img src="./distill.png" width="300px"></img>
<img src="./images/distill.png" width="300px"></img>
A recent <a href="https://arxiv.org/abs/2012.12877">paper</a> has shown that use of a distillation token for distilling knowledge from convolutional nets to vision transformer can yield small and efficient vision transformers. This repository offers the means to do distillation easily.
@@ -146,7 +145,7 @@ preds = v(img) # (1, 1000)
## Token-to-Token ViT
<img src="./t2t.png" width="400px"></img>
<img src="./images/t2t.png" width="400px"></img>
<a href="https://arxiv.org/abs/2101.11986">This paper</a> proposes that the first couple layers should downsample the image sequence by unfolding, leading to overlapping image data in each token as shown in the figure above. You can use this variant of the `ViT` as follows.
@@ -214,6 +213,47 @@ for _ in range(100):
torch.save(model.state_dict(), './pretrained-net.pt')
```
## Accessing Attention
If you would like to visualize the attention weights (post-softmax) for your research, just follow the procedure below
```python
import torch
from vit_pytorch.vit import ViT
v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
# import Recorder and wrap the ViT
from vit_pytorch.recorder import Recorder
v = Recorder(v)
# forward pass now returns predictions and the attention maps
img = torch.randn(1, 3, 256, 256)
preds, attns = v(img)
# there is one extra patch due to the CLS token
attns # (1, 6, 16, 65, 65) - (batch x layers x heads x patch x patch)
```
to cleanup the class and the hooks once you have collected enough data
```python
v = v.eject() # wrapper is discarded and original ViT instance is returned
```
## Research Ideas
### Self Supervised Training
@@ -392,6 +432,17 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```
```bibtex
@misc{chen2021crossvit,
title = {CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification},
author = {Chun-Fu Chen and Quanfu Fan and Rameswar Panda},
year = {2021},
eprint = {2103.14899},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@misc{vaswani2017attention,
title = {Attention Is All You Need},

View File

Before

Width:  |  Height:  |  Size: 49 KiB

After

Width:  |  Height:  |  Size: 49 KiB

View File

Before

Width:  |  Height:  |  Size: 109 KiB

After

Width:  |  Height:  |  Size: 109 KiB

View File

Before

Width:  |  Height:  |  Size: 5.8 MiB

After

Width:  |  Height:  |  Size: 5.8 MiB

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.9.1',
version = '0.10.2',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',

268
vit_pytorch/cross_vit.py Normal file
View File

@@ -0,0 +1,268 @@
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
# pre-layernorm
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
# feedforward
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
# attention
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, context = None, kv_include_self = False):
b, n, _, h = *x.shape, self.heads
context = default(context, x)
if kv_include_self:
context = torch.cat((x, context), dim = 1) # cross attention requires CLS token includes itself as key / value
qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = self.attend(dots)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
# transformer encoder, for small and large patches
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
self.norm = nn.LayerNorm(dim)
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)
# projecting CLS tokens, in the case that small and large patch tokens have different dimensions
class ProjectInOut(nn.Module):
def __init__(self, dim_in, dim_out, fn):
super().__init__()
self.fn = fn
need_projection = dim_in != dim_out
self.project_in = nn.Linear(dim_in, dim_out) if need_projection else nn.Identity()
self.project_out = nn.Linear(dim_out, dim_in) if need_projection else nn.Identity()
def forward(self, x, *args, **kwargs):
x = self.project_in(x)
x = self.fn(x, *args, **kwargs)
x = self.project_out(x)
return x
# cross attention transformer
class CrossTransformer(nn.Module):
def __init__(self, sm_dim, lg_dim, depth, heads, dim_head, dropout):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
ProjectInOut(sm_dim, lg_dim, PreNorm(lg_dim, Attention(lg_dim, heads = heads, dim_head = dim_head, dropout = dropout))),
ProjectInOut(lg_dim, sm_dim, PreNorm(sm_dim, Attention(sm_dim, heads = heads, dim_head = dim_head, dropout = dropout)))
]))
def forward(self, sm_tokens, lg_tokens):
for sm_attend_lg, lg_attend_sm in self.layers:
(sm_cls, sm_patch_tokens), (lg_cls, lg_patch_tokens) = map(lambda t: (t[:, :1], t[:, 1:]), (sm_tokens, lg_tokens))
sm_cls = sm_attend_lg(sm_cls, context = lg_patch_tokens, kv_include_self = True) + sm_cls
lg_cls = lg_attend_sm(lg_cls, context = sm_patch_tokens, kv_include_self = True) + lg_cls
return sm_tokens, lg_tokens
# multi-scale encoder
class MultiScaleEncoder(nn.Module):
def __init__(
self,
*,
depth,
sm_dim,
lg_dim,
sm_enc_params,
lg_enc_params,
cross_attn_heads,
cross_attn_depth,
cross_attn_dim_head = 64,
dropout = 0.
):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Transformer(dim = sm_dim, dropout = dropout, **sm_enc_params),
Transformer(dim = lg_dim, dropout = dropout, **lg_enc_params),
CrossTransformer(sm_dim = sm_dim, lg_dim = lg_dim, depth = cross_attn_depth, heads = cross_attn_heads, dim_head = cross_attn_dim_head, dropout = dropout)
]))
def forward(self, sm_tokens, lg_tokens):
for sm_enc, lg_enc, cross_attend in self.layers:
sm_tokens, lg_tokens = sm_enc(sm_tokens), lg_enc(lg_tokens)
sm_tokens, lg_tokens = cross_attend(sm_tokens, lg_tokens)
return sm_tokens, lg_tokens
# patch-based image to token embedder
class ImageEmbedder(nn.Module):
def __init__(
self,
*,
dim,
image_size,
patch_size,
dropout = 0.
):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
patch_dim = 3 * patch_size ** 2
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
nn.Linear(patch_dim, dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(dropout)
def forward(self, img):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
return self.dropout(x)
# cross ViT class
class CrossViT(nn.Module):
def __init__(
self,
*,
image_size,
num_classes,
sm_dim,
lg_dim,
sm_patch_size = 12,
sm_enc_depth = 1,
sm_enc_heads = 8,
sm_enc_mlp_dim = 2048,
sm_enc_dim_head = 64,
lg_patch_size = 16,
lg_enc_depth = 4,
lg_enc_heads = 8,
lg_enc_mlp_dim = 2048,
lg_enc_dim_head = 64,
cross_attn_depth = 2,
cross_attn_heads = 8,
cross_attn_dim_head = 64,
depth = 3,
dropout = 0.1,
emb_dropout = 0.1
):
super().__init__()
self.sm_image_embedder = ImageEmbedder(dim = sm_dim, image_size = image_size, patch_size = sm_patch_size, dropout = emb_dropout)
self.lg_image_embedder = ImageEmbedder(dim = lg_dim, image_size = image_size, patch_size = lg_patch_size, dropout = emb_dropout)
self.multi_scale_encoder = MultiScaleEncoder(
depth = depth,
sm_dim = sm_dim,
lg_dim = lg_dim,
cross_attn_heads = cross_attn_heads,
cross_attn_dim_head = cross_attn_dim_head,
cross_attn_depth = cross_attn_depth,
sm_enc_params = dict(
depth = sm_enc_depth,
heads = sm_enc_heads,
mlp_dim = sm_enc_mlp_dim,
dim_head = sm_enc_dim_head
),
lg_enc_params = dict(
depth = lg_enc_depth,
heads = lg_enc_heads,
mlp_dim = lg_enc_mlp_dim,
dim_head = lg_enc_dim_head
),
dropout = dropout
)
self.sm_mlp_head = nn.Sequential(nn.LayerNorm(sm_dim), nn.Linear(sm_dim, num_classes))
self.lg_mlp_head = nn.Sequential(nn.LayerNorm(lg_dim), nn.Linear(lg_dim, num_classes))
def forward(self, img):
sm_tokens = self.sm_image_embedder(img)
lg_tokens = self.lg_image_embedder(img)
sm_tokens, lg_tokens = self.multi_scale_encoder(sm_tokens, lg_tokens)
sm_cls, lg_cls = map(lambda t: t[:, 0], (sm_tokens, lg_tokens))
sm_logits = self.sm_mlp_head(sm_cls)
lg_logits = self.lg_mlp_head(lg_cls)
return sm_logits + lg_logits

View File

@@ -1,7 +1,7 @@
import torch
import torch.nn.functional as F
from torch import nn
from vit_pytorch.vit_pytorch import ViT
from vit_pytorch.vit import ViT
from vit_pytorch.t2t import T2TViT
from vit_pytorch.efficient import ViT as EfficientViT
@@ -15,7 +15,7 @@ def exists(val):
# classes
class DistillMixin:
def forward(self, img, distill_token = None, mask = None):
def forward(self, img, distill_token = None):
distilling = exists(distill_token)
x = self.to_patch_embedding(img)
b, n, _ = x.shape
@@ -28,7 +28,7 @@ class DistillMixin:
distill_tokens = repeat(distill_token, '() n d -> b n d', b = b)
x = torch.cat((x, distill_tokens), dim = 1)
x = self._attend(x, mask)
x = self._attend(x)
if distilling:
x, distill_tokens = x[:, :-1], x[:, -1]
@@ -56,9 +56,9 @@ class DistillableViT(DistillMixin, ViT):
v.load_state_dict(self.state_dict())
return v
def _attend(self, x, mask):
def _attend(self, x):
x = self.dropout(x)
x = self.transformer(x, mask)
x = self.transformer(x)
return x
class DistillableT2TViT(DistillMixin, T2TViT):
@@ -74,7 +74,7 @@ class DistillableT2TViT(DistillMixin, T2TViT):
v.load_state_dict(self.state_dict())
return v
def _attend(self, x, mask):
def _attend(self, x):
x = self.dropout(x)
x = self.transformer(x)
return x
@@ -92,7 +92,7 @@ class DistillableEfficientViT(DistillMixin, EfficientViT):
v.load_state_dict(self.state_dict())
return v
def _attend(self, x, mask):
def _attend(self, x):
return self.transformer(x)
# knowledge distillation wrapper

54
vit_pytorch/recorder.py Normal file
View File

@@ -0,0 +1,54 @@
from functools import wraps
import torch
from torch import nn
from vit_pytorch.vit import Attention
def find_modules(nn_module, type):
return [module for module in nn_module.modules() if isinstance(module, type)]
class Recorder(nn.Module):
def __init__(self, vit):
super().__init__()
self.vit = vit
self.data = None
self.recordings = []
self.hooks = []
self.hook_registered = False
self.ejected = False
def _hook(self, _, input, output):
self.recordings.append(output.clone().detach())
def _register_hook(self):
modules = find_modules(self.vit.transformer, Attention)
for module in modules:
handle = module.attend.register_forward_hook(self._hook)
self.hooks.append(handle)
self.hook_registered = True
def eject(self):
self.ejected = True
for hook in self.hooks:
hook.remove()
self.hooks.clear()
return self.vit
def clear(self):
self.recordings.clear()
def record(self, attn):
recording = attn.clone().detach()
self.recordings.append(recording)
def forward(self, img):
assert not self.ejected, 'recorder has been ejected, cannot be used anymore'
self.clear()
if not self.hook_registered:
self._register_hook()
pred = self.vit(img)
attns = torch.stack(self.recordings, dim = 1)
return pred, attns

View File

@@ -2,7 +2,7 @@ import math
import torch
from torch import nn
from vit_pytorch.vit_pytorch import Transformer
from vit_pytorch.vit import Transformer
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

115
vit_pytorch/vit.py Normal file
View File

@@ -0,0 +1,115 @@
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = self.attend(dots)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
nn.Linear(patch_dim, dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)