diff --git a/setup.py b/setup.py index 2af2fde..3d4324f 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.35.3', + version = '0.35.4', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', long_description_content_type = 'text/markdown', diff --git a/vit_pytorch/extractor.py b/vit_pytorch/extractor.py index aff9704..fc30ed1 100644 --- a/vit_pytorch/extractor.py +++ b/vit_pytorch/extractor.py @@ -14,6 +14,7 @@ class Extractor(nn.Module): self, vit, device = None, + layer = None, layer_name = 'transformer', layer_save_input = False, return_embeddings_only = False @@ -28,6 +29,7 @@ class Extractor(nn.Module): self.ejected = False self.device = device + self.layer = layer self.layer_name = layer_name self.layer_save_input = layer_save_input # whether to save input or output of layer self.return_embeddings_only = return_embeddings_only @@ -37,8 +39,12 @@ class Extractor(nn.Module): self.latents = apply_tuple_or_single(lambda t: t.clone().detach(), layer_output) def _register_hook(self): - assert hasattr(self.vit, self.layer_name), 'layer whose output to take as embedding not found in vision transformer' - layer = getattr(self.vit, self.layer_name) + if not exists(self.layer): + assert hasattr(self.vit, self.layer_name), 'layer whose output to take as embedding not found in vision transformer' + layer = getattr(self.vit, self.layer_name) + else: + layer = self.layer + handle = layer.register_forward_hook(self._hook) self.hooks.append(handle) self.hook_registered = True