mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
use hooks to retrieve attention maps for user without modifying ViT
This commit is contained in:
23
README.md
23
README.md
@@ -219,10 +219,7 @@ If you would like to visualize the attention weights (post-softmax) for your res
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch import ViT
|
||||
|
||||
from vit_pytorch.recorder import Recorder # import the Recorder and instantiate
|
||||
rec = Recorder()
|
||||
from vit_pytorch.vit import ViT
|
||||
|
||||
v = ViT(
|
||||
image_size = 256,
|
||||
@@ -236,13 +233,25 @@ v = ViT(
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 256, 256)
|
||||
# import Recorder and wrap the ViT
|
||||
|
||||
preds = v(img, rec = rec) # pass in the recorder
|
||||
from vit_pytorch.recorder import Recorder
|
||||
v = Recorder(v)
|
||||
|
||||
# forward pass now returns predictions and the attention maps
|
||||
|
||||
img = torch.randn(1, 3, 256, 256)
|
||||
preds, attns = v(img)
|
||||
|
||||
# there is one extra patch due to the CLS token
|
||||
|
||||
rec.attn # (1, 6, 16, 65, 65) - (batch x layers x heads x patch x patch)
|
||||
attns # (1, 6, 16, 65, 65) - (batch x layers x heads x patch x patch)
|
||||
```
|
||||
|
||||
to cleanup the class and the hooks once you have collected enough data
|
||||
|
||||
```python
|
||||
v = v.eject() # wrapper is discarded and original ViT instance is returned
|
||||
```
|
||||
|
||||
## Research Ideas
|
||||
|
||||
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.10.0',
|
||||
version = '0.10.1',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
|
||||
@@ -1,35 +1,57 @@
|
||||
from functools import wraps
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vit_pytorch.vit import Attention
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def record_wrapper(fn):
|
||||
@wraps(fn)
|
||||
def inner(model, img, **kwargs):
|
||||
rec = kwargs.pop('rec', None)
|
||||
if exists(rec):
|
||||
rec.clear()
|
||||
def find_modules(nn_module, type):
|
||||
return [module for module in nn_module.modules() if isinstance(module, type)]
|
||||
|
||||
out = fn(model, img, rec = rec, **kwargs)
|
||||
|
||||
if exists(rec):
|
||||
rec.finalize()
|
||||
return out
|
||||
return inner
|
||||
|
||||
class Recorder():
|
||||
def __init__(self):
|
||||
class Recorder(nn.Module):
|
||||
def __init__(self, vit):
|
||||
super().__init__()
|
||||
self._layer_attns = []
|
||||
self.attn = None
|
||||
self.vit = vit
|
||||
|
||||
self.data = None
|
||||
self.recordings = []
|
||||
self.hooks = []
|
||||
self.hook_registered = False
|
||||
self.ejected = False
|
||||
|
||||
def _hook(self, _, input, output):
|
||||
self.recordings.append(output.clone().detach())
|
||||
|
||||
def _register_hook(self):
|
||||
modules = find_modules(self, Attention)
|
||||
for module in modules:
|
||||
handle = module.attend.register_forward_hook(self._hook)
|
||||
self.hooks.append(handle)
|
||||
self.hook_registered = True
|
||||
|
||||
def eject(self):
|
||||
self.ejected = True
|
||||
for hook in self.hooks:
|
||||
hook.remove()
|
||||
self.hooks.clear()
|
||||
return self.vit
|
||||
|
||||
def clear(self):
|
||||
self._layer_attns.clear()
|
||||
self.attn = None
|
||||
|
||||
def finalize(self):
|
||||
self.attn = torch.stack(self._layer_attns, dim = 1)
|
||||
self.recordings.clear()
|
||||
|
||||
def record(self, attn):
|
||||
self._layer_attns.append(attn.clone().detach())
|
||||
recording = attn.clone().detach()
|
||||
self.recordings.append(recording)
|
||||
|
||||
def forward(self, img):
|
||||
assert not self.ejected, 'recorder has been ejected, cannot be used anymore'
|
||||
self.clear()
|
||||
|
||||
if not self.hook_registered:
|
||||
self._register_hook()
|
||||
|
||||
pred = self.vit(img)
|
||||
attns = torch.stack(self.recordings, dim = 1)
|
||||
return pred, attns
|
||||
|
||||
@@ -5,8 +5,6 @@ import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
from vit_pytorch.recorder import record_wrapper
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
@@ -37,6 +35,7 @@ class Attention(nn.Module):
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
@@ -44,23 +43,18 @@ class Attention(nn.Module):
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x, rec = None):
|
||||
def forward(self, x):
|
||||
b, n, _, h = *x.shape, self.heads
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
||||
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
|
||||
attn = dots.softmax(dim=-1)
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
out = self.to_out(out)
|
||||
|
||||
if rec is not None:
|
||||
rec.record(attn)
|
||||
|
||||
return out
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
@@ -71,9 +65,9 @@ class Transformer(nn.Module):
|
||||
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
|
||||
]))
|
||||
def forward(self, x, **kwargs):
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x, **kwargs) + x
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
@@ -104,8 +98,7 @@ class ViT(nn.Module):
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
@record_wrapper
|
||||
def forward(self, img, **kwargs):
|
||||
def forward(self, img):
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
@@ -114,7 +107,7 @@ class ViT(nn.Module):
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x, **kwargs)
|
||||
x = self.transformer(x)
|
||||
|
||||
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user