mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
fix distill
This commit is contained in:
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "vit-pytorch"
|
name = "vit-pytorch"
|
||||||
version = "1.16.4"
|
version = "1.16.5"
|
||||||
description = "Vision Transformer (ViT) - Pytorch"
|
description = "Vision Transformer (ViT) - Pytorch"
|
||||||
readme = { file = "README.md", content-type = "text/markdown" }
|
readme = { file = "README.md", content-type = "text/markdown" }
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
|
|||||||
@@ -25,12 +25,12 @@ class DistillMixin:
|
|||||||
x = self.to_patch_embedding(img)
|
x = self.to_patch_embedding(img)
|
||||||
b, n, _ = x.shape
|
b, n, _ = x.shape
|
||||||
|
|
||||||
cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = b)
|
cls_tokens = repeat(self.cls_token, 'n d -> b n d', b = b)
|
||||||
x = torch.cat((cls_tokens, x), dim = 1)
|
x = torch.cat((cls_tokens, x), dim = 1)
|
||||||
x += self.pos_embedding[:, :(n + 1)]
|
x += self.pos_embedding[:(n + 1)]
|
||||||
|
|
||||||
if distilling:
|
if distilling:
|
||||||
distill_tokens = repeat(distill_token, '1 n d -> b n d', b = b)
|
distill_tokens = repeat(distill_token, 'n d -> b n d', b = b)
|
||||||
x = torch.cat((x, distill_tokens), dim = 1)
|
x = torch.cat((x, distill_tokens), dim = 1)
|
||||||
|
|
||||||
x = self._attend(x)
|
x = self._attend(x)
|
||||||
@@ -125,7 +125,7 @@ class DistillWrapper(Module):
|
|||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
self.hard = hard
|
self.hard = hard
|
||||||
|
|
||||||
self.distillation_token = nn.Parameter(torch.randn(1, 1, dim))
|
self.distillation_token = nn.Parameter(torch.randn(1, dim))
|
||||||
|
|
||||||
self.distill_mlp = nn.Sequential(
|
self.distill_mlp = nn.Sequential(
|
||||||
nn.LayerNorm(dim) if mlp_layernorm else nn.Identity(),
|
nn.LayerNorm(dim) if mlp_layernorm else nn.Identity(),
|
||||||
|
|||||||
Reference in New Issue
Block a user