import torch from torch import nn from torch.nn import functional as F class IFLoss(nn.Module): def __init__(self, eikonal_weight, mask_weight, reg_weight, normal_weight, alpha): super().__init__() self.eikonal_weight = eikonal_weight self.mask_weight = mask_weight self.reg_weight = reg_weight self.normal_weight = normal_weight self.alpha = alpha self.l1_loss = nn.L1Loss(reduction='sum') self.l2_loss = nn.MSELoss(reduction='sum') self.cosine = nn.CosineSimilarity() def get_rgb_loss(self,rgb_values, rgb_gt, network_object_mask, object_mask): if (network_object_mask & object_mask).sum() == 0: return torch.tensor(0.0).cuda().float() rgb_values = rgb_values[network_object_mask & object_mask] rgb_gt = rgb_gt.reshape(-1, 3)[network_object_mask & object_mask] rgb_loss = self.l1_loss(rgb_values, rgb_gt) / float(object_mask.shape[0]) return rgb_loss def get_eikonal_loss(self, grad_theta): if grad_theta.shape[0] == 0: return torch.tensor(0.0).cuda().float() eikonal_loss = ((grad_theta.norm(2, dim=1) - 1) ** 2).mean() return eikonal_loss def get_mask_loss(self, sdf_output, network_object_mask, object_mask): mask = ~(network_object_mask & object_mask) if mask.sum() == 0: return torch.tensor(0.0).cuda().float() sdf_pred = -self.alpha * sdf_output[mask] gt = object_mask[mask].float() mask_loss = (1 / self.alpha) * F.binary_cross_entropy_with_logits(sdf_pred.squeeze(), gt, reduction='sum') / float(object_mask.shape[0]) return mask_loss def get_reg_loss(self, point_gt, point_pre): loss = self.l2_loss(point_gt, point_pre) / len(point_pre) return loss def forward(self, model_outputs, ground_truth): rgb_gt = ground_truth['rgb'].cuda() network_object_mask = model_outputs['network_object_mask'] object_mask = model_outputs['object_mask'] rgb_loss = self.get_rgb_loss(model_outputs['rgb_values'], rgb_gt, network_object_mask, object_mask) mask_loss = self.get_mask_loss(model_outputs['sdf_output'], network_object_mask, object_mask) eikonal_loss = self.get_eikonal_loss(model_outputs['grad_theta']) reg_loss = self.get_reg_loss(model_outputs['points_mesh_ray_gt'], model_outputs['points_pre']) normal_loss = 1 - torch.mean(self.cosine(model_outputs['points_mesh_ray_normals'], model_outputs['surface_normals'])) loss = rgb_loss + \ self.eikonal_weight * eikonal_loss + \ self.mask_weight * mask_loss + \ self.reg_weight * reg_loss + \ self.normal_weight * normal_loss return { 'loss': loss, 'rgb_loss': rgb_loss, 'eikonal_loss': eikonal_loss, 'mask_loss': mask_loss, 'reg_loss': reg_loss, 'normal_loss': normal_loss, }