diff --git a/recognition/_evaluation_/ijb/ijb_evals.py b/recognition/_evaluation_/ijb/ijb_evals.py index e88ce92..9652c89 100755 --- a/recognition/_evaluation_/ijb/ijb_evals.py +++ b/recognition/_evaluation_/ijb/ijb_evals.py @@ -1,12 +1,12 @@ #!/usr/bin/env python3 import os +import cv2 import numpy as np +import pandas as pd from tqdm import tqdm from skimage import transform from sklearn.preprocessing import normalize from sklearn.metrics import roc_curve, auc -import pandas as pd -import cv2 class Mxnet_model_interf: @@ -62,8 +62,27 @@ class Torch_model_interf: return output.cpu().detach().numpy() +class ONNX_model_interf: + def __init__(self, model_file, image_size=(112, 112)): + import onnxruntime as ort + ort.set_default_logger_severity(3) + self.ort_session = ort.InferenceSession(model_file) + self.output_names = [self.ort_session.get_outputs()[0].name] + self.input_name = self.ort_session.get_inputs()[0].name + + def __call__(self, imgs): + imgs = imgs.transpose(0, 3, 1, 2).astype("float32") + imgs = (imgs - 127.5) * 0.0078125 + outputs = self.ort_session.run(self.output_names, {self.input_name: imgs}) + return outputs[0] + + def keras_model_interf(model_file): import tensorflow as tf + from tensorflow_addons.layers import StochasticDepth + + for gpu in tf.config.experimental.list_physical_devices("GPU"): + tf.config.experimental.set_memory_growth(gpu, True) mm = tf.keras.models.load_model(model_file, compile=False) return lambda imgs: mm((tf.cast(imgs, "float32") - 127.5) * 0.0078125).numpy() @@ -96,7 +115,7 @@ def extract_IJB_data_11(data_path, subset, save_path=None, force_reload=False): if save_path == None: save_path = os.path.join(data_path, subset + "_backup.npz") if not force_reload and os.path.exists(save_path): - print(">>>> Reloading from backup: %s ..." % save_path) + print(">>>> Reload from backup: %s ..." % save_path) aa = np.load(save_path) return ( aa["templates"], @@ -123,15 +142,16 @@ def extract_IJB_data_11(data_path, subset, save_path=None, force_reload=False): print(">>>> Loading templates and medias...") templates, medias = read_IJB_meta_columns_to_int(media_list_path, columns=[1, 2]) # ['1.jpg', '1', '69544'] print("templates: %s, medias: %s, unique templates: %s" % (templates.shape, medias.shape, np.unique(templates).shape)) - # (227630,) (227630,) (12115,) + # templates: (227630,), medias: (227630,), unique templates: (12115,) print(">>>> Loading pairs...") p1, p2, label = read_IJB_meta_columns_to_int(pair_list_path, columns=[0, 1, 2]) # ['1', '11065', '1'] print("p1: %s, unique p1: %s" % (p1.shape, np.unique(p1).shape)) print("p2: %s, unique p2: %s" % (p2.shape, np.unique(p2).shape)) print("label: %s, label value counts: %s" % (label.shape, dict(zip(*np.unique(label, return_counts=True))))) - # (8010270,) (8010270,) (8010270,) (1845,) (10270,) # 10270 + 1845 = 12115 - # {0: 8000000, 1: 10270} + # p1: (8010270,), unique p1: (1845,) + # p2: (8010270,), unique p2: (10270,) # 10270 + 1845 = 12115 --> np.unique(templates).shape + # label: (8010270,), label value counts: {0: 8000000, 1: 10270} print(">>>> Loading images...") with open(img_list_path, "r") as ff: @@ -142,7 +162,7 @@ def extract_IJB_data_11(data_path, subset, save_path=None, force_reload=False): landmarks = img_records[:, 1:-1].astype("float32").reshape(-1, 5, 2) face_scores = img_records[:, -1].astype("float32") print("img_names: %s, landmarks: %s, face_scores: %s" % (img_names.shape, landmarks.shape, face_scores.shape)) - # (227630,) (227630, 5, 2) (227630,) + # img_names: (227630,), landmarks: (227630, 5, 2), face_scores: (227630,) print("face_scores value counts:", dict(zip(*np.histogram(face_scores, bins=9)[::-1]))) # {0.1: 2515, 0.2: 0, 0.3: 62, 0.4: 94, 0.5: 136, 0.6: 197, 0.7: 291, 0.8: 538, 0.9: 223797} @@ -166,11 +186,13 @@ def extract_gallery_prob_data(data_path, subset, save_path=None, force_reload=Fa if save_path == None: save_path = os.path.join(data_path, subset + "_gallery_prob_backup.npz") if not force_reload and os.path.exists(save_path): - print(">>>> Reloading from backup: %s ..." % save_path) + print(">>>> Reload from backup: %s ..." % save_path) aa = np.load(save_path) return ( - aa["gallery_templates"], - aa["gallery_subject_ids"], + aa["s1_templates"], + aa["s1_subject_ids"], + aa["s2_templates"], + aa["s2_subject_ids"], aa["probe_mixed_templates"], aa["probe_mixed_subject_ids"], ) @@ -189,14 +211,8 @@ def extract_gallery_prob_data(data_path, subset, save_path=None, force_reload=Fa print(">>>> Loading gallery feature...") s1_templates, s1_subject_ids = read_IJB_meta_columns_to_int(gallery_s1_record, columns=[0, 1], skiprows=1, sep=",") s2_templates, s2_subject_ids = read_IJB_meta_columns_to_int(gallery_s2_record, columns=[0, 1], skiprows=1, sep=",") - gallery_templates = np.concatenate([s1_templates, s2_templates]) - gallery_subject_ids = np.concatenate([s1_subject_ids, s2_subject_ids]) print("s1 gallery: %s, ids: %s, unique: %s" % (s1_templates.shape, s1_subject_ids.shape, np.unique(s1_templates).shape)) print("s2 gallery: %s, ids: %s, unique: %s" % (s2_templates.shape, s2_subject_ids.shape, np.unique(s2_templates).shape)) - print( - "total gallery: %s, ids: %s, unique: %s" - % (gallery_templates.shape, gallery_subject_ids.shape, np.unique(gallery_templates).shape) - ) print(">>>> Loading prope feature...") probe_mixed_templates, probe_mixed_subject_ids = read_IJB_meta_columns_to_int( @@ -208,13 +224,15 @@ def extract_gallery_prob_data(data_path, subset, save_path=None, force_reload=Fa print(">>>> Saving backup to: %s ..." % save_path) np.savez( save_path, - gallery_templates=gallery_templates, - gallery_subject_ids=gallery_subject_ids, + s1_templates=s1_templates, + s1_subject_ids=s1_subject_ids, + s2_templates=s2_templates, + s2_subject_ids=s2_subject_ids, probe_mixed_templates=probe_mixed_templates, probe_mixed_subject_ids=probe_mixed_subject_ids, ) print() - return gallery_templates, gallery_subject_ids, probe_mixed_templates, probe_mixed_subject_ids + return s1_templates, s1_subject_ids, s2_templates, s2_subject_ids, probe_mixed_templates, probe_mixed_subject_ids def get_embeddings(model_interf, img_names, landmarks, batch_size=64, flip=True): @@ -231,6 +249,7 @@ def get_embeddings(model_interf, img_names, landmarks, batch_size=64, flip=True) def process_embeddings(embs, embs_f=[], use_flip_test=True, use_norm_score=False, use_detector_score=True, face_scores=None): + print(">>>> process_embeddings: Norm {}, Detect_score {}, Flip {}".format(use_norm_score, use_detector_score, use_flip_test)) if use_flip_test and len(embs_f) != 0: embs = embs + embs_f if use_norm_score: @@ -241,10 +260,10 @@ def process_embeddings(embs, embs_f=[], use_flip_test=True, use_norm_score=False def image2template_feature(img_feats=None, templates=None, medias=None, choose_templates=None, choose_ids=None): - if choose_templates is not None: # 1N + if choose_templates is not None: # 1:N unique_templates, indices = np.unique(choose_templates, return_index=True) unique_subjectids = choose_ids[indices] - else: # 11 + else: # 1:1 unique_templates = np.unique(templates) unique_subjectids = None @@ -269,73 +288,65 @@ def image2template_feature(img_feats=None, templates=None, medias=None, choose_t return template_norm_feats, unique_templates, unique_subjectids -def verification_11(template_norm_feats=None, unique_templates=None, p1=None, p2=None, batch_size=100000): +def verification_11(template_norm_feats=None, unique_templates=None, p1=None, p2=None, batch_size=10000): + try: + print(">>>> Trying cupy.") + import cupy as cp + + template_norm_feats = cp.array(template_norm_feats) + score_func = lambda feat1, feat2: cp.sum(feat1 * feat2, axis=-1).get() + test = score_func(template_norm_feats[:batch_size], template_norm_feats[:batch_size]) + except: + score_func = lambda feat1, feat2: np.sum(feat1 * feat2, -1) + template2id = np.zeros(max(unique_templates) + 1, dtype=int) - for count_template, uqt in enumerate(unique_templates): - template2id[uqt] = count_template + template2id[unique_templates] = np.arange(len(unique_templates)) steps = int(np.ceil(len(p1) / batch_size)) score = [] for id in tqdm(range(steps), "Verification"): feat1 = template_norm_feats[template2id[p1[id * batch_size : (id + 1) * batch_size]]] feat2 = template_norm_feats[template2id[p2[id * batch_size : (id + 1) * batch_size]]] - score.extend(np.sum(feat1 * feat2, -1)) + score.extend(score_func(feat1, feat2)) return np.array(score) -def evaluation_1N(query_feats, gallery_feats, query_ids, reg_ids): - import heapq - - Fars = [0.01, 0.1] +def evaluation_1N(query_feats, gallery_feats, query_ids, reg_ids, fars=[0.01, 0.1]): print("query_feats: %s, gallery_feats: %s" % (query_feats.shape, gallery_feats.shape)) + similarity = np.dot(query_feats, gallery_feats.T) # (19593, 3531) - query_num = query_feats.shape[0] - gallery_num = gallery_feats.shape[0] + top_1_count, top_5_count, top_10_count = 0, 0, 0 + pos_sims, neg_sims, non_gallery_sims = [], [], [] + for index, query_id in enumerate(query_ids): + if query_id in reg_ids: + gallery_label = np.argwhere(reg_ids == query_id)[0, 0] + index_sorted = np.argsort(similarity[index])[::-1] - similarity = np.dot(query_feats, gallery_feats.T) - print("similarity shape:", similarity.shape) - top_inds = np.argsort(-similarity) - print("top_inds shape:", top_inds.shape) + top_1_count += gallery_label in index_sorted[:1] + top_5_count += gallery_label in index_sorted[:5] + top_10_count += gallery_label in index_sorted[:10] - # gen_mask - mask = [] - for query_id in query_ids: - pos = [i for i, x in enumerate(reg_ids) if query_id == x] - if len(pos) != 1: - raise RuntimeError("RegIdsError with id = {}, duplicate = {} ".format(query_id, len(pos))) - mask.append(pos[0]) + pos_sims.append(similarity[index][reg_ids == query_id][0]) + neg_sims.append(similarity[index][reg_ids != query_id]) + else: + non_gallery_sims.append(similarity[index]) + total_pos = len(pos_sims) + pos_sims, neg_sims, non_gallery_sims = np.array(pos_sims), np.array(neg_sims), np.array(non_gallery_sims) + print("pos_sims: %s, neg_sims: %s, non_gallery_sims: %s" % (pos_sims.shape, neg_sims.shape, non_gallery_sims.shape)) + print("top1: %f, top5: %f, top10: %f" % (top_1_count / total_pos, top_5_count / total_pos, top_10_count / total_pos)) - # calculate top_n - correct_num_1, correct_num_5, correct_num_10 = 0, 0, 0 - for i in range(query_num): - top_1, top_5, top_10 = top_inds[i, 0], top_inds[i, 0:5], top_inds[i, 0:10] - if mask[i] == top_1: - correct_num_1 += 1 - if mask[i] in top_5: - correct_num_5 += 1 - if mask[i] in top_10: - correct_num_10 += 1 - print("top1: %f, top5: %f, top10: %f" % (correct_num_1 / query_num, correct_num_5 / query_num, correct_num_10 / query_num)) - - # neg_pair_num = query_num * gallery_num - query_num - # print("neg_pair_num:", neg_pair_num) - required_topk = [int(np.ceil(query_num * x)) for x in Fars] - top_sims = similarity - # calculate fars and tprs - pos_sims = [] - for i in range(query_num): - gt = mask[i] - pos_sims.append(top_sims[i, gt]) - top_sims[i, gt] = -2.0 - - pos_sims = np.array(pos_sims) - neg_sims = top_sims[np.where(top_sims > -2.0)] - neg_sims_sorted = heapq.nlargest(max(required_topk), neg_sims) # heap sort - print("pos_sims: %s, neg_sims: %s, neg_sims_sorted: %d" % (pos_sims.shape, neg_sims.shape, len(neg_sims_sorted))) - for far, pos in zip(Fars, required_topk): - th = neg_sims_sorted[pos - 1] - recall = np.sum(pos_sims > th) / query_num - print("far = {:.10f} pr = {:.10f} th = {:.10f}".format(far, recall, th)) + correct_pos_cond = pos_sims > neg_sims.max(1) + non_gallery_sims_sorted = np.sort(non_gallery_sims.max(1))[::-1] + threshes, recalls = [], [] + for far in fars: + # thresh = non_gallery_sims_sorted[int(np.ceil(non_gallery_sims_sorted.shape[0] * far)) - 1] + thresh = non_gallery_sims_sorted[max(int((non_gallery_sims_sorted.shape[0]) * far) - 1, 0)] + recall = np.logical_and(correct_pos_cond, pos_sims > thresh).sum() / pos_sims.shape[0] + threshes.append(thresh) + recalls.append(recall) + # print("FAR = {:.10f} TPIR = {:.10f} th = {:.10f}".format(far, recall, thresh)) + cmc_scores = list(zip(neg_sims, pos_sims.reshape(-1, 1))) + list(zip(non_gallery_sims, [None] * non_gallery_sims.shape[0])) + return top_1_count, top_5_count, top_10_count, threshes, recalls, cmc_scores class IJB_test: @@ -348,6 +359,8 @@ class IJB_test: interf_func = keras_model_interf(model_file) elif model_file.endswith(".pth") or model_file.endswith(".pt"): interf_func = Torch_model_interf(model_file) + elif model_file.endswith(".onnx") or model_file.endswith(".ONNX"): + interf_func = ONNX_model_interf(model_file) else: interf_func = Mxnet_model_interf(model_file) self.embs, self.embs_f = get_embeddings(interf_func, img_names, landmarks, batch_size=batch_size) @@ -388,8 +401,11 @@ class IJB_test: scores.append(self.run_model_test_single(use_flip_test, use_norm_score, use_detector_score)) return scores, names - def run_model_test_1N(self): - gallery_templates, gallery_subject_ids, probe_mixed_templates, probe_mixed_subject_ids = extract_gallery_prob_data( + def run_model_test_1N(self, npoints=100): + fars_cal = [10 ** ii for ii in np.arange(-4, 0, 4 / npoints)] + [1] # plot in range [10-4, 1] + fars_show_idx = np.arange(len(fars_cal))[:: npoints // 4] # npoints=100, fars_show=[0.0001, 0.001, 0.01, 0.1, 1.0] + + g1_templates, g1_ids, g2_templates, g2_ids, probe_mixed_templates, probe_mixed_ids = extract_gallery_prob_data( self.data_path, self.subset, force_reload=self.force_reload ) img_input_feats = process_embeddings( @@ -400,21 +416,48 @@ class IJB_test: use_detector_score=True, face_scores=self.face_scores, ) - gallery_templates_feature, gallery_unique_templates, gallery_unique_subject_ids = image2template_feature( - img_input_feats, self.templates, self.medias, gallery_templates, gallery_subject_ids + g1_templates_feature, g1_unique_templates, g1_unique_ids = image2template_feature( + img_input_feats, self.templates, self.medias, g1_templates, g1_ids + ) + g2_templates_feature, g2_unique_templates, g2_unique_ids = image2template_feature( + img_input_feats, self.templates, self.medias, g2_templates, g2_ids ) - print("gallery_templates_feature:", gallery_templates_feature.shape) - print("gallery_unique_subject_ids:", gallery_unique_subject_ids.shape) - probe_mixed_templates_feature, probe_mixed_unique_templates, probe_mixed_unique_subject_ids = image2template_feature( - img_input_feats, self.templates, self.medias, probe_mixed_templates, probe_mixed_subject_ids + img_input_feats, self.templates, self.medias, probe_mixed_templates, probe_mixed_ids ) - print("probe_mixed_templates_feature:", probe_mixed_templates_feature.shape) - print("probe_mixed_unique_subject_ids:", probe_mixed_unique_subject_ids.shape) + print("g1_templates_feature:", g1_templates_feature.shape) # (1772, 512) + print("g2_templates_feature:", g2_templates_feature.shape) # (1759, 512) - evaluation_1N( - probe_mixed_templates_feature, gallery_templates_feature, probe_mixed_unique_subject_ids, gallery_unique_subject_ids + print("probe_mixed_templates_feature:", probe_mixed_templates_feature.shape) # (19593, 512) + print("probe_mixed_unique_subject_ids:", probe_mixed_unique_subject_ids.shape) # (19593,) + + print(">>>> Gallery 1") + g1_top_1_count, g1_top_5_count, g1_top_10_count, g1_threshes, g1_recalls, g1_cmc_scores = evaluation_1N( + probe_mixed_templates_feature, g1_templates_feature, probe_mixed_unique_subject_ids, g1_unique_ids, fars_cal ) + print(">>>> Gallery 2") + g2_top_1_count, g2_top_5_count, g2_top_10_count, g2_threshes, g2_recalls, g2_cmc_scores = evaluation_1N( + probe_mixed_templates_feature, g2_templates_feature, probe_mixed_unique_subject_ids, g2_unique_ids, fars_cal + ) + print(">>>> Mean") + query_num = probe_mixed_templates_feature.shape[0] + top_1 = (g1_top_1_count + g2_top_1_count) / query_num + top_5 = (g1_top_5_count + g2_top_5_count) / query_num + top_10 = (g1_top_10_count + g2_top_10_count) / query_num + print("[Mean] top1: %f, top5: %f, top10: %f" % (top_1, top_5, top_10)) + + mean_tpirs = (np.array(g1_recalls) + np.array(g2_recalls)) / 2 + show_result = {} + for id, far in enumerate(fars_cal): + if id in fars_show_idx: + show_result.setdefault("far", []).append(far) + show_result.setdefault("g1_tpir", []).append(g1_recalls[id]) + show_result.setdefault("g1_thresh", []).append(g1_threshes[id]) + show_result.setdefault("g2_tpir", []).append(g2_recalls[id]) + show_result.setdefault("g2_thresh", []).append(g2_threshes[id]) + show_result.setdefault("mean_tpir", []).append(mean_tpirs[id]) + print(pd.DataFrame(show_result).set_index("far").to_markdown()) + return fars_cal, mean_tpirs, g1_cmc_scores, g2_cmc_scores def plot_roc_and_calculate_tpr(scores, names=None, label=None): @@ -434,7 +477,7 @@ def plot_roc_and_calculate_tpr(scores, names=None, label=None): score_dict[name] = np.load(score) elif isinstance(score, str) and score.endswith(".txt"): # IJB meta data like ijbb_template_pair_label.txt - label = pd.read_csv(score, sep=" ").values[:, 2] + label = pd.read_csv(score, sep=" ", header=None).values[:, 2] else: name = name if name is not None else str(id) score_dict[name] = score @@ -451,6 +494,7 @@ def plot_roc_and_calculate_tpr(scores, names=None, label=None): tpr_result[name] = [tpr[np.argmin(abs(fpr - ii))] for ii in x_labels] fpr_dict[name], tpr_dict[name], roc_auc_dict[name] = fpr, tpr, roc_auc tpr_result_df = pd.DataFrame(tpr_result, index=x_labels).T + tpr_result_df['AUC'] = pd.Series(roc_auc_dict) tpr_result_df.columns.name = "Methods" print(tpr_result_df.to_markdown()) # print(tpr_result_df) @@ -461,35 +505,71 @@ def plot_roc_and_calculate_tpr(scores, names=None, label=None): fig = plt.figure() for name in score_dict: plt.plot(fpr_dict[name], tpr_dict[name], lw=1, label="[%s (AUC = %0.4f%%)]" % (name, roc_auc_dict[name] * 100)) + title = "ROC on IJB" + name.split("IJB")[-1][0] if "IJB" in name else "ROC on IJB" plt.xlim([10 ** -6, 0.1]) - plt.ylim([0.3, 1.0]) - plt.grid(linestyle="--", linewidth=1) - plt.xticks(x_labels) - plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True)) plt.xscale("log") + plt.xticks(x_labels) plt.xlabel("False Positive Rate") + plt.ylim([0.3, 1.0]) + plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True)) plt.ylabel("True Positive Rate") - plt.title("ROC on IJB") - plt.legend(loc="lower right") + + plt.grid(linestyle="--", linewidth=1) + plt.title(title) + plt.legend(loc="lower right", fontsize='x-small') plt.tight_layout() plt.show() except: - print("Missing matplotlib") + print("matplotlib plot failed") fig = None return tpr_result_df, fig +def plot_dir_far_cmc_scores(scores, names=None): + try: + import matplotlib.pyplot as plt + + fig = plt.figure() + for id, score in enumerate(scores): + name = None if names is None else names[id] + if isinstance(score, str) and score.endswith(".npz"): + aa = np.load(score) + score, name = aa.get("scores")[0], aa.get("names")[0] + fars, tpirs = score[0], score[1] + name = name if name is not None else str(id) + + auc_value = auc(fars, tpirs) + label = "[%s (AUC = %0.4f%%)]" % (name, auc_value * 100) + plt.plot(fars, tpirs, lw=1, label=label) + + plt.xlabel("False Alarm Rate") + plt.xlim([0.0001, 1]) + plt.xscale("log") + plt.ylabel("Detection & Identification Rate (%)") + plt.ylim([0, 1]) + + plt.grid(linestyle="--", linewidth=1) + plt.legend(fontsize='x-small') + plt.tight_layout() + plt.show() + except: + print("matplotlib plot failed") + fig = None + + return fig + + def parse_arguments(argv): import argparse - default_save_result_name = "IJB_result/{model_name}_{subset}.npz" + default_save_result_name = "IJB_result/{model_name}_{subset}_{type}.npz" parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument("-m", "--model_file", type=str, default=None, help="Saved model file, could be keras h5 / pytorch jit pth / mxnet") - parser.add_argument("-d", "--data_path", type=str, default="./", help="Dataset path") - parser.add_argument("-s", "--subset", type=str, default="IJBB", help="Subset test target, could be IJBB / IJBC") - parser.add_argument("-b", "--batch_size", type=int, default=64, help="Batch size for get_embeddings") + parser.add_argument("-m", "--model_file", type=str, default=None, help="Saved model, keras h5 / pytorch jit pth / onnx / mxnet") + parser.add_argument("-d", "--data_path", type=str, default="./", help="Dataset path containing IJBB and IJBC sub folder") + parser.add_argument("-s", "--subset", type=str, default="IJBC", help="Subset test target, could be IJBB / IJBC") + parser.add_argument("-b", "--batch_size", type=int, default=128, help="Batch size for get_embeddings") parser.add_argument( "-R", "--save_result", type=str, default=default_save_result_name, help="Filename for saving / restore result" ) @@ -513,15 +593,16 @@ def parse_arguments(argv): print("Please provide -m MODEL_FILE, see `--help` for usage.") exit(1) elif args.model_file != None: - if args.model_file.endswith(".h5") or args.model_file.endswith(".pth") or args.model_file.endswith(".pt"): - # Keras model file "model.h5", pytorch model ends with `.pth` or `.pt` + if args.model_file.endswith(".h5") or args.model_file.endswith(".pth") or args.model_file.endswith(".pt") or args.model_file.endswith(".onnx"): + # Keras model file "model.h5", pytorch model ends with `.pth` or `.pt`, onnx model ends with `.onnx` model_name = os.path.splitext(os.path.basename(args.model_file))[0] else: # MXNet model file "models/r50-arcface-emore/model,1" model_name = os.path.basename(os.path.dirname(args.model_file)) if args.save_result == default_save_result_name: - args.save_result = default_save_result_name.format(model_name=model_name, subset=args.subset) + type = "1N" if args.is_one_2_N else "11" + args.save_result = default_save_result_name.format(model_name=model_name, subset=args.subset, type=type) return args @@ -530,32 +611,45 @@ if __name__ == "__main__": args = parse_arguments(sys.argv[1:]) if args.plot_only != None and len(args.plot_only) != 0: - plot_roc_and_calculate_tpr(args.plot_only) + if args.is_one_2_N: + plot_dir_far_cmc_scores(args.plot_only) + else: + plot_roc_and_calculate_tpr(args.plot_only) else: - save_name = os.path.splitext(args.save_result)[0] + save_name = os.path.splitext(os.path.basename(args.save_result))[0] save_items = {} + save_path = os.path.dirname(args.save_result) + if len(save_path) != 0 and not os.path.exists(save_path): + os.makedirs(save_path) + tt = IJB_test(args.model_file, args.data_path, args.subset, args.batch_size, args.force_reload, args.save_result) + if args.save_embeddings: # Save embeddings first, in case of any error happens later... + np.savez(args.save_result, embs=tt.embs, embs_f=tt.embs_f) + if args.is_one_2_N: # 1:N test - tt.run_model_test_1N() + fars, tpirs, _, _ = tt.run_model_test_1N() + scores = [(fars, tpirs)] + names = [save_name] + save_items.update({"scores": scores, "names": names}) elif args.is_bunch: # All 8 tests N{0,1}D{0,1}F{0,1} scores, names = tt.run_model_test_bunch() names = [save_name + "_" + ii for ii in names] + label = tt.label save_items.update({"scores": scores, "names": names}) else: # Basic 1:1 N0D1F1 test score = tt.run_model_test_single() - scores, names = [score], [save_name] + scores, names, label = [score], [save_name], tt.label save_items.update({"scores": scores, "names": names}) if args.save_embeddings: save_items.update({"embs": tt.embs, "embs_f": tt.embs_f}) if args.save_label: - save_items.update({"label": tt.label}) + save_items.update({"label": label}) if args.model_file != None or args.save_embeddings: # embeddings not restored from file or should save_embeddings again - save_path = os.path.dirname(args.save_result) - if not os.path.exists(save_path): - os.makedirs(save_path) np.savez(args.save_result, **save_items) - if not args.is_one_2_N: - plot_roc_and_calculate_tpr(scores, names=names, label=tt.label) + if args.is_one_2_N: + plot_dir_far_cmc_scores(scores=scores, names=names) + else: + plot_roc_and_calculate_tpr(scores, names=names, label=label)