use hooks to retrieve attention maps for user without modifying ViT

This commit is contained in:
Phil Wang
2021-03-29 15:10:12 -07:00
parent 3067155cea
commit 8135d70e4e
4 changed files with 70 additions and 46 deletions

View File

@@ -219,10 +219,7 @@ If you would like to visualize the attention weights (post-softmax) for your res
```python
import torch
from vit_pytorch import ViT
from vit_pytorch.recorder import Recorder # import the Recorder and instantiate
rec = Recorder()
from vit_pytorch.vit import ViT
v = ViT(
image_size = 256,
@@ -236,13 +233,25 @@ v = ViT(
emb_dropout = 0.1
)
img = torch.randn(1, 3, 256, 256)
# import Recorder and wrap the ViT
preds = v(img, rec = rec) # pass in the recorder
from vit_pytorch.recorder import Recorder
v = Recorder(v)
# forward pass now returns predictions and the attention maps
img = torch.randn(1, 3, 256, 256)
preds, attns = v(img)
# there is one extra patch due to the CLS token
rec.attn # (1, 6, 16, 65, 65) - (batch x layers x heads x patch x patch)
attns # (1, 6, 16, 65, 65) - (batch x layers x heads x patch x patch)
```
to cleanup the class and the hooks once you have collected enough data
```python
v = v.eject() # wrapper is discarded and original ViT instance is returned
```
## Research Ideas

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.10.0',
version = '0.10.1',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',

View File

@@ -1,35 +1,57 @@
from functools import wraps
import torch
from torch import nn
from vit_pytorch.vit import Attention
def exists(val):
return val is not None
def record_wrapper(fn):
@wraps(fn)
def inner(model, img, **kwargs):
rec = kwargs.pop('rec', None)
if exists(rec):
rec.clear()
def find_modules(nn_module, type):
return [module for module in nn_module.modules() if isinstance(module, type)]
out = fn(model, img, rec = rec, **kwargs)
if exists(rec):
rec.finalize()
return out
return inner
class Recorder():
def __init__(self):
class Recorder(nn.Module):
def __init__(self, vit):
super().__init__()
self._layer_attns = []
self.attn = None
self.vit = vit
self.data = None
self.recordings = []
self.hooks = []
self.hook_registered = False
self.ejected = False
def _hook(self, _, input, output):
self.recordings.append(output.clone().detach())
def _register_hook(self):
modules = find_modules(self, Attention)
for module in modules:
handle = module.attend.register_forward_hook(self._hook)
self.hooks.append(handle)
self.hook_registered = True
def eject(self):
self.ejected = True
for hook in self.hooks:
hook.remove()
self.hooks.clear()
return self.vit
def clear(self):
self._layer_attns.clear()
self.attn = None
def finalize(self):
self.attn = torch.stack(self._layer_attns, dim = 1)
self.recordings.clear()
def record(self, attn):
self._layer_attns.append(attn.clone().detach())
recording = attn.clone().detach()
self.recordings.append(recording)
def forward(self, img):
assert not self.ejected, 'recorder has been ejected, cannot be used anymore'
self.clear()
if not self.hook_registered:
self._register_hook()
pred = self.vit(img)
attns = torch.stack(self.recordings, dim = 1)
return pred, attns

View File

@@ -5,8 +5,6 @@ import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from vit_pytorch.recorder import record_wrapper
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
@@ -37,6 +35,7 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
@@ -44,23 +43,18 @@ class Attention(nn.Module):
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x, rec = None):
def forward(self, x):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = dots.softmax(dim=-1)
attn = self.attend(dots)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
if rec is not None:
rec.record(attn)
return out
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
@@ -71,9 +65,9 @@ class Transformer(nn.Module):
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x, **kwargs):
def forward(self, x):
for attn, ff in self.layers:
x = attn(x, **kwargs) + x
x = attn(x) + x
x = ff(x) + x
return x
@@ -104,8 +98,7 @@ class ViT(nn.Module):
nn.Linear(dim, num_classes)
)
@record_wrapper
def forward(self, img, **kwargs):
def forward(self, img):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
@@ -114,7 +107,7 @@ class ViT(nn.Module):
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x, **kwargs)
x = self.transformer(x)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]