mirror of
https://github.com/deepinsight/insightface.git
synced 2025-12-30 08:02:27 +00:00
101 lines
3.3 KiB
Python
101 lines
3.3 KiB
Python
import torch
|
|
import math
|
|
|
|
|
|
class CombinedMarginLoss(torch.nn.Module):
|
|
def __init__(self,
|
|
s,
|
|
m1,
|
|
m2,
|
|
m3,
|
|
interclass_filtering_threshold=0):
|
|
super().__init__()
|
|
self.s = s
|
|
self.m1 = m1
|
|
self.m2 = m2
|
|
self.m3 = m3
|
|
self.interclass_filtering_threshold = interclass_filtering_threshold
|
|
|
|
# For ArcFace
|
|
self.cos_m = math.cos(self.m2)
|
|
self.sin_m = math.sin(self.m2)
|
|
self.theta = math.cos(math.pi - self.m2)
|
|
self.sinmm = math.sin(math.pi - self.m2) * self.m2
|
|
self.easy_margin = False
|
|
|
|
|
|
def forward(self, logits, labels):
|
|
index_positive = torch.where(labels != -1)[0]
|
|
|
|
if self.interclass_filtering_threshold > 0:
|
|
with torch.no_grad():
|
|
dirty = logits > self.interclass_filtering_threshold
|
|
dirty = dirty.float()
|
|
mask = torch.ones([index_positive.size(0), logits.size(1)], device=logits.device)
|
|
mask.scatter_(1, labels[index_positive], 0)
|
|
dirty[index_positive] *= mask
|
|
tensor_mul = 1 - dirty
|
|
logits = tensor_mul * logits
|
|
|
|
target_logit = logits[index_positive, labels[index_positive].view(-1)]
|
|
|
|
if self.m1 == 1.0 and self.m3 == 0.0:
|
|
with torch.no_grad():
|
|
target_logit.arccos_()
|
|
logits.arccos_()
|
|
final_target_logit = target_logit + self.m2
|
|
logits[index_positive, labels[index_positive].view(-1)] = final_target_logit
|
|
logits.cos_()
|
|
logits = logits * self.s
|
|
|
|
elif self.m3 > 0:
|
|
final_target_logit = target_logit - self.m3
|
|
logits[index_positive, labels[index_positive].view(-1)] = final_target_logit
|
|
logits = logits * self.s
|
|
else:
|
|
raise
|
|
|
|
return logits
|
|
|
|
class ArcFace(torch.nn.Module):
|
|
""" ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf):
|
|
"""
|
|
def __init__(self, s=64.0, margin=0.5):
|
|
super(ArcFace, self).__init__()
|
|
self.s = s
|
|
self.margin = margin
|
|
self.cos_m = math.cos(margin)
|
|
self.sin_m = math.sin(margin)
|
|
self.theta = math.cos(math.pi - margin)
|
|
self.sinmm = math.sin(math.pi - margin) * margin
|
|
self.easy_margin = False
|
|
|
|
|
|
def forward(self, logits: torch.Tensor, labels: torch.Tensor):
|
|
index = torch.where(labels != -1)[0]
|
|
target_logit = logits[index, labels[index].view(-1)]
|
|
|
|
with torch.no_grad():
|
|
target_logit.arccos_()
|
|
logits.arccos_()
|
|
final_target_logit = target_logit + self.margin
|
|
logits[index, labels[index].view(-1)] = final_target_logit
|
|
logits.cos_()
|
|
logits = logits * self.s
|
|
return logits
|
|
|
|
|
|
class CosFace(torch.nn.Module):
|
|
def __init__(self, s=64.0, m=0.40):
|
|
super(CosFace, self).__init__()
|
|
self.s = s
|
|
self.m = m
|
|
|
|
def forward(self, logits: torch.Tensor, labels: torch.Tensor):
|
|
index = torch.where(labels != -1)[0]
|
|
target_logit = logits[index, labels[index].view(-1)]
|
|
final_target_logit = target_logit - self.m
|
|
logits[index, labels[index].view(-1)] = final_target_logit
|
|
logits = logits * self.s
|
|
return logits
|