Files
vit-pytorch/vit_pytorch/extractor.py

91 lines
2.4 KiB
Python
Raw Normal View History

2021-12-21 11:11:39 -08:00
import torch
from torch import nn
def exists(val):
return val is not None
def identity(t):
return t
def clone_and_detach(t):
return t.clone().detach()
def apply_tuple_or_single(fn, val):
if isinstance(val, tuple):
return tuple(map(fn, val))
return fn(val)
2021-12-21 11:11:39 -08:00
class Extractor(nn.Module):
def __init__(
self,
vit,
device = None,
layer = None,
layer_name = 'transformer',
layer_save_input = False,
return_embeddings_only = False,
detach = True
):
2021-12-21 11:11:39 -08:00
super().__init__()
self.vit = vit
self.data = None
self.latents = None
self.hooks = []
self.hook_registered = False
self.ejected = False
self.device = device
self.layer = layer
self.layer_name = layer_name
self.layer_save_input = layer_save_input # whether to save input or output of layer
self.return_embeddings_only = return_embeddings_only
self.detach_fn = clone_and_detach if detach else identity
def _hook(self, _, inputs, output):
layer_output = inputs if self.layer_save_input else output
self.latents = apply_tuple_or_single(self.detach_fn, layer_output)
2021-12-21 11:11:39 -08:00
def _register_hook(self):
if not exists(self.layer):
assert hasattr(self.vit, self.layer_name), 'layer whose output to take as embedding not found in vision transformer'
layer = getattr(self.vit, self.layer_name)
else:
layer = self.layer
handle = layer.register_forward_hook(self._hook)
2021-12-21 11:11:39 -08:00
self.hooks.append(handle)
self.hook_registered = True
def eject(self):
self.ejected = True
for hook in self.hooks:
hook.remove()
self.hooks.clear()
return self.vit
def clear(self):
del self.latents
self.latents = None
def forward(
self,
img,
return_embeddings_only = False
):
2021-12-21 11:11:39 -08:00
assert not self.ejected, 'extractor has been ejected, cannot be used anymore'
self.clear()
if not self.hook_registered:
self._register_hook()
pred = self.vit(img)
target_device = self.device if exists(self.device) else img.device
latents = apply_tuple_or_single(lambda t: t.to(target_device), self.latents)
if return_embeddings_only or self.return_embeddings_only:
return latents
2021-12-21 11:11:39 -08:00
return pred, latents