Compare commits

...

1 Commits

Author SHA1 Message Date
Phil Wang
da6699bb41 offer hard distillation as well 2021-04-01 16:41:46 -07:00
3 changed files with 15 additions and 8 deletions

View File

@@ -93,7 +93,8 @@ distiller = DistillWrapper(
student = v,
teacher = teacher,
temperature = 3, # temperature of distillation
alpha = 0.5 # trade between main loss and distillation loss
alpha = 0.5, # trade between main loss and distillation loss
hard = False # whether to use soft or hard distillation
)
img = torch.randn(2, 3, 256, 256)

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.14.2',
version = '0.14.3',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',

View File

@@ -104,7 +104,8 @@ class DistillWrapper(nn.Module):
teacher,
student,
temperature = 1.,
alpha = 0.5
alpha = 0.5,
hard = False
):
super().__init__()
assert (isinstance(student, (DistillableViT, DistillableT2TViT, DistillableEfficientViT))) , 'student must be a vision transformer'
@@ -116,6 +117,7 @@ class DistillWrapper(nn.Module):
num_classes = student.num_classes
self.temperature = temperature
self.alpha = alpha
self.hard = hard
self.distillation_token = nn.Parameter(torch.randn(1, 1, dim))
@@ -137,11 +139,15 @@ class DistillWrapper(nn.Module):
loss = F.cross_entropy(student_logits, labels)
distill_loss = F.kl_div(
F.log_softmax(distill_logits / T, dim = -1),
F.softmax(teacher_logits / T, dim = -1).detach(),
reduction = 'batchmean')
if not self.hard:
distill_loss = F.kl_div(
F.log_softmax(distill_logits / T, dim = -1),
F.softmax(teacher_logits / T, dim = -1).detach(),
reduction = 'batchmean')
else:
teacher_labels = teacher_logits.argmax(dim = -1)
distill_loss = F.cross_entropy(student_logits, teacher_labels)
distill_loss *= T ** 2
return loss * alpha + distill_loss * (1 - alpha)