make extractor flexible for layers that output multiple tensors, show CrossViT example

This commit is contained in:
Phil Wang
2022-06-19 08:11:41 -07:00
parent b3e90a2652
commit 4e62e5f05e
3 changed files with 51 additions and 4 deletions

View File

@@ -1255,6 +1255,47 @@ logits, embeddings = v(img)
embeddings # (1, 65, 1024) - (batch x patches x model dim)
```
Or say for `CrossViT`, which has a multi-scale encoder that outputs two sets of embeddings for 'large' and 'small' scales
```python
import torch
from vit_pytorch.cross_vit import CrossViT
v = CrossViT(
image_size = 256,
num_classes = 1000,
depth = 4,
sm_dim = 192,
sm_patch_size = 16,
sm_enc_depth = 2,
sm_enc_heads = 8,
sm_enc_mlp_dim = 2048,
lg_dim = 384,
lg_patch_size = 64,
lg_enc_depth = 3,
lg_enc_heads = 8,
lg_enc_mlp_dim = 2048,
cross_attn_depth = 2,
cross_attn_heads = 8,
dropout = 0.1,
emb_dropout = 0.1
)
# wrap the CrossViT
from vit_pytorch.extractor import Extractor
v = Extractor(v, layer_name = 'multi_scale_encoder') # take embedding coming from the output of multi-scale-encoder
# forward pass now returns predictions and the attention maps
img = torch.randn(1, 3, 256, 256)
logits, embeddings = v(img)
# there is one extra token due to the CLS token
embeddings # ((1, 257, 192), (1, 17, 384)) - (batch x patches x dimension) <- large and small scales respectively
```
## Research Ideas
### Efficient Attention

View File

@@ -3,9 +3,10 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.35.2',
version = '0.35.3',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description_content_type = 'text/markdown',
author = 'Phil Wang',
author_email = 'lucidrains@gmail.com',
url = 'https://github.com/lucidrains/vit-pytorch',

View File

@@ -4,6 +4,11 @@ from torch import nn
def exists(val):
return val is not None
def apply_tuple_or_single(fn, val):
if isinstance(val, tuple):
return tuple(map(fn, val))
return fn(val)
class Extractor(nn.Module):
def __init__(
self,
@@ -28,8 +33,8 @@ class Extractor(nn.Module):
self.return_embeddings_only = return_embeddings_only
def _hook(self, _, inputs, output):
tensor_to_save = inputs if self.layer_save_input else output
self.latents = tensor_to_save.clone().detach()
layer_output = inputs if self.layer_save_input else output
self.latents = apply_tuple_or_single(lambda t: t.clone().detach(), layer_output)
def _register_hook(self):
assert hasattr(self.vit, self.layer_name), 'layer whose output to take as embedding not found in vision transformer'
@@ -62,7 +67,7 @@ class Extractor(nn.Module):
pred = self.vit(img)
target_device = self.device if exists(self.device) else img.device
latents = self.latents.to(target_device)
latents = apply_tuple_or_single(lambda t: t.to(target_device), self.latents)
if return_embeddings_only or self.return_embeddings_only:
return latents