Files
insightface/recognition/arcface_mxnet/metric.py
2021-06-19 23:37:10 +08:00

51 lines
1.7 KiB
Python

import numpy as np
import mxnet as mx
class AccMetric(mx.metric.EvalMetric):
def __init__(self):
self.axis = 1
super(AccMetric, self).__init__('acc',
axis=self.axis,
output_names=None,
label_names=None)
self.losses = []
self.count = 0
def update(self, labels, preds):
self.count += 1
label = labels[0]
pred_label = preds[1]
#print('ACC', label.shape, pred_label.shape)
if pred_label.shape != label.shape:
pred_label = mx.ndarray.argmax(pred_label, axis=self.axis)
pred_label = pred_label.asnumpy().astype('int32').flatten()
label = label.asnumpy()
if label.ndim == 2:
label = label[:, 0]
label = label.astype('int32').flatten()
assert label.shape == pred_label.shape
self.sum_metric += (pred_label.flat == label.flat).sum()
self.num_inst += len(pred_label.flat)
class LossValueMetric(mx.metric.EvalMetric):
def __init__(self):
self.axis = 1
super(LossValueMetric, self).__init__('lossvalue',
axis=self.axis,
output_names=None,
label_names=None)
self.losses = []
def update(self, labels, preds):
#label = labels[0].asnumpy()
pred = preds[-1].asnumpy()
#print('in loss', pred.shape)
#print(pred)
loss = pred[0]
self.sum_metric += loss
self.num_inst += 1.0
#gt_label = preds[-2].asnumpy()
#print(gt_label)