diff --git a/setup.py b/setup.py index 5ad1dde..4f4685a 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ with open('README.md') as f: setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '1.6.9', + version = '1.7.0', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', long_description=long_description, diff --git a/vit_pytorch/distill.py b/vit_pytorch/distill.py index 79bf8c5..b480e23 100644 --- a/vit_pytorch/distill.py +++ b/vit_pytorch/distill.py @@ -1,6 +1,8 @@ import torch -import torch.nn.functional as F from torch import nn +from torch.nn import Module +import torch.nn.functional as F + from vit_pytorch.vit import ViT from vit_pytorch.t2t import T2TViT from vit_pytorch.efficient import ViT as EfficientViT @@ -12,6 +14,9 @@ from einops import rearrange, repeat def exists(val): return val is not None +def default(val, d): + return val if exists(val) else d + # classes class DistillMixin: @@ -20,12 +25,12 @@ class DistillMixin: x = self.to_patch_embedding(img) b, n, _ = x.shape - cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) + cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = b) x = torch.cat((cls_tokens, x), dim = 1) x += self.pos_embedding[:, :(n + 1)] if distilling: - distill_tokens = repeat(distill_token, '() n d -> b n d', b = b) + distill_tokens = repeat(distill_token, '1 n d -> b n d', b = b) x = torch.cat((x, distill_tokens), dim = 1) x = self._attend(x) @@ -97,7 +102,7 @@ class DistillableEfficientViT(DistillMixin, EfficientViT): # knowledge distillation wrapper -class DistillWrapper(nn.Module): +class DistillWrapper(Module): def __init__( self, *, @@ -105,7 +110,8 @@ class DistillWrapper(nn.Module): student, temperature = 1., alpha = 0.5, - hard = False + hard = False, + mlp_layernorm = False ): super().__init__() assert (isinstance(student, (DistillableViT, DistillableT2TViT, DistillableEfficientViT))) , 'student must be a vision transformer' @@ -122,14 +128,14 @@ class DistillWrapper(nn.Module): self.distillation_token = nn.Parameter(torch.randn(1, 1, dim)) self.distill_mlp = nn.Sequential( - nn.LayerNorm(dim), + nn.LayerNorm(dim) if mlp_layernorm else nn.Identity(), nn.Linear(dim, num_classes) ) def forward(self, img, labels, temperature = None, alpha = None, **kwargs): - b, *_ = img.shape - alpha = alpha if exists(alpha) else self.alpha - T = temperature if exists(temperature) else self.temperature + + alpha = default(alpha, self.alpha) + T = default(temperature, self.temperature) with torch.no_grad(): teacher_logits = self.teacher(img) diff --git a/vit_pytorch/t2t.py b/vit_pytorch/t2t.py index 0ccc7d7..c70004c 100644 --- a/vit_pytorch/t2t.py +++ b/vit_pytorch/t2t.py @@ -61,10 +61,7 @@ class T2TViT(nn.Module): self.pool = pool self.to_latent = nn.Identity() - self.mlp_head = nn.Sequential( - nn.LayerNorm(dim), - nn.Linear(dim, num_classes) - ) + self.mlp_head = nn.Linear(dim, num_classes) def forward(self, img): x = self.to_patch_embedding(img)