mirror of
https://github.com/deepinsight/insightface.git
synced 2025-12-30 08:02:27 +00:00
72 lines
2.3 KiB
Python
72 lines
2.3 KiB
Python
import argparse
|
|
import cv2
|
|
import sys
|
|
import numpy as np
|
|
import os
|
|
import mxnet as mx
|
|
import datetime
|
|
import img_helper
|
|
from config import config
|
|
from data import FaceSegIter
|
|
from metric import LossValueMetric, NMEMetric
|
|
|
|
parser = argparse.ArgumentParser(description='test nme on rec data')
|
|
# general
|
|
parser.add_argument('--rec',
|
|
default='./data_2d/ibug.rec',
|
|
help='rec data path')
|
|
parser.add_argument('--prefix', default='', help='model prefix')
|
|
parser.add_argument('--epoch', type=int, default=1, help='model epoch')
|
|
parser.add_argument('--gpu', type=int, default=0, help='')
|
|
parser.add_argument('--landmark-type', default='2d', help='')
|
|
parser.add_argument('--image-size', type=int, default=128, help='')
|
|
args = parser.parse_args()
|
|
|
|
rec_path = args.rec
|
|
ctx_id = args.gpu
|
|
prefix = args.prefix
|
|
epoch = args.epoch
|
|
image_size = (args.image_size, args.image_size)
|
|
config.landmark_type = args.landmark_type
|
|
config.input_img_size = image_size[0]
|
|
|
|
if ctx_id >= 0:
|
|
ctx = mx.gpu(ctx_id)
|
|
else:
|
|
ctx = mx.cpu()
|
|
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
|
|
all_layers = sym.get_internals()
|
|
sym = all_layers['heatmap_output']
|
|
#model = mx.mod.Module(symbol=sym, context=ctx, data_names=['data'], label_names=['softmax_label'])
|
|
model = mx.mod.Module(symbol=sym,
|
|
context=ctx,
|
|
data_names=['data'],
|
|
label_names=None)
|
|
#model = mx.mod.Module(symbol=sym, context=ctx)
|
|
model.bind(for_training=False,
|
|
data_shapes=[('data', (1, 3, image_size[0], image_size[1]))])
|
|
model.set_params(arg_params, aux_params)
|
|
|
|
val_iter = FaceSegIter(
|
|
path_imgrec=rec_path,
|
|
batch_size=1,
|
|
aug_level=0,
|
|
)
|
|
_metric = NMEMetric()
|
|
#val_metric = mx.metric.create(_metric)
|
|
#val_metric.reset()
|
|
#val_iter.reset()
|
|
nme = []
|
|
for i, eval_batch in enumerate(val_iter):
|
|
if i % 10 == 0:
|
|
print('processing', i)
|
|
#print(eval_batch.data[0].shape, eval_batch.label[0].shape)
|
|
batch_data = mx.io.DataBatch(eval_batch.data)
|
|
model.forward(batch_data, is_train=False)
|
|
#model.update_metric(val_metric, eval_batch.label, True)
|
|
pred_label = model.get_outputs()[-1].asnumpy()
|
|
label = eval_batch.label[0].asnumpy()
|
|
_nme = _metric.cal_nme(label, pred_label)
|
|
nme.append(_nme)
|
|
print(np.mean(nme))
|