From dc4b3327ce8f2b5050aa61a3c833f5109daafb79 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 24 Dec 2020 11:11:58 -0800 Subject: [PATCH] no grad for teacher in distillation --- README.md | 3 +-- setup.py | 2 +- vit_pytorch/distill.py | 4 +++- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index c1083be..6bcd28c 100644 --- a/README.md +++ b/README.md @@ -83,8 +83,7 @@ v = DistillableViT( heads = 8, mlp_dim = 2048, dropout = 0.1, - emb_dropout = 0.1, - pool = 'mean' + emb_dropout = 0.1 ) distiller = DistillWrapper( diff --git a/setup.py b/setup.py index dae4150..0cc51f3 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.6.0', + version = '0.6.1', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', author = 'Phil Wang', diff --git a/vit_pytorch/distill.py b/vit_pytorch/distill.py index c509107..28b4b5b 100644 --- a/vit_pytorch/distill.py +++ b/vit_pytorch/distill.py @@ -72,7 +72,9 @@ class DistillWrapper(nn.Module): b, *_, alpha = *img.shape, self.alpha T = temperature if exists(temperature) else self.temperature - teacher_logits = self.teacher(img) + with torch.no_grad(): + teacher_logits = self.teacher(img) + student_logits, distill_tokens = self.student(img, distill_token = self.distillation_token, **kwargs) distill_logits = self.distill_mlp(distill_tokens)