mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-29 23:52:27 +00:00
91 lines
2.4 KiB
Python
91 lines
2.4 KiB
Python
import torch
|
|
from torch import nn
|
|
|
|
def exists(val):
|
|
return val is not None
|
|
|
|
def identity(t):
|
|
return t
|
|
|
|
def clone_and_detach(t):
|
|
return t.clone().detach()
|
|
|
|
def apply_tuple_or_single(fn, val):
|
|
if isinstance(val, tuple):
|
|
return tuple(map(fn, val))
|
|
return fn(val)
|
|
|
|
class Extractor(nn.Module):
|
|
def __init__(
|
|
self,
|
|
vit,
|
|
device = None,
|
|
layer = None,
|
|
layer_name = 'transformer',
|
|
layer_save_input = False,
|
|
return_embeddings_only = False,
|
|
detach = True
|
|
):
|
|
super().__init__()
|
|
self.vit = vit
|
|
|
|
self.data = None
|
|
self.latents = None
|
|
self.hooks = []
|
|
self.hook_registered = False
|
|
self.ejected = False
|
|
self.device = device
|
|
|
|
self.layer = layer
|
|
self.layer_name = layer_name
|
|
self.layer_save_input = layer_save_input # whether to save input or output of layer
|
|
self.return_embeddings_only = return_embeddings_only
|
|
|
|
self.detach_fn = clone_and_detach if detach else identity
|
|
|
|
def _hook(self, _, inputs, output):
|
|
layer_output = inputs if self.layer_save_input else output
|
|
self.latents = apply_tuple_or_single(self.detach_fn, layer_output)
|
|
|
|
def _register_hook(self):
|
|
if not exists(self.layer):
|
|
assert hasattr(self.vit, self.layer_name), 'layer whose output to take as embedding not found in vision transformer'
|
|
layer = getattr(self.vit, self.layer_name)
|
|
else:
|
|
layer = self.layer
|
|
|
|
handle = layer.register_forward_hook(self._hook)
|
|
self.hooks.append(handle)
|
|
self.hook_registered = True
|
|
|
|
def eject(self):
|
|
self.ejected = True
|
|
for hook in self.hooks:
|
|
hook.remove()
|
|
self.hooks.clear()
|
|
return self.vit
|
|
|
|
def clear(self):
|
|
del self.latents
|
|
self.latents = None
|
|
|
|
def forward(
|
|
self,
|
|
img,
|
|
return_embeddings_only = False
|
|
):
|
|
assert not self.ejected, 'extractor has been ejected, cannot be used anymore'
|
|
self.clear()
|
|
if not self.hook_registered:
|
|
self._register_hook()
|
|
|
|
pred = self.vit(img)
|
|
|
|
target_device = self.device if exists(self.device) else img.device
|
|
latents = apply_tuple_or_single(lambda t: t.to(target_device), self.latents)
|
|
|
|
if return_embeddings_only or self.return_embeddings_only:
|
|
return latents
|
|
|
|
return pred, latents
|