able to return embed from vit-nd-rotary

This commit is contained in:
lucidrains
2025-09-23 07:21:34 -07:00
parent 845c844b3b
commit f6bc14c81d
2 changed files with 25 additions and 11 deletions

View File

@@ -6,7 +6,7 @@ with open('README.md') as f:
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.12.1',
version = '1.12.2',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description = long_description,

View File

@@ -5,7 +5,7 @@ from torch import nn, arange, cat, stack, Tensor
from torch.nn import Module, ModuleList
import torch.nn.functional as F
from einops import rearrange, repeat, reduce
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange
# helpers
@@ -245,7 +245,11 @@ class ViTND(Module):
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes)
def forward(self, x):
def forward(
self,
x,
return_embed = False
):
x = self.to_patch_embedding(x) # (b, *spatial_dims, patch_dim)
batch, *spatial_dims, _, device = *x.shape, x.device
@@ -259,16 +263,24 @@ class ViTND(Module):
# flatten spatial dimensions for attention with nd rotary
pos = repeat(pos, '... p -> b (...) p', b = batch)
x = rearrange(x, 'b ... d -> b (...) d')
x, packed_shape = pack([x], 'b * d')
x = self.dropout(x)
x = self.transformer(x, pos)
x = reduce(x, 'b n d -> b d', 'mean')
x = self.to_latent(x)
return self.mlp_head(x)
embed = self.transformer(x, pos)
# return the embed with reconstituted patch shape
if return_embed:
embed, = unpack(embed, packed_shape, 'b * d')
return embed
# pooling to logits
pooled = reduce(embed, 'b n d -> b d', 'mean')
pooled = self.to_latent(pooled)
return self.mlp_head(pooled)
if __name__ == '__main__':
@@ -288,5 +300,7 @@ if __name__ == '__main__':
)
data = torch.randn(2, 3, 4, 8, 16, 32, 64)
logits = model(data)
embed = model(data, return_embed = True) # (2, 2, 4, 4, 8, 8, 512)