mirror of
https://github.com/deepinsight/insightface.git
synced 2025-12-30 08:02:27 +00:00
Create lovasz_softmax.py
This commit is contained in:
280
parsing/dml_csr/loss/lovasz_softmax.py
Normal file
280
parsing/dml_csr/loss/lovasz_softmax.py
Normal file
@@ -0,0 +1,280 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
"""
|
||||
@Author : Qingping Zheng
|
||||
@Contact : qingpingzheng2014@gmail.com
|
||||
@File : lovaz_softmax.py
|
||||
@Time : 10/01/21 00:00 PM
|
||||
@Desc :
|
||||
@License : Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@Copyright : Copyright 2015 The Authors. All Rights Reserved.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch import nn
|
||||
from torch.autograd import Variable
|
||||
|
||||
try:
|
||||
from itertools import ifilterfalse
|
||||
except ImportError: # py3k
|
||||
from itertools import filterfalse as ifilterfalse
|
||||
|
||||
|
||||
def lovasz_grad(gt_sorted):
|
||||
"""
|
||||
Computes gradient of the Lovasz extension w.r.t sorted errors
|
||||
See Alg. 1 in paper
|
||||
"""
|
||||
p = len(gt_sorted)
|
||||
gts = gt_sorted.sum()
|
||||
intersection = gts - gt_sorted.float().cumsum(0)
|
||||
union = gts + (1 - gt_sorted).float().cumsum(0)
|
||||
jaccard = 1. - intersection / union
|
||||
if p > 1: # cover 1-pixel case
|
||||
jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
|
||||
return jaccard
|
||||
|
||||
|
||||
def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True):
|
||||
"""
|
||||
IoU for foreground class
|
||||
binary: 1 foreground, 0 background
|
||||
"""
|
||||
if not per_image:
|
||||
preds, labels = (preds,), (labels,)
|
||||
ious = []
|
||||
for pred, label in zip(preds, labels):
|
||||
intersection = ((label == 1) & (pred == 1)).sum()
|
||||
union = ((label == 1) | ((pred == 1) & (label != ignore))).sum()
|
||||
if not union:
|
||||
iou = EMPTY
|
||||
else:
|
||||
iou = float(intersection) / float(union)
|
||||
ious.append(iou)
|
||||
iou = mean(ious) # mean accross images if per_image
|
||||
return 100 * iou
|
||||
|
||||
|
||||
def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False):
|
||||
"""
|
||||
Array of IoU for each (non ignored) class
|
||||
"""
|
||||
if not per_image:
|
||||
preds, labels = (preds,), (labels,)
|
||||
ious = []
|
||||
for pred, label in zip(preds, labels):
|
||||
iou = []
|
||||
for i in range(C):
|
||||
if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes)
|
||||
intersection = ((label == i) & (pred == i)).sum()
|
||||
union = ((label == i) | ((pred == i) & (label != ignore))).sum()
|
||||
if not union:
|
||||
iou.append(EMPTY)
|
||||
else:
|
||||
iou.append(float(intersection) / float(union))
|
||||
ious.append(iou)
|
||||
ious = [mean(iou) for iou in zip(*ious)] # mean accross images if per_image
|
||||
return 100 * np.array(ious)
|
||||
|
||||
|
||||
# --------------------------- BINARY LOSSES ---------------------------
|
||||
|
||||
|
||||
def lovasz_hinge(logits, labels, per_image=True, ignore=None):
|
||||
"""
|
||||
Binary Lovasz hinge loss
|
||||
logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
|
||||
labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
|
||||
per_image: compute the loss per image instead of per batch
|
||||
ignore: void class id
|
||||
"""
|
||||
if per_image:
|
||||
loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore))
|
||||
for log, lab in zip(logits, labels))
|
||||
else:
|
||||
loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore))
|
||||
return loss
|
||||
|
||||
|
||||
def lovasz_hinge_flat(logits, labels):
|
||||
"""
|
||||
Binary Lovasz hinge loss
|
||||
logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
|
||||
labels: [P] Tensor, binary ground truth labels (0 or 1)
|
||||
ignore: label to ignore
|
||||
"""
|
||||
if len(labels) == 0:
|
||||
# only void pixels, the gradients should be 0
|
||||
return logits.sum() * 0.
|
||||
signs = 2. * labels.float() - 1.
|
||||
errors = (1. - logits * Variable(signs))
|
||||
errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
|
||||
perm = perm.data
|
||||
gt_sorted = labels[perm]
|
||||
grad = lovasz_grad(gt_sorted)
|
||||
loss = torch.dot(F.relu(errors_sorted), Variable(grad))
|
||||
return loss
|
||||
|
||||
|
||||
def flatten_binary_scores(scores, labels, ignore=None):
|
||||
"""
|
||||
Flattens predictions in the batch (binary case)
|
||||
Remove labels equal to 'ignore'
|
||||
"""
|
||||
scores = scores.view(-1)
|
||||
labels = labels.view(-1)
|
||||
if ignore is None:
|
||||
return scores, labels
|
||||
valid = (labels != ignore)
|
||||
vscores = scores[valid]
|
||||
vlabels = labels[valid]
|
||||
return vscores, vlabels
|
||||
|
||||
|
||||
class StableBCELoss(torch.nn.modules.Module):
|
||||
def __init__(self):
|
||||
super(StableBCELoss, self).__init__()
|
||||
|
||||
def forward(self, input, target):
|
||||
neg_abs = - input.abs()
|
||||
loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
|
||||
return loss.mean()
|
||||
|
||||
|
||||
def binary_xloss(logits, labels, ignore=None):
|
||||
"""
|
||||
Binary Cross entropy loss
|
||||
logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
|
||||
labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
|
||||
ignore: void class id
|
||||
"""
|
||||
logits, labels = flatten_binary_scores(logits, labels, ignore)
|
||||
loss = StableBCELoss()(logits, Variable(labels.float()))
|
||||
return loss
|
||||
|
||||
|
||||
# --------------------------- MULTICLASS LOSSES ---------------------------
|
||||
|
||||
|
||||
def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=255, weighted=None):
|
||||
"""
|
||||
Multi-class Lovasz-Softmax loss
|
||||
probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).
|
||||
Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
|
||||
labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
|
||||
classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
|
||||
per_image: compute the loss per image instead of per batch
|
||||
ignore: void class labels
|
||||
"""
|
||||
if per_image:
|
||||
loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes, weighted=weighted)
|
||||
for prob, lab in zip(probas, labels))
|
||||
else:
|
||||
loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes, weighted=weighted )
|
||||
return loss
|
||||
|
||||
|
||||
def lovasz_softmax_flat(probas, labels, classes='present', weighted=None):
|
||||
"""
|
||||
Multi-class Lovasz-Softmax loss
|
||||
probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
|
||||
labels: [P] Tensor, ground truth labels (between 0 and C - 1)
|
||||
classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
|
||||
"""
|
||||
if probas.numel() == 0:
|
||||
# only void pixels, the gradients should be 0
|
||||
return probas * 0.
|
||||
C = probas.size(1)
|
||||
losses = []
|
||||
class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
|
||||
for c in class_to_sum:
|
||||
fg = (labels == c).float() # foreground for class c
|
||||
if (classes is 'present' and fg.sum() == 0):
|
||||
continue
|
||||
if C == 1:
|
||||
if len(classes) > 1:
|
||||
raise ValueError('Sigmoid output possible only with 1 class')
|
||||
class_pred = probas[:, 0]
|
||||
else:
|
||||
class_pred = probas[:, c]
|
||||
errors = (Variable(fg) - class_pred).abs()
|
||||
errors_sorted, perm = torch.sort(errors, 0, descending=True)
|
||||
perm = perm.data
|
||||
fg_sorted = fg[perm]
|
||||
if weighted is not None:
|
||||
losses.append(weighted[c]*torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))))
|
||||
else:
|
||||
losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))))
|
||||
return mean(losses)
|
||||
|
||||
|
||||
def flatten_probas(probas, labels, ignore=None):
|
||||
"""
|
||||
Flattens predictions in the batch
|
||||
"""
|
||||
if probas.dim() == 3:
|
||||
# assumes output of a sigmoid layer
|
||||
B, H, W = probas.size()
|
||||
probas = probas.view(B, 1, H, W)
|
||||
B, C, H, W = probas.size()
|
||||
probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C
|
||||
labels = labels.view(-1)
|
||||
if ignore is None:
|
||||
return probas, labels
|
||||
valid = (labels != ignore)
|
||||
# vprobas = probas[valid.nonzero().squeeze()]
|
||||
vprobas = probas[torch.nonzero(valid, as_tuple =False).squeeze()]
|
||||
vlabels = labels[valid]
|
||||
return vprobas, vlabels
|
||||
|
||||
|
||||
def xloss(logits, labels, ignore=None):
|
||||
"""
|
||||
Cross entropy loss
|
||||
"""
|
||||
return F.cross_entropy(logits, Variable(labels), ignore_index=255)
|
||||
|
||||
|
||||
# --------------------------- HELPER FUNCTIONS ---------------------------
|
||||
def isnan(x):
|
||||
return x != x
|
||||
|
||||
|
||||
def mean(l, ignore_nan=False, empty=0):
|
||||
"""
|
||||
nanmean compatible with generators.
|
||||
"""
|
||||
l = iter(l)
|
||||
if ignore_nan:
|
||||
l = ifilterfalse(isnan, l)
|
||||
try:
|
||||
n = 1
|
||||
acc = next(l)
|
||||
except StopIteration:
|
||||
if empty == 'raise':
|
||||
raise ValueError('Empty mean')
|
||||
return empty
|
||||
for n, v in enumerate(l, 2):
|
||||
acc += v
|
||||
if n == 1:
|
||||
return acc
|
||||
return acc / n
|
||||
|
||||
# --------------------------- Class ---------------------------
|
||||
class LovaszSoftmax(nn.Module):
|
||||
def __init__(self, per_image=False, ignore_index=255, weighted=None):
|
||||
super(LovaszSoftmax, self).__init__()
|
||||
self.lovasz_softmax = lovasz_softmax
|
||||
self.per_image = per_image
|
||||
self.ignore_index=ignore_index
|
||||
self.weighted = weighted
|
||||
|
||||
def forward(self, pred, label):
|
||||
pred = F.softmax(pred, dim=1)
|
||||
return self.lovasz_softmax(pred, label, per_image=self.per_image, ignore=self.ignore_index, weighted=self.weighted)
|
||||
Reference in New Issue
Block a user