mirror of
https://github.com/deepinsight/insightface.git
synced 2025-12-30 08:02:27 +00:00
Update losses.py
Incorrect `scale` parameter name for ArcFace class.
This commit is contained in:
@@ -62,7 +62,7 @@ class ArcFace(torch.nn.Module):
|
||||
"""
|
||||
def __init__(self, s=64.0, margin=0.5):
|
||||
super(ArcFace, self).__init__()
|
||||
self.scale = s
|
||||
self.s = s
|
||||
self.margin = margin
|
||||
self.cos_m = math.cos(margin)
|
||||
self.sin_m = math.sin(margin)
|
||||
@@ -81,7 +81,7 @@ class ArcFace(torch.nn.Module):
|
||||
final_target_logit = target_logit + self.margin
|
||||
logits[index, labels[index].view(-1)] = final_target_logit
|
||||
logits.cos_()
|
||||
logits = logits * self.s
|
||||
logits = logits * self.s
|
||||
return logits
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user