diff --git a/README.md b/README.md index 4afca12..023f162 100644 --- a/README.md +++ b/README.md @@ -1255,6 +1255,47 @@ logits, embeddings = v(img) embeddings # (1, 65, 1024) - (batch x patches x model dim) ``` +Or say for `CrossViT`, which has a multi-scale encoder that outputs two sets of embeddings for 'large' and 'small' scales + +```python +import torch +from vit_pytorch.cross_vit import CrossViT + +v = CrossViT( + image_size = 256, + num_classes = 1000, + depth = 4, + sm_dim = 192, + sm_patch_size = 16, + sm_enc_depth = 2, + sm_enc_heads = 8, + sm_enc_mlp_dim = 2048, + lg_dim = 384, + lg_patch_size = 64, + lg_enc_depth = 3, + lg_enc_heads = 8, + lg_enc_mlp_dim = 2048, + cross_attn_depth = 2, + cross_attn_heads = 8, + dropout = 0.1, + emb_dropout = 0.1 +) + +# wrap the CrossViT + +from vit_pytorch.extractor import Extractor +v = Extractor(v, layer_name = 'multi_scale_encoder') # take embedding coming from the output of multi-scale-encoder + +# forward pass now returns predictions and the attention maps + +img = torch.randn(1, 3, 256, 256) +logits, embeddings = v(img) + +# there is one extra token due to the CLS token + +embeddings # ((1, 257, 192), (1, 17, 384)) - (batch x patches x dimension) <- large and small scales respectively +``` + ## Research Ideas ### Efficient Attention diff --git a/setup.py b/setup.py index 74bd038..2af2fde 100644 --- a/setup.py +++ b/setup.py @@ -3,9 +3,10 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.35.2', + version = '0.35.3', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', + long_description_content_type = 'text/markdown', author = 'Phil Wang', author_email = 'lucidrains@gmail.com', url = 'https://github.com/lucidrains/vit-pytorch', diff --git a/vit_pytorch/extractor.py b/vit_pytorch/extractor.py index 297c577..aff9704 100644 --- a/vit_pytorch/extractor.py +++ b/vit_pytorch/extractor.py @@ -4,6 +4,11 @@ from torch import nn def exists(val): return val is not None +def apply_tuple_or_single(fn, val): + if isinstance(val, tuple): + return tuple(map(fn, val)) + return fn(val) + class Extractor(nn.Module): def __init__( self, @@ -28,8 +33,8 @@ class Extractor(nn.Module): self.return_embeddings_only = return_embeddings_only def _hook(self, _, inputs, output): - tensor_to_save = inputs if self.layer_save_input else output - self.latents = tensor_to_save.clone().detach() + layer_output = inputs if self.layer_save_input else output + self.latents = apply_tuple_or_single(lambda t: t.clone().detach(), layer_output) def _register_hook(self): assert hasattr(self.vit, self.layer_name), 'layer whose output to take as embedding not found in vision transformer' @@ -62,7 +67,7 @@ class Extractor(nn.Module): pred = self.vit(img) target_device = self.device if exists(self.device) else img.device - latents = self.latents.to(target_device) + latents = apply_tuple_or_single(lambda t: t.to(target_device), self.latents) if return_embeddings_only or self.return_embeddings_only: return latents