diff --git a/pyproject.toml b/pyproject.toml index c64261c..b0959d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "vit-pytorch" -version = "1.16.4" +version = "1.16.5" description = "Vision Transformer (ViT) - Pytorch" readme = { file = "README.md", content-type = "text/markdown" } license = { file = "LICENSE" } diff --git a/vit_pytorch/distill.py b/vit_pytorch/distill.py index b480e23..994c96e 100644 --- a/vit_pytorch/distill.py +++ b/vit_pytorch/distill.py @@ -25,12 +25,12 @@ class DistillMixin: x = self.to_patch_embedding(img) b, n, _ = x.shape - cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = b) + cls_tokens = repeat(self.cls_token, 'n d -> b n d', b = b) x = torch.cat((cls_tokens, x), dim = 1) - x += self.pos_embedding[:, :(n + 1)] + x += self.pos_embedding[:(n + 1)] if distilling: - distill_tokens = repeat(distill_token, '1 n d -> b n d', b = b) + distill_tokens = repeat(distill_token, 'n d -> b n d', b = b) x = torch.cat((x, distill_tokens), dim = 1) x = self._attend(x) @@ -125,7 +125,7 @@ class DistillWrapper(Module): self.alpha = alpha self.hard = hard - self.distillation_token = nn.Parameter(torch.randn(1, 1, dim)) + self.distillation_token = nn.Parameter(torch.randn(1, dim)) self.distill_mlp = nn.Sequential( nn.LayerNorm(dim) if mlp_layernorm else nn.Identity(),