diff --git a/setup.py b/setup.py index 256a0be..426a69c 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.35.7', + version = '0.35.8', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', long_description_content_type = 'text/markdown', diff --git a/vit_pytorch/extractor.py b/vit_pytorch/extractor.py index fc30ed1..59bc300 100644 --- a/vit_pytorch/extractor.py +++ b/vit_pytorch/extractor.py @@ -4,6 +4,12 @@ from torch import nn def exists(val): return val is not None +def identity(t): + return t + +def clone_and_detach(t): + return t.clone().detach() + def apply_tuple_or_single(fn, val): if isinstance(val, tuple): return tuple(map(fn, val)) @@ -17,7 +23,8 @@ class Extractor(nn.Module): layer = None, layer_name = 'transformer', layer_save_input = False, - return_embeddings_only = False + return_embeddings_only = False, + detach = True ): super().__init__() self.vit = vit @@ -34,9 +41,11 @@ class Extractor(nn.Module): self.layer_save_input = layer_save_input # whether to save input or output of layer self.return_embeddings_only = return_embeddings_only + self.detach_fn = clone_and_detach if detach else identity + def _hook(self, _, inputs, output): layer_output = inputs if self.layer_save_input else output - self.latents = apply_tuple_or_single(lambda t: t.clone().detach(), layer_output) + self.latents = apply_tuple_or_single(self.detach_fn, layer_output) def _register_hook(self): if not exists(self.layer):