From 8135d70e4e5ca292acaf7ee253324fa69a217d9c Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 29 Mar 2021 15:10:12 -0700 Subject: [PATCH] use hooks to retrieve attention maps for user without modifying ViT --- README.md | 23 +++++++++----- setup.py | 2 +- vit_pytorch/recorder.py | 68 +++++++++++++++++++++++++++-------------- vit_pytorch/vit.py | 23 +++++--------- 4 files changed, 70 insertions(+), 46 deletions(-) diff --git a/README.md b/README.md index b6b69bb..3f72cfc 100644 --- a/README.md +++ b/README.md @@ -219,10 +219,7 @@ If you would like to visualize the attention weights (post-softmax) for your res ```python import torch -from vit_pytorch import ViT - -from vit_pytorch.recorder import Recorder # import the Recorder and instantiate -rec = Recorder() +from vit_pytorch.vit import ViT v = ViT( image_size = 256, @@ -236,13 +233,25 @@ v = ViT( emb_dropout = 0.1 ) -img = torch.randn(1, 3, 256, 256) +# import Recorder and wrap the ViT -preds = v(img, rec = rec) # pass in the recorder +from vit_pytorch.recorder import Recorder +v = Recorder(v) + +# forward pass now returns predictions and the attention maps + +img = torch.randn(1, 3, 256, 256) +preds, attns = v(img) # there is one extra patch due to the CLS token -rec.attn # (1, 6, 16, 65, 65) - (batch x layers x heads x patch x patch) +attns # (1, 6, 16, 65, 65) - (batch x layers x heads x patch x patch) +``` + +to cleanup the class and the hooks once you have collected enough data + +```python +v = v.eject() # wrapper is discarded and original ViT instance is returned ``` ## Research Ideas diff --git a/setup.py b/setup.py index 6edb860..b1b5a8b 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.10.0', + version = '0.10.1', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', author = 'Phil Wang', diff --git a/vit_pytorch/recorder.py b/vit_pytorch/recorder.py index b663942..8760c30 100644 --- a/vit_pytorch/recorder.py +++ b/vit_pytorch/recorder.py @@ -1,35 +1,57 @@ from functools import wraps import torch +from torch import nn + +from vit_pytorch.vit import Attention def exists(val): return val is not None -def record_wrapper(fn): - @wraps(fn) - def inner(model, img, **kwargs): - rec = kwargs.pop('rec', None) - if exists(rec): - rec.clear() +def find_modules(nn_module, type): + return [module for module in nn_module.modules() if isinstance(module, type)] - out = fn(model, img, rec = rec, **kwargs) - - if exists(rec): - rec.finalize() - return out - return inner - -class Recorder(): - def __init__(self): +class Recorder(nn.Module): + def __init__(self, vit): super().__init__() - self._layer_attns = [] - self.attn = None + self.vit = vit + + self.data = None + self.recordings = [] + self.hooks = [] + self.hook_registered = False + self.ejected = False + + def _hook(self, _, input, output): + self.recordings.append(output.clone().detach()) + + def _register_hook(self): + modules = find_modules(self, Attention) + for module in modules: + handle = module.attend.register_forward_hook(self._hook) + self.hooks.append(handle) + self.hook_registered = True + + def eject(self): + self.ejected = True + for hook in self.hooks: + hook.remove() + self.hooks.clear() + return self.vit def clear(self): - self._layer_attns.clear() - self.attn = None - - def finalize(self): - self.attn = torch.stack(self._layer_attns, dim = 1) + self.recordings.clear() def record(self, attn): - self._layer_attns.append(attn.clone().detach()) + recording = attn.clone().detach() + self.recordings.append(recording) + + def forward(self, img): + assert not self.ejected, 'recorder has been ejected, cannot be used anymore' + self.clear() + + if not self.hook_registered: + self._register_hook() + + pred = self.vit(img) + attns = torch.stack(self.recordings, dim = 1) + return pred, attns diff --git a/vit_pytorch/vit.py b/vit_pytorch/vit.py index 2c33ce8..90d1977 100644 --- a/vit_pytorch/vit.py +++ b/vit_pytorch/vit.py @@ -5,8 +5,6 @@ import torch.nn.functional as F from einops import rearrange, repeat from einops.layers.torch import Rearrange -from vit_pytorch.recorder import record_wrapper - class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() @@ -37,6 +35,7 @@ class Attention(nn.Module): self.heads = heads self.scale = dim_head ** -0.5 + self.attend = nn.Softmax(dim = -1) self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) self.to_out = nn.Sequential( @@ -44,23 +43,18 @@ class Attention(nn.Module): nn.Dropout(dropout) ) if project_out else nn.Identity() - def forward(self, x, rec = None): + def forward(self, x): b, n, _, h = *x.shape, self.heads qkv = self.to_qkv(x).chunk(3, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale - attn = dots.softmax(dim=-1) + attn = self.attend(dots) out = einsum('b h i j, b h j d -> b h i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') - out = self.to_out(out) - - if rec is not None: - rec.record(attn) - - return out + return self.to_out(out) class Transformer(nn.Module): def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): @@ -71,9 +65,9 @@ class Transformer(nn.Module): PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) ])) - def forward(self, x, **kwargs): + def forward(self, x): for attn, ff in self.layers: - x = attn(x, **kwargs) + x + x = attn(x) + x x = ff(x) + x return x @@ -104,8 +98,7 @@ class ViT(nn.Module): nn.Linear(dim, num_classes) ) - @record_wrapper - def forward(self, img, **kwargs): + def forward(self, img): x = self.to_patch_embedding(img) b, n, _ = x.shape @@ -114,7 +107,7 @@ class ViT(nn.Module): x += self.pos_embedding[:, :(n + 1)] x = self.dropout(x) - x = self.transformer(x, **kwargs) + x = self.transformer(x) x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]