mirror of
https://github.com/deepinsight/insightface.git
synced 2025-12-30 08:02:27 +00:00
Fix IJB 1:N evaluation in ijb_evals.py
This commit is contained in:
@@ -1,12 +1,12 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from skimage import transform
|
||||
from sklearn.preprocessing import normalize
|
||||
from sklearn.metrics import roc_curve, auc
|
||||
import pandas as pd
|
||||
import cv2
|
||||
|
||||
|
||||
class Mxnet_model_interf:
|
||||
@@ -62,8 +62,27 @@ class Torch_model_interf:
|
||||
return output.cpu().detach().numpy()
|
||||
|
||||
|
||||
class ONNX_model_interf:
|
||||
def __init__(self, model_file, image_size=(112, 112)):
|
||||
import onnxruntime as ort
|
||||
ort.set_default_logger_severity(3)
|
||||
self.ort_session = ort.InferenceSession(model_file)
|
||||
self.output_names = [self.ort_session.get_outputs()[0].name]
|
||||
self.input_name = self.ort_session.get_inputs()[0].name
|
||||
|
||||
def __call__(self, imgs):
|
||||
imgs = imgs.transpose(0, 3, 1, 2).astype("float32")
|
||||
imgs = (imgs - 127.5) * 0.0078125
|
||||
outputs = self.ort_session.run(self.output_names, {self.input_name: imgs})
|
||||
return outputs[0]
|
||||
|
||||
|
||||
def keras_model_interf(model_file):
|
||||
import tensorflow as tf
|
||||
from tensorflow_addons.layers import StochasticDepth
|
||||
|
||||
for gpu in tf.config.experimental.list_physical_devices("GPU"):
|
||||
tf.config.experimental.set_memory_growth(gpu, True)
|
||||
|
||||
mm = tf.keras.models.load_model(model_file, compile=False)
|
||||
return lambda imgs: mm((tf.cast(imgs, "float32") - 127.5) * 0.0078125).numpy()
|
||||
@@ -96,7 +115,7 @@ def extract_IJB_data_11(data_path, subset, save_path=None, force_reload=False):
|
||||
if save_path == None:
|
||||
save_path = os.path.join(data_path, subset + "_backup.npz")
|
||||
if not force_reload and os.path.exists(save_path):
|
||||
print(">>>> Reloading from backup: %s ..." % save_path)
|
||||
print(">>>> Reload from backup: %s ..." % save_path)
|
||||
aa = np.load(save_path)
|
||||
return (
|
||||
aa["templates"],
|
||||
@@ -123,15 +142,16 @@ def extract_IJB_data_11(data_path, subset, save_path=None, force_reload=False):
|
||||
print(">>>> Loading templates and medias...")
|
||||
templates, medias = read_IJB_meta_columns_to_int(media_list_path, columns=[1, 2]) # ['1.jpg', '1', '69544']
|
||||
print("templates: %s, medias: %s, unique templates: %s" % (templates.shape, medias.shape, np.unique(templates).shape))
|
||||
# (227630,) (227630,) (12115,)
|
||||
# templates: (227630,), medias: (227630,), unique templates: (12115,)
|
||||
|
||||
print(">>>> Loading pairs...")
|
||||
p1, p2, label = read_IJB_meta_columns_to_int(pair_list_path, columns=[0, 1, 2]) # ['1', '11065', '1']
|
||||
print("p1: %s, unique p1: %s" % (p1.shape, np.unique(p1).shape))
|
||||
print("p2: %s, unique p2: %s" % (p2.shape, np.unique(p2).shape))
|
||||
print("label: %s, label value counts: %s" % (label.shape, dict(zip(*np.unique(label, return_counts=True)))))
|
||||
# (8010270,) (8010270,) (8010270,) (1845,) (10270,) # 10270 + 1845 = 12115
|
||||
# {0: 8000000, 1: 10270}
|
||||
# p1: (8010270,), unique p1: (1845,)
|
||||
# p2: (8010270,), unique p2: (10270,) # 10270 + 1845 = 12115 --> np.unique(templates).shape
|
||||
# label: (8010270,), label value counts: {0: 8000000, 1: 10270}
|
||||
|
||||
print(">>>> Loading images...")
|
||||
with open(img_list_path, "r") as ff:
|
||||
@@ -142,7 +162,7 @@ def extract_IJB_data_11(data_path, subset, save_path=None, force_reload=False):
|
||||
landmarks = img_records[:, 1:-1].astype("float32").reshape(-1, 5, 2)
|
||||
face_scores = img_records[:, -1].astype("float32")
|
||||
print("img_names: %s, landmarks: %s, face_scores: %s" % (img_names.shape, landmarks.shape, face_scores.shape))
|
||||
# (227630,) (227630, 5, 2) (227630,)
|
||||
# img_names: (227630,), landmarks: (227630, 5, 2), face_scores: (227630,)
|
||||
print("face_scores value counts:", dict(zip(*np.histogram(face_scores, bins=9)[::-1])))
|
||||
# {0.1: 2515, 0.2: 0, 0.3: 62, 0.4: 94, 0.5: 136, 0.6: 197, 0.7: 291, 0.8: 538, 0.9: 223797}
|
||||
|
||||
@@ -166,11 +186,13 @@ def extract_gallery_prob_data(data_path, subset, save_path=None, force_reload=Fa
|
||||
if save_path == None:
|
||||
save_path = os.path.join(data_path, subset + "_gallery_prob_backup.npz")
|
||||
if not force_reload and os.path.exists(save_path):
|
||||
print(">>>> Reloading from backup: %s ..." % save_path)
|
||||
print(">>>> Reload from backup: %s ..." % save_path)
|
||||
aa = np.load(save_path)
|
||||
return (
|
||||
aa["gallery_templates"],
|
||||
aa["gallery_subject_ids"],
|
||||
aa["s1_templates"],
|
||||
aa["s1_subject_ids"],
|
||||
aa["s2_templates"],
|
||||
aa["s2_subject_ids"],
|
||||
aa["probe_mixed_templates"],
|
||||
aa["probe_mixed_subject_ids"],
|
||||
)
|
||||
@@ -189,14 +211,8 @@ def extract_gallery_prob_data(data_path, subset, save_path=None, force_reload=Fa
|
||||
print(">>>> Loading gallery feature...")
|
||||
s1_templates, s1_subject_ids = read_IJB_meta_columns_to_int(gallery_s1_record, columns=[0, 1], skiprows=1, sep=",")
|
||||
s2_templates, s2_subject_ids = read_IJB_meta_columns_to_int(gallery_s2_record, columns=[0, 1], skiprows=1, sep=",")
|
||||
gallery_templates = np.concatenate([s1_templates, s2_templates])
|
||||
gallery_subject_ids = np.concatenate([s1_subject_ids, s2_subject_ids])
|
||||
print("s1 gallery: %s, ids: %s, unique: %s" % (s1_templates.shape, s1_subject_ids.shape, np.unique(s1_templates).shape))
|
||||
print("s2 gallery: %s, ids: %s, unique: %s" % (s2_templates.shape, s2_subject_ids.shape, np.unique(s2_templates).shape))
|
||||
print(
|
||||
"total gallery: %s, ids: %s, unique: %s"
|
||||
% (gallery_templates.shape, gallery_subject_ids.shape, np.unique(gallery_templates).shape)
|
||||
)
|
||||
|
||||
print(">>>> Loading prope feature...")
|
||||
probe_mixed_templates, probe_mixed_subject_ids = read_IJB_meta_columns_to_int(
|
||||
@@ -208,13 +224,15 @@ def extract_gallery_prob_data(data_path, subset, save_path=None, force_reload=Fa
|
||||
print(">>>> Saving backup to: %s ..." % save_path)
|
||||
np.savez(
|
||||
save_path,
|
||||
gallery_templates=gallery_templates,
|
||||
gallery_subject_ids=gallery_subject_ids,
|
||||
s1_templates=s1_templates,
|
||||
s1_subject_ids=s1_subject_ids,
|
||||
s2_templates=s2_templates,
|
||||
s2_subject_ids=s2_subject_ids,
|
||||
probe_mixed_templates=probe_mixed_templates,
|
||||
probe_mixed_subject_ids=probe_mixed_subject_ids,
|
||||
)
|
||||
print()
|
||||
return gallery_templates, gallery_subject_ids, probe_mixed_templates, probe_mixed_subject_ids
|
||||
return s1_templates, s1_subject_ids, s2_templates, s2_subject_ids, probe_mixed_templates, probe_mixed_subject_ids
|
||||
|
||||
|
||||
def get_embeddings(model_interf, img_names, landmarks, batch_size=64, flip=True):
|
||||
@@ -231,6 +249,7 @@ def get_embeddings(model_interf, img_names, landmarks, batch_size=64, flip=True)
|
||||
|
||||
|
||||
def process_embeddings(embs, embs_f=[], use_flip_test=True, use_norm_score=False, use_detector_score=True, face_scores=None):
|
||||
print(">>>> process_embeddings: Norm {}, Detect_score {}, Flip {}".format(use_norm_score, use_detector_score, use_flip_test))
|
||||
if use_flip_test and len(embs_f) != 0:
|
||||
embs = embs + embs_f
|
||||
if use_norm_score:
|
||||
@@ -241,10 +260,10 @@ def process_embeddings(embs, embs_f=[], use_flip_test=True, use_norm_score=False
|
||||
|
||||
|
||||
def image2template_feature(img_feats=None, templates=None, medias=None, choose_templates=None, choose_ids=None):
|
||||
if choose_templates is not None: # 1N
|
||||
if choose_templates is not None: # 1:N
|
||||
unique_templates, indices = np.unique(choose_templates, return_index=True)
|
||||
unique_subjectids = choose_ids[indices]
|
||||
else: # 11
|
||||
else: # 1:1
|
||||
unique_templates = np.unique(templates)
|
||||
unique_subjectids = None
|
||||
|
||||
@@ -269,73 +288,65 @@ def image2template_feature(img_feats=None, templates=None, medias=None, choose_t
|
||||
return template_norm_feats, unique_templates, unique_subjectids
|
||||
|
||||
|
||||
def verification_11(template_norm_feats=None, unique_templates=None, p1=None, p2=None, batch_size=100000):
|
||||
def verification_11(template_norm_feats=None, unique_templates=None, p1=None, p2=None, batch_size=10000):
|
||||
try:
|
||||
print(">>>> Trying cupy.")
|
||||
import cupy as cp
|
||||
|
||||
template_norm_feats = cp.array(template_norm_feats)
|
||||
score_func = lambda feat1, feat2: cp.sum(feat1 * feat2, axis=-1).get()
|
||||
test = score_func(template_norm_feats[:batch_size], template_norm_feats[:batch_size])
|
||||
except:
|
||||
score_func = lambda feat1, feat2: np.sum(feat1 * feat2, -1)
|
||||
|
||||
template2id = np.zeros(max(unique_templates) + 1, dtype=int)
|
||||
for count_template, uqt in enumerate(unique_templates):
|
||||
template2id[uqt] = count_template
|
||||
template2id[unique_templates] = np.arange(len(unique_templates))
|
||||
|
||||
steps = int(np.ceil(len(p1) / batch_size))
|
||||
score = []
|
||||
for id in tqdm(range(steps), "Verification"):
|
||||
feat1 = template_norm_feats[template2id[p1[id * batch_size : (id + 1) * batch_size]]]
|
||||
feat2 = template_norm_feats[template2id[p2[id * batch_size : (id + 1) * batch_size]]]
|
||||
score.extend(np.sum(feat1 * feat2, -1))
|
||||
score.extend(score_func(feat1, feat2))
|
||||
return np.array(score)
|
||||
|
||||
|
||||
def evaluation_1N(query_feats, gallery_feats, query_ids, reg_ids):
|
||||
import heapq
|
||||
|
||||
Fars = [0.01, 0.1]
|
||||
def evaluation_1N(query_feats, gallery_feats, query_ids, reg_ids, fars=[0.01, 0.1]):
|
||||
print("query_feats: %s, gallery_feats: %s" % (query_feats.shape, gallery_feats.shape))
|
||||
similarity = np.dot(query_feats, gallery_feats.T) # (19593, 3531)
|
||||
|
||||
query_num = query_feats.shape[0]
|
||||
gallery_num = gallery_feats.shape[0]
|
||||
top_1_count, top_5_count, top_10_count = 0, 0, 0
|
||||
pos_sims, neg_sims, non_gallery_sims = [], [], []
|
||||
for index, query_id in enumerate(query_ids):
|
||||
if query_id in reg_ids:
|
||||
gallery_label = np.argwhere(reg_ids == query_id)[0, 0]
|
||||
index_sorted = np.argsort(similarity[index])[::-1]
|
||||
|
||||
similarity = np.dot(query_feats, gallery_feats.T)
|
||||
print("similarity shape:", similarity.shape)
|
||||
top_inds = np.argsort(-similarity)
|
||||
print("top_inds shape:", top_inds.shape)
|
||||
top_1_count += gallery_label in index_sorted[:1]
|
||||
top_5_count += gallery_label in index_sorted[:5]
|
||||
top_10_count += gallery_label in index_sorted[:10]
|
||||
|
||||
# gen_mask
|
||||
mask = []
|
||||
for query_id in query_ids:
|
||||
pos = [i for i, x in enumerate(reg_ids) if query_id == x]
|
||||
if len(pos) != 1:
|
||||
raise RuntimeError("RegIdsError with id = {}, duplicate = {} ".format(query_id, len(pos)))
|
||||
mask.append(pos[0])
|
||||
pos_sims.append(similarity[index][reg_ids == query_id][0])
|
||||
neg_sims.append(similarity[index][reg_ids != query_id])
|
||||
else:
|
||||
non_gallery_sims.append(similarity[index])
|
||||
total_pos = len(pos_sims)
|
||||
pos_sims, neg_sims, non_gallery_sims = np.array(pos_sims), np.array(neg_sims), np.array(non_gallery_sims)
|
||||
print("pos_sims: %s, neg_sims: %s, non_gallery_sims: %s" % (pos_sims.shape, neg_sims.shape, non_gallery_sims.shape))
|
||||
print("top1: %f, top5: %f, top10: %f" % (top_1_count / total_pos, top_5_count / total_pos, top_10_count / total_pos))
|
||||
|
||||
# calculate top_n
|
||||
correct_num_1, correct_num_5, correct_num_10 = 0, 0, 0
|
||||
for i in range(query_num):
|
||||
top_1, top_5, top_10 = top_inds[i, 0], top_inds[i, 0:5], top_inds[i, 0:10]
|
||||
if mask[i] == top_1:
|
||||
correct_num_1 += 1
|
||||
if mask[i] in top_5:
|
||||
correct_num_5 += 1
|
||||
if mask[i] in top_10:
|
||||
correct_num_10 += 1
|
||||
print("top1: %f, top5: %f, top10: %f" % (correct_num_1 / query_num, correct_num_5 / query_num, correct_num_10 / query_num))
|
||||
|
||||
# neg_pair_num = query_num * gallery_num - query_num
|
||||
# print("neg_pair_num:", neg_pair_num)
|
||||
required_topk = [int(np.ceil(query_num * x)) for x in Fars]
|
||||
top_sims = similarity
|
||||
# calculate fars and tprs
|
||||
pos_sims = []
|
||||
for i in range(query_num):
|
||||
gt = mask[i]
|
||||
pos_sims.append(top_sims[i, gt])
|
||||
top_sims[i, gt] = -2.0
|
||||
|
||||
pos_sims = np.array(pos_sims)
|
||||
neg_sims = top_sims[np.where(top_sims > -2.0)]
|
||||
neg_sims_sorted = heapq.nlargest(max(required_topk), neg_sims) # heap sort
|
||||
print("pos_sims: %s, neg_sims: %s, neg_sims_sorted: %d" % (pos_sims.shape, neg_sims.shape, len(neg_sims_sorted)))
|
||||
for far, pos in zip(Fars, required_topk):
|
||||
th = neg_sims_sorted[pos - 1]
|
||||
recall = np.sum(pos_sims > th) / query_num
|
||||
print("far = {:.10f} pr = {:.10f} th = {:.10f}".format(far, recall, th))
|
||||
correct_pos_cond = pos_sims > neg_sims.max(1)
|
||||
non_gallery_sims_sorted = np.sort(non_gallery_sims.max(1))[::-1]
|
||||
threshes, recalls = [], []
|
||||
for far in fars:
|
||||
# thresh = non_gallery_sims_sorted[int(np.ceil(non_gallery_sims_sorted.shape[0] * far)) - 1]
|
||||
thresh = non_gallery_sims_sorted[max(int((non_gallery_sims_sorted.shape[0]) * far) - 1, 0)]
|
||||
recall = np.logical_and(correct_pos_cond, pos_sims > thresh).sum() / pos_sims.shape[0]
|
||||
threshes.append(thresh)
|
||||
recalls.append(recall)
|
||||
# print("FAR = {:.10f} TPIR = {:.10f} th = {:.10f}".format(far, recall, thresh))
|
||||
cmc_scores = list(zip(neg_sims, pos_sims.reshape(-1, 1))) + list(zip(non_gallery_sims, [None] * non_gallery_sims.shape[0]))
|
||||
return top_1_count, top_5_count, top_10_count, threshes, recalls, cmc_scores
|
||||
|
||||
|
||||
class IJB_test:
|
||||
@@ -348,6 +359,8 @@ class IJB_test:
|
||||
interf_func = keras_model_interf(model_file)
|
||||
elif model_file.endswith(".pth") or model_file.endswith(".pt"):
|
||||
interf_func = Torch_model_interf(model_file)
|
||||
elif model_file.endswith(".onnx") or model_file.endswith(".ONNX"):
|
||||
interf_func = ONNX_model_interf(model_file)
|
||||
else:
|
||||
interf_func = Mxnet_model_interf(model_file)
|
||||
self.embs, self.embs_f = get_embeddings(interf_func, img_names, landmarks, batch_size=batch_size)
|
||||
@@ -388,8 +401,11 @@ class IJB_test:
|
||||
scores.append(self.run_model_test_single(use_flip_test, use_norm_score, use_detector_score))
|
||||
return scores, names
|
||||
|
||||
def run_model_test_1N(self):
|
||||
gallery_templates, gallery_subject_ids, probe_mixed_templates, probe_mixed_subject_ids = extract_gallery_prob_data(
|
||||
def run_model_test_1N(self, npoints=100):
|
||||
fars_cal = [10 ** ii for ii in np.arange(-4, 0, 4 / npoints)] + [1] # plot in range [10-4, 1]
|
||||
fars_show_idx = np.arange(len(fars_cal))[:: npoints // 4] # npoints=100, fars_show=[0.0001, 0.001, 0.01, 0.1, 1.0]
|
||||
|
||||
g1_templates, g1_ids, g2_templates, g2_ids, probe_mixed_templates, probe_mixed_ids = extract_gallery_prob_data(
|
||||
self.data_path, self.subset, force_reload=self.force_reload
|
||||
)
|
||||
img_input_feats = process_embeddings(
|
||||
@@ -400,21 +416,48 @@ class IJB_test:
|
||||
use_detector_score=True,
|
||||
face_scores=self.face_scores,
|
||||
)
|
||||
gallery_templates_feature, gallery_unique_templates, gallery_unique_subject_ids = image2template_feature(
|
||||
img_input_feats, self.templates, self.medias, gallery_templates, gallery_subject_ids
|
||||
g1_templates_feature, g1_unique_templates, g1_unique_ids = image2template_feature(
|
||||
img_input_feats, self.templates, self.medias, g1_templates, g1_ids
|
||||
)
|
||||
g2_templates_feature, g2_unique_templates, g2_unique_ids = image2template_feature(
|
||||
img_input_feats, self.templates, self.medias, g2_templates, g2_ids
|
||||
)
|
||||
print("gallery_templates_feature:", gallery_templates_feature.shape)
|
||||
print("gallery_unique_subject_ids:", gallery_unique_subject_ids.shape)
|
||||
|
||||
probe_mixed_templates_feature, probe_mixed_unique_templates, probe_mixed_unique_subject_ids = image2template_feature(
|
||||
img_input_feats, self.templates, self.medias, probe_mixed_templates, probe_mixed_subject_ids
|
||||
img_input_feats, self.templates, self.medias, probe_mixed_templates, probe_mixed_ids
|
||||
)
|
||||
print("probe_mixed_templates_feature:", probe_mixed_templates_feature.shape)
|
||||
print("probe_mixed_unique_subject_ids:", probe_mixed_unique_subject_ids.shape)
|
||||
print("g1_templates_feature:", g1_templates_feature.shape) # (1772, 512)
|
||||
print("g2_templates_feature:", g2_templates_feature.shape) # (1759, 512)
|
||||
|
||||
evaluation_1N(
|
||||
probe_mixed_templates_feature, gallery_templates_feature, probe_mixed_unique_subject_ids, gallery_unique_subject_ids
|
||||
print("probe_mixed_templates_feature:", probe_mixed_templates_feature.shape) # (19593, 512)
|
||||
print("probe_mixed_unique_subject_ids:", probe_mixed_unique_subject_ids.shape) # (19593,)
|
||||
|
||||
print(">>>> Gallery 1")
|
||||
g1_top_1_count, g1_top_5_count, g1_top_10_count, g1_threshes, g1_recalls, g1_cmc_scores = evaluation_1N(
|
||||
probe_mixed_templates_feature, g1_templates_feature, probe_mixed_unique_subject_ids, g1_unique_ids, fars_cal
|
||||
)
|
||||
print(">>>> Gallery 2")
|
||||
g2_top_1_count, g2_top_5_count, g2_top_10_count, g2_threshes, g2_recalls, g2_cmc_scores = evaluation_1N(
|
||||
probe_mixed_templates_feature, g2_templates_feature, probe_mixed_unique_subject_ids, g2_unique_ids, fars_cal
|
||||
)
|
||||
print(">>>> Mean")
|
||||
query_num = probe_mixed_templates_feature.shape[0]
|
||||
top_1 = (g1_top_1_count + g2_top_1_count) / query_num
|
||||
top_5 = (g1_top_5_count + g2_top_5_count) / query_num
|
||||
top_10 = (g1_top_10_count + g2_top_10_count) / query_num
|
||||
print("[Mean] top1: %f, top5: %f, top10: %f" % (top_1, top_5, top_10))
|
||||
|
||||
mean_tpirs = (np.array(g1_recalls) + np.array(g2_recalls)) / 2
|
||||
show_result = {}
|
||||
for id, far in enumerate(fars_cal):
|
||||
if id in fars_show_idx:
|
||||
show_result.setdefault("far", []).append(far)
|
||||
show_result.setdefault("g1_tpir", []).append(g1_recalls[id])
|
||||
show_result.setdefault("g1_thresh", []).append(g1_threshes[id])
|
||||
show_result.setdefault("g2_tpir", []).append(g2_recalls[id])
|
||||
show_result.setdefault("g2_thresh", []).append(g2_threshes[id])
|
||||
show_result.setdefault("mean_tpir", []).append(mean_tpirs[id])
|
||||
print(pd.DataFrame(show_result).set_index("far").to_markdown())
|
||||
return fars_cal, mean_tpirs, g1_cmc_scores, g2_cmc_scores
|
||||
|
||||
|
||||
def plot_roc_and_calculate_tpr(scores, names=None, label=None):
|
||||
@@ -434,7 +477,7 @@ def plot_roc_and_calculate_tpr(scores, names=None, label=None):
|
||||
score_dict[name] = np.load(score)
|
||||
elif isinstance(score, str) and score.endswith(".txt"):
|
||||
# IJB meta data like ijbb_template_pair_label.txt
|
||||
label = pd.read_csv(score, sep=" ").values[:, 2]
|
||||
label = pd.read_csv(score, sep=" ", header=None).values[:, 2]
|
||||
else:
|
||||
name = name if name is not None else str(id)
|
||||
score_dict[name] = score
|
||||
@@ -451,6 +494,7 @@ def plot_roc_and_calculate_tpr(scores, names=None, label=None):
|
||||
tpr_result[name] = [tpr[np.argmin(abs(fpr - ii))] for ii in x_labels]
|
||||
fpr_dict[name], tpr_dict[name], roc_auc_dict[name] = fpr, tpr, roc_auc
|
||||
tpr_result_df = pd.DataFrame(tpr_result, index=x_labels).T
|
||||
tpr_result_df['AUC'] = pd.Series(roc_auc_dict)
|
||||
tpr_result_df.columns.name = "Methods"
|
||||
print(tpr_result_df.to_markdown())
|
||||
# print(tpr_result_df)
|
||||
@@ -461,35 +505,71 @@ def plot_roc_and_calculate_tpr(scores, names=None, label=None):
|
||||
fig = plt.figure()
|
||||
for name in score_dict:
|
||||
plt.plot(fpr_dict[name], tpr_dict[name], lw=1, label="[%s (AUC = %0.4f%%)]" % (name, roc_auc_dict[name] * 100))
|
||||
title = "ROC on IJB" + name.split("IJB")[-1][0] if "IJB" in name else "ROC on IJB"
|
||||
|
||||
plt.xlim([10 ** -6, 0.1])
|
||||
plt.ylim([0.3, 1.0])
|
||||
plt.grid(linestyle="--", linewidth=1)
|
||||
plt.xticks(x_labels)
|
||||
plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True))
|
||||
plt.xscale("log")
|
||||
plt.xticks(x_labels)
|
||||
plt.xlabel("False Positive Rate")
|
||||
plt.ylim([0.3, 1.0])
|
||||
plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True))
|
||||
plt.ylabel("True Positive Rate")
|
||||
plt.title("ROC on IJB")
|
||||
plt.legend(loc="lower right")
|
||||
|
||||
plt.grid(linestyle="--", linewidth=1)
|
||||
plt.title(title)
|
||||
plt.legend(loc="lower right", fontsize='x-small')
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
except:
|
||||
print("Missing matplotlib")
|
||||
print("matplotlib plot failed")
|
||||
fig = None
|
||||
|
||||
return tpr_result_df, fig
|
||||
|
||||
|
||||
def plot_dir_far_cmc_scores(scores, names=None):
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
fig = plt.figure()
|
||||
for id, score in enumerate(scores):
|
||||
name = None if names is None else names[id]
|
||||
if isinstance(score, str) and score.endswith(".npz"):
|
||||
aa = np.load(score)
|
||||
score, name = aa.get("scores")[0], aa.get("names")[0]
|
||||
fars, tpirs = score[0], score[1]
|
||||
name = name if name is not None else str(id)
|
||||
|
||||
auc_value = auc(fars, tpirs)
|
||||
label = "[%s (AUC = %0.4f%%)]" % (name, auc_value * 100)
|
||||
plt.plot(fars, tpirs, lw=1, label=label)
|
||||
|
||||
plt.xlabel("False Alarm Rate")
|
||||
plt.xlim([0.0001, 1])
|
||||
plt.xscale("log")
|
||||
plt.ylabel("Detection & Identification Rate (%)")
|
||||
plt.ylim([0, 1])
|
||||
|
||||
plt.grid(linestyle="--", linewidth=1)
|
||||
plt.legend(fontsize='x-small')
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
except:
|
||||
print("matplotlib plot failed")
|
||||
fig = None
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def parse_arguments(argv):
|
||||
import argparse
|
||||
|
||||
default_save_result_name = "IJB_result/{model_name}_{subset}.npz"
|
||||
default_save_result_name = "IJB_result/{model_name}_{subset}_{type}.npz"
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument("-m", "--model_file", type=str, default=None, help="Saved model file, could be keras h5 / pytorch jit pth / mxnet")
|
||||
parser.add_argument("-d", "--data_path", type=str, default="./", help="Dataset path")
|
||||
parser.add_argument("-s", "--subset", type=str, default="IJBB", help="Subset test target, could be IJBB / IJBC")
|
||||
parser.add_argument("-b", "--batch_size", type=int, default=64, help="Batch size for get_embeddings")
|
||||
parser.add_argument("-m", "--model_file", type=str, default=None, help="Saved model, keras h5 / pytorch jit pth / onnx / mxnet")
|
||||
parser.add_argument("-d", "--data_path", type=str, default="./", help="Dataset path containing IJBB and IJBC sub folder")
|
||||
parser.add_argument("-s", "--subset", type=str, default="IJBC", help="Subset test target, could be IJBB / IJBC")
|
||||
parser.add_argument("-b", "--batch_size", type=int, default=128, help="Batch size for get_embeddings")
|
||||
parser.add_argument(
|
||||
"-R", "--save_result", type=str, default=default_save_result_name, help="Filename for saving / restore result"
|
||||
)
|
||||
@@ -513,15 +593,16 @@ def parse_arguments(argv):
|
||||
print("Please provide -m MODEL_FILE, see `--help` for usage.")
|
||||
exit(1)
|
||||
elif args.model_file != None:
|
||||
if args.model_file.endswith(".h5") or args.model_file.endswith(".pth") or args.model_file.endswith(".pt"):
|
||||
# Keras model file "model.h5", pytorch model ends with `.pth` or `.pt`
|
||||
if args.model_file.endswith(".h5") or args.model_file.endswith(".pth") or args.model_file.endswith(".pt") or args.model_file.endswith(".onnx"):
|
||||
# Keras model file "model.h5", pytorch model ends with `.pth` or `.pt`, onnx model ends with `.onnx`
|
||||
model_name = os.path.splitext(os.path.basename(args.model_file))[0]
|
||||
else:
|
||||
# MXNet model file "models/r50-arcface-emore/model,1"
|
||||
model_name = os.path.basename(os.path.dirname(args.model_file))
|
||||
|
||||
if args.save_result == default_save_result_name:
|
||||
args.save_result = default_save_result_name.format(model_name=model_name, subset=args.subset)
|
||||
type = "1N" if args.is_one_2_N else "11"
|
||||
args.save_result = default_save_result_name.format(model_name=model_name, subset=args.subset, type=type)
|
||||
return args
|
||||
|
||||
|
||||
@@ -530,32 +611,45 @@ if __name__ == "__main__":
|
||||
|
||||
args = parse_arguments(sys.argv[1:])
|
||||
if args.plot_only != None and len(args.plot_only) != 0:
|
||||
plot_roc_and_calculate_tpr(args.plot_only)
|
||||
if args.is_one_2_N:
|
||||
plot_dir_far_cmc_scores(args.plot_only)
|
||||
else:
|
||||
plot_roc_and_calculate_tpr(args.plot_only)
|
||||
else:
|
||||
save_name = os.path.splitext(args.save_result)[0]
|
||||
save_name = os.path.splitext(os.path.basename(args.save_result))[0]
|
||||
save_items = {}
|
||||
save_path = os.path.dirname(args.save_result)
|
||||
if len(save_path) != 0 and not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
|
||||
tt = IJB_test(args.model_file, args.data_path, args.subset, args.batch_size, args.force_reload, args.save_result)
|
||||
if args.save_embeddings: # Save embeddings first, in case of any error happens later...
|
||||
np.savez(args.save_result, embs=tt.embs, embs_f=tt.embs_f)
|
||||
|
||||
if args.is_one_2_N: # 1:N test
|
||||
tt.run_model_test_1N()
|
||||
fars, tpirs, _, _ = tt.run_model_test_1N()
|
||||
scores = [(fars, tpirs)]
|
||||
names = [save_name]
|
||||
save_items.update({"scores": scores, "names": names})
|
||||
elif args.is_bunch: # All 8 tests N{0,1}D{0,1}F{0,1}
|
||||
scores, names = tt.run_model_test_bunch()
|
||||
names = [save_name + "_" + ii for ii in names]
|
||||
label = tt.label
|
||||
save_items.update({"scores": scores, "names": names})
|
||||
else: # Basic 1:1 N0D1F1 test
|
||||
score = tt.run_model_test_single()
|
||||
scores, names = [score], [save_name]
|
||||
scores, names, label = [score], [save_name], tt.label
|
||||
save_items.update({"scores": scores, "names": names})
|
||||
|
||||
if args.save_embeddings:
|
||||
save_items.update({"embs": tt.embs, "embs_f": tt.embs_f})
|
||||
if args.save_label:
|
||||
save_items.update({"label": tt.label})
|
||||
save_items.update({"label": label})
|
||||
|
||||
if args.model_file != None or args.save_embeddings: # embeddings not restored from file or should save_embeddings again
|
||||
save_path = os.path.dirname(args.save_result)
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
np.savez(args.save_result, **save_items)
|
||||
|
||||
if not args.is_one_2_N:
|
||||
plot_roc_and_calculate_tpr(scores, names=names, label=tt.label)
|
||||
if args.is_one_2_N:
|
||||
plot_dir_far_cmc_scores(scores=scores, names=names)
|
||||
else:
|
||||
plot_roc_and_calculate_tpr(scores, names=names, label=label)
|
||||
|
||||
Reference in New Issue
Block a user