mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-16 13:46:15 +00:00
70 lines
2.9 KiB
Python
70 lines
2.9 KiB
Python
import torch
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
|
|
class IFLoss(nn.Module):
|
|
def __init__(self, eikonal_weight, mask_weight, reg_weight, normal_weight, alpha):
|
|
super().__init__()
|
|
self.eikonal_weight = eikonal_weight
|
|
self.mask_weight = mask_weight
|
|
self.reg_weight = reg_weight
|
|
self.normal_weight = normal_weight
|
|
self.alpha = alpha
|
|
self.l1_loss = nn.L1Loss(reduction='sum')
|
|
self.l2_loss = nn.MSELoss(reduction='sum')
|
|
self.cosine = nn.CosineSimilarity()
|
|
|
|
def get_rgb_loss(self,rgb_values, rgb_gt, network_object_mask, object_mask):
|
|
if (network_object_mask & object_mask).sum() == 0:
|
|
return torch.tensor(0.0).cuda().float()
|
|
|
|
rgb_values = rgb_values[network_object_mask & object_mask]
|
|
rgb_gt = rgb_gt.reshape(-1, 3)[network_object_mask & object_mask]
|
|
rgb_loss = self.l1_loss(rgb_values, rgb_gt) / float(object_mask.shape[0])
|
|
return rgb_loss
|
|
|
|
def get_eikonal_loss(self, grad_theta):
|
|
if grad_theta.shape[0] == 0:
|
|
return torch.tensor(0.0).cuda().float()
|
|
|
|
eikonal_loss = ((grad_theta.norm(2, dim=1) - 1) ** 2).mean()
|
|
return eikonal_loss
|
|
|
|
def get_mask_loss(self, sdf_output, network_object_mask, object_mask):
|
|
mask = ~(network_object_mask & object_mask)
|
|
if mask.sum() == 0:
|
|
return torch.tensor(0.0).cuda().float()
|
|
sdf_pred = -self.alpha * sdf_output[mask]
|
|
gt = object_mask[mask].float()
|
|
mask_loss = (1 / self.alpha) * F.binary_cross_entropy_with_logits(sdf_pred.squeeze(), gt, reduction='sum') / float(object_mask.shape[0])
|
|
return mask_loss
|
|
|
|
def get_reg_loss(self, point_gt, point_pre):
|
|
loss = self.l2_loss(point_gt, point_pre) / len(point_pre)
|
|
return loss
|
|
|
|
def forward(self, model_outputs, ground_truth):
|
|
rgb_gt = ground_truth['rgb'].cuda()
|
|
network_object_mask = model_outputs['network_object_mask']
|
|
object_mask = model_outputs['object_mask']
|
|
|
|
rgb_loss = self.get_rgb_loss(model_outputs['rgb_values'], rgb_gt, network_object_mask, object_mask)
|
|
mask_loss = self.get_mask_loss(model_outputs['sdf_output'], network_object_mask, object_mask)
|
|
eikonal_loss = self.get_eikonal_loss(model_outputs['grad_theta'])
|
|
reg_loss = self.get_reg_loss(model_outputs['points_mesh_ray_gt'], model_outputs['points_pre'])
|
|
normal_loss = 1 - torch.mean(self.cosine(model_outputs['points_mesh_ray_normals'], model_outputs['surface_normals']))
|
|
loss = rgb_loss + \
|
|
self.eikonal_weight * eikonal_loss + \
|
|
self.mask_weight * mask_loss + \
|
|
self.reg_weight * reg_loss + \
|
|
self.normal_weight * normal_loss
|
|
|
|
return {
|
|
'loss': loss,
|
|
'rgb_loss': rgb_loss,
|
|
'eikonal_loss': eikonal_loss,
|
|
'mask_loss': mask_loss,
|
|
'reg_loss': reg_loss,
|
|
'normal_loss': normal_loss,
|
|
}
|