Compare commits

...

1 Commits

2 changed files with 9 additions and 3 deletions

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.35.3',
version = '0.35.4',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description_content_type = 'text/markdown',

View File

@@ -14,6 +14,7 @@ class Extractor(nn.Module):
self,
vit,
device = None,
layer = None,
layer_name = 'transformer',
layer_save_input = False,
return_embeddings_only = False
@@ -28,6 +29,7 @@ class Extractor(nn.Module):
self.ejected = False
self.device = device
self.layer = layer
self.layer_name = layer_name
self.layer_save_input = layer_save_input # whether to save input or output of layer
self.return_embeddings_only = return_embeddings_only
@@ -37,8 +39,12 @@ class Extractor(nn.Module):
self.latents = apply_tuple_or_single(lambda t: t.clone().detach(), layer_output)
def _register_hook(self):
assert hasattr(self.vit, self.layer_name), 'layer whose output to take as embedding not found in vision transformer'
layer = getattr(self.vit, self.layer_name)
if not exists(self.layer):
assert hasattr(self.vit, self.layer_name), 'layer whose output to take as embedding not found in vision transformer'
layer = getattr(self.vit, self.layer_name)
else:
layer = self.layer
handle = layer.register_forward_hook(self._hook)
self.hooks.append(handle)
self.hook_registered = True