diff --git a/setup.py b/setup.py index 31e83b0..7fdb521 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ with open('README.md') as f: setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '1.12.1', + version = '1.12.2', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', long_description = long_description, diff --git a/vit_pytorch/vit_nd_rotary.py b/vit_pytorch/vit_nd_rotary.py index cd814a8..afe3158 100644 --- a/vit_pytorch/vit_nd_rotary.py +++ b/vit_pytorch/vit_nd_rotary.py @@ -5,7 +5,7 @@ from torch import nn, arange, cat, stack, Tensor from torch.nn import Module, ModuleList import torch.nn.functional as F -from einops import rearrange, repeat, reduce +from einops import rearrange, repeat, reduce, pack, unpack from einops.layers.torch import Rearrange # helpers @@ -245,7 +245,11 @@ class ViTND(Module): self.to_latent = nn.Identity() self.mlp_head = nn.Linear(dim, num_classes) - def forward(self, x): + def forward( + self, + x, + return_embed = False + ): x = self.to_patch_embedding(x) # (b, *spatial_dims, patch_dim) batch, *spatial_dims, _, device = *x.shape, x.device @@ -259,16 +263,24 @@ class ViTND(Module): # flatten spatial dimensions for attention with nd rotary pos = repeat(pos, '... p -> b (...) p', b = batch) - x = rearrange(x, 'b ... d -> b (...) d') + x, packed_shape = pack([x], 'b * d') x = self.dropout(x) - x = self.transformer(x, pos) - - x = reduce(x, 'b n d -> b d', 'mean') - - x = self.to_latent(x) - return self.mlp_head(x) + embed = self.transformer(x, pos) + + # return the embed with reconstituted patch shape + + if return_embed: + embed, = unpack(embed, packed_shape, 'b * d') + return embed + + # pooling to logits + + pooled = reduce(embed, 'b n d -> b d', 'mean') + + pooled = self.to_latent(pooled) + return self.mlp_head(pooled) if __name__ == '__main__': @@ -288,5 +300,7 @@ if __name__ == '__main__': ) data = torch.randn(2, 3, 4, 8, 16, 32, 64) - + logits = model(data) + + embed = model(data, return_embed = True) # (2, 2, 4, 4, 8, 8, 512)