Files
insightface/reconstruction/PBIDR/code/model/loss.py
2022-03-19 14:24:51 +08:00

70 lines
2.9 KiB
Python

import torch
from torch import nn
from torch.nn import functional as F
class IFLoss(nn.Module):
def __init__(self, eikonal_weight, mask_weight, reg_weight, normal_weight, alpha):
super().__init__()
self.eikonal_weight = eikonal_weight
self.mask_weight = mask_weight
self.reg_weight = reg_weight
self.normal_weight = normal_weight
self.alpha = alpha
self.l1_loss = nn.L1Loss(reduction='sum')
self.l2_loss = nn.MSELoss(reduction='sum')
self.cosine = nn.CosineSimilarity()
def get_rgb_loss(self,rgb_values, rgb_gt, network_object_mask, object_mask):
if (network_object_mask & object_mask).sum() == 0:
return torch.tensor(0.0).cuda().float()
rgb_values = rgb_values[network_object_mask & object_mask]
rgb_gt = rgb_gt.reshape(-1, 3)[network_object_mask & object_mask]
rgb_loss = self.l1_loss(rgb_values, rgb_gt) / float(object_mask.shape[0])
return rgb_loss
def get_eikonal_loss(self, grad_theta):
if grad_theta.shape[0] == 0:
return torch.tensor(0.0).cuda().float()
eikonal_loss = ((grad_theta.norm(2, dim=1) - 1) ** 2).mean()
return eikonal_loss
def get_mask_loss(self, sdf_output, network_object_mask, object_mask):
mask = ~(network_object_mask & object_mask)
if mask.sum() == 0:
return torch.tensor(0.0).cuda().float()
sdf_pred = -self.alpha * sdf_output[mask]
gt = object_mask[mask].float()
mask_loss = (1 / self.alpha) * F.binary_cross_entropy_with_logits(sdf_pred.squeeze(), gt, reduction='sum') / float(object_mask.shape[0])
return mask_loss
def get_reg_loss(self, point_gt, point_pre):
loss = self.l2_loss(point_gt, point_pre) / len(point_pre)
return loss
def forward(self, model_outputs, ground_truth):
rgb_gt = ground_truth['rgb'].cuda()
network_object_mask = model_outputs['network_object_mask']
object_mask = model_outputs['object_mask']
rgb_loss = self.get_rgb_loss(model_outputs['rgb_values'], rgb_gt, network_object_mask, object_mask)
mask_loss = self.get_mask_loss(model_outputs['sdf_output'], network_object_mask, object_mask)
eikonal_loss = self.get_eikonal_loss(model_outputs['grad_theta'])
reg_loss = self.get_reg_loss(model_outputs['points_mesh_ray_gt'], model_outputs['points_pre'])
normal_loss = 1 - torch.mean(self.cosine(model_outputs['points_mesh_ray_normals'], model_outputs['surface_normals']))
loss = rgb_loss + \
self.eikonal_weight * eikonal_loss + \
self.mask_weight * mask_loss + \
self.reg_weight * reg_loss + \
self.normal_weight * normal_loss
return {
'loss': loss,
'rgb_loss': rgb_loss,
'eikonal_loss': eikonal_loss,
'mask_loss': mask_loss,
'reg_loss': reg_loss,
'normal_loss': normal_loss,
}