Files
insightface/recognition/oneflow_face/insightface_val.py
2021-01-20 17:25:35 +08:00

188 lines
6.6 KiB
Python

import math, os
import argparse
import numpy as np
import oneflow as flow
from config import config, default, generate_val_config
import ofrecord_util
import validation_util
from symbols import fresnet100, fmobilefacenet
def get_val_args():
val_parser = argparse.ArgumentParser(description="flags for validation")
val_parser.add_argument("-network", default=default.network)
args, rest = val_parser.parse_known_args()
generate_val_config(args.network)
for ds in config.val_targets:
val_parser.add_argument(
"--%s_dataset_dir" % ds,
type=str,
default=os.path.join(default.val_dataset_dir, ds),
help="validation dataset dir",
)
val_parser.add_argument(
"--val_data_part_num",
type=str,
default=default.val_data_part_num,
help="validation dataset dir prefix",
)
val_parser.add_argument(
"--lfw_total_images_num", type=int, default=12000, required=False
)
val_parser.add_argument(
"--cfp_fp_total_images_num", type=int, default=14000, required=False
)
val_parser.add_argument(
"--agedb_30_total_images_num", type=int, default=12000, required=False
)
# distribution config
val_parser.add_argument(
"--device_num_per_node",
type=int,
default=default.device_num_per_node,
required=False,
)
val_parser.add_argument(
"--num_nodes",
type=int,
default=default.num_nodes,
help="node/machine number for training",
)
val_parser.add_argument(
"--val_batch_size_per_device",
default=default.val_batch_size_per_device,
type=int,
help="validation batch size per device",
)
val_parser.add_argument(
"--nrof_folds", default=default.nrof_folds, type=int, help="nrof folds"
)
# model and log
val_parser.add_argument(
"--log_dir", type=str, default=default.log_dir, help="log info save"
)
val_parser.add_argument(
"--model_load_dir", default=default.model_load_dir, help="path to load model."
)
return val_parser.parse_args()
def flip_data(images):
images_flipped = np.flip(images, axis=2).astype(np.float32)
return images_flipped
def get_val_config():
config = flow.function_config()
config.default_logical_view(flow.scope.consistent_view())
config.default_data_type(flow.float)
return config
class Validator(object):
def __init__(self, args):
self.args = args
if default.do_validation_while_train:
function_config = get_val_config()
@flow.global_function(type="predict", function_config=function_config)
def get_validation_datset_lfw_job():
with flow.scope.placement("cpu", "0:0"):
issame, images = ofrecord_util.load_lfw_dataset(self.args)
return issame, images
self.get_validation_datset_lfw_fn = get_validation_datset_lfw_job
@flow.global_function(type="predict", function_config=function_config)
def get_validation_dataset_cfp_fp_job():
with flow.scope.placement("cpu", "0:0"):
issame, images = ofrecord_util.load_cfp_fp_dataset(self.args)
return issame, images
self.get_validation_dataset_cfp_fp_fn = get_validation_dataset_cfp_fp_job
@flow.global_function(type="predict", function_config=function_config)
def get_validation_dataset_agedb_30_job():
with flow.scope.placement("cpu", "0:0"):
issame, images = ofrecord_util.load_agedb_30_dataset(self.args)
return issame, images
self.get_validation_dataset_agedb_30_fn = (
get_validation_dataset_agedb_30_job
)
@flow.global_function(type="predict", function_config=function_config)
def get_symbol_val_job(
images: flow.typing.Numpy.Placeholder(
(self.args.val_batch_size_per_device, 112, 112, 3)
)
):
print("val batch data: ", images.shape)
embedding = eval(config.net_name).get_symbol(images)
return embedding
self.get_symbol_val_fn = get_symbol_val_job
def do_validation(self, dataset="lfw"):
print("Validation on [{}]:".format(dataset))
_issame_list = []
_em_list = []
_em_flipped_list = []
batch_size = self.args.val_batch_size_per_device
if dataset == "lfw":
total_images_num = self.args.lfw_total_images_num
val_job = self.get_validation_datset_lfw_fn
if dataset == "cfp_fp":
total_images_num = self.args.cfp_fp_total_images_num
val_job = self.get_validation_dataset_cfp_fp_fn
if dataset == "agedb_30":
total_images_num = self.args.agedb_30_total_images_num
val_job = self.get_validation_dataset_agedb_30_fn
val_iter_num = math.ceil(total_images_num / batch_size)
for i in range(val_iter_num):
_issame, images = val_job().get()
images_flipped = flip_data(images.numpy())
_em = self.get_symbol_val_fn(images.numpy()).get()
_em_flipped = self.get_symbol_val_fn(images_flipped).get()
_issame_list.append(_issame.numpy())
_em_list.append(_em.numpy())
_em_flipped_list.append(_em_flipped.numpy())
issame = np.array(_issame_list).flatten().reshape(-1, 1)[:total_images_num, :]
issame_list = [bool(x) for x in issame[0::2]]
embedding_length = _em_list[0].shape[-1]
embeddings = (np.array(_em_list).flatten().reshape(-1, embedding_length))[
:total_images_num, :
]
embeddings_flipped = (
np.array(_em_flipped_list).flatten().reshape(-1, embedding_length)
)[:total_images_num, :]
embeddings_list = [embeddings, embeddings_flipped]
return issame_list, embeddings_list
def load_checkpoint(self):
flow.load_variables(flow.checkpoint.get(self.args.model_load_dir))
def main():
args = get_val_args()
flow.env.log_dir(args.log_dir)
flow.config.gpu_device_num(args.device_num_per_node)
# validation
validator = Validator(args)
validator.load_checkpoint()
for ds in config.val_targets:
issame_list, embeddings_list = validator.do_validation(dataset=ds)
validation_util.cal_validation_metrics(
embeddings_list, issame_list, nrof_folds=args.nrof_folds,
)
if __name__ == "__main__":
main()