fix distill

This commit is contained in:
lucidrains
2025-12-10 15:52:10 -08:00
parent 5888f05300
commit 077d8c188f
2 changed files with 5 additions and 5 deletions

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "vit-pytorch" name = "vit-pytorch"
version = "1.16.4" version = "1.16.5"
description = "Vision Transformer (ViT) - Pytorch" description = "Vision Transformer (ViT) - Pytorch"
readme = { file = "README.md", content-type = "text/markdown" } readme = { file = "README.md", content-type = "text/markdown" }
license = { file = "LICENSE" } license = { file = "LICENSE" }

View File

@@ -25,12 +25,12 @@ class DistillMixin:
x = self.to_patch_embedding(img) x = self.to_patch_embedding(img)
b, n, _ = x.shape b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = b) cls_tokens = repeat(self.cls_token, 'n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim = 1) x = torch.cat((cls_tokens, x), dim = 1)
x += self.pos_embedding[:, :(n + 1)] x += self.pos_embedding[:(n + 1)]
if distilling: if distilling:
distill_tokens = repeat(distill_token, '1 n d -> b n d', b = b) distill_tokens = repeat(distill_token, 'n d -> b n d', b = b)
x = torch.cat((x, distill_tokens), dim = 1) x = torch.cat((x, distill_tokens), dim = 1)
x = self._attend(x) x = self._attend(x)
@@ -125,7 +125,7 @@ class DistillWrapper(Module):
self.alpha = alpha self.alpha = alpha
self.hard = hard self.hard = hard
self.distillation_token = nn.Parameter(torch.randn(1, 1, dim)) self.distillation_token = nn.Parameter(torch.randn(1, dim))
self.distill_mlp = nn.Sequential( self.distill_mlp = nn.Sequential(
nn.LayerNorm(dim) if mlp_layernorm else nn.Identity(), nn.LayerNorm(dim) if mlp_layernorm else nn.Identity(),