offer way for extractor to return latents without detaching them

This commit is contained in:
Phil Wang
2022-07-16 16:22:40 -07:00
parent 2fa2b62def
commit f86e052c05
2 changed files with 12 additions and 3 deletions

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.35.7',
version = '0.35.8',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description_content_type = 'text/markdown',

View File

@@ -4,6 +4,12 @@ from torch import nn
def exists(val):
return val is not None
def identity(t):
return t
def clone_and_detach(t):
return t.clone().detach()
def apply_tuple_or_single(fn, val):
if isinstance(val, tuple):
return tuple(map(fn, val))
@@ -17,7 +23,8 @@ class Extractor(nn.Module):
layer = None,
layer_name = 'transformer',
layer_save_input = False,
return_embeddings_only = False
return_embeddings_only = False,
detach = True
):
super().__init__()
self.vit = vit
@@ -34,9 +41,11 @@ class Extractor(nn.Module):
self.layer_save_input = layer_save_input # whether to save input or output of layer
self.return_embeddings_only = return_embeddings_only
self.detach_fn = clone_and_detach if detach else identity
def _hook(self, _, inputs, output):
layer_output = inputs if self.layer_save_input else output
self.latents = apply_tuple_or_single(lambda t: t.clone().detach(), layer_output)
self.latents = apply_tuple_or_single(self.detach_fn, layer_output)
def _register_hook(self):
if not exists(self.layer):