Update losses.py

Incorrect `scale` parameter name for ArcFace class.
This commit is contained in:
Parsa
2023-06-10 18:36:19 +02:00
committed by GitHub
parent bc19ea168b
commit 9097cd7ba3

View File

@@ -62,7 +62,7 @@ class ArcFace(torch.nn.Module):
"""
def __init__(self, s=64.0, margin=0.5):
super(ArcFace, self).__init__()
self.scale = s
self.s = s
self.margin = margin
self.cos_m = math.cos(margin)
self.sin_m = math.sin(margin)
@@ -81,7 +81,7 @@ class ArcFace(torch.nn.Module):
final_target_logit = target_logit + self.margin
logits[index, labels[index].view(-1)] = final_target_logit
logits.cos_()
logits = logits * self.s
logits = logits * self.s
return logits