mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
da6699bb41 |
@@ -93,7 +93,8 @@ distiller = DistillWrapper(
|
||||
student = v,
|
||||
teacher = teacher,
|
||||
temperature = 3, # temperature of distillation
|
||||
alpha = 0.5 # trade between main loss and distillation loss
|
||||
alpha = 0.5, # trade between main loss and distillation loss
|
||||
hard = False # whether to use soft or hard distillation
|
||||
)
|
||||
|
||||
img = torch.randn(2, 3, 256, 256)
|
||||
|
||||
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.14.2',
|
||||
version = '0.14.3',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
|
||||
@@ -104,7 +104,8 @@ class DistillWrapper(nn.Module):
|
||||
teacher,
|
||||
student,
|
||||
temperature = 1.,
|
||||
alpha = 0.5
|
||||
alpha = 0.5,
|
||||
hard = False
|
||||
):
|
||||
super().__init__()
|
||||
assert (isinstance(student, (DistillableViT, DistillableT2TViT, DistillableEfficientViT))) , 'student must be a vision transformer'
|
||||
@@ -116,6 +117,7 @@ class DistillWrapper(nn.Module):
|
||||
num_classes = student.num_classes
|
||||
self.temperature = temperature
|
||||
self.alpha = alpha
|
||||
self.hard = hard
|
||||
|
||||
self.distillation_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
|
||||
@@ -137,11 +139,15 @@ class DistillWrapper(nn.Module):
|
||||
|
||||
loss = F.cross_entropy(student_logits, labels)
|
||||
|
||||
distill_loss = F.kl_div(
|
||||
F.log_softmax(distill_logits / T, dim = -1),
|
||||
F.softmax(teacher_logits / T, dim = -1).detach(),
|
||||
reduction = 'batchmean')
|
||||
if not self.hard:
|
||||
distill_loss = F.kl_div(
|
||||
F.log_softmax(distill_logits / T, dim = -1),
|
||||
F.softmax(teacher_logits / T, dim = -1).detach(),
|
||||
reduction = 'batchmean')
|
||||
|
||||
else:
|
||||
teacher_labels = teacher_logits.argmax(dim = -1)
|
||||
distill_loss = F.cross_entropy(student_logits, teacher_labels)
|
||||
|
||||
distill_loss *= T ** 2
|
||||
|
||||
return loss * alpha + distill_loss * (1 - alpha)
|
||||
|
||||
Reference in New Issue
Block a user