mirror of
https://github.com/MarcosRodrigoT/ViT-Face-Recognition.git
synced 2025-12-30 08:02:29 +00:00
Add compute_roc() & plot_and_csv() functions
This commit is contained in:
111
scface_test.py
111
scface_test.py
@@ -1,8 +1,12 @@
|
||||
import os
|
||||
import pickle
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import tensorflow as tf
|
||||
import matplotlib.pyplot as plt
|
||||
from vit_keras import vit
|
||||
from scipy.spatial.distance import cosine
|
||||
from sklearn.metrics import roc_curve, auc
|
||||
|
||||
|
||||
def preprocess_image(img_path):
|
||||
@@ -20,6 +24,89 @@ def compute_score(embeddings1, embeddings2):
|
||||
return score
|
||||
|
||||
|
||||
def plot_and_csv(models, ground_truth_, cameras, distances, positive_label=1):
|
||||
results_dir = os.path.join('./saved_results/Tests/SCface', f"ROC-CAMERAS-{'_'.join(cameras)}-DISTANCES-{'_'.join(distances)}")
|
||||
try:
|
||||
os.mkdir(results_dir)
|
||||
except FileExistsError:
|
||||
pass
|
||||
|
||||
# Figure
|
||||
fig, ax = plt.subplots(1, 1, figsize=(10, 9))
|
||||
|
||||
for model in models.keys():
|
||||
model_name_ = models[model]['name']
|
||||
model_color_ = models[model]['color']
|
||||
model_scores_ = models[model]['scores']
|
||||
|
||||
# Data
|
||||
fpr, tpr, thresholds = roc_curve(ground_truth_, model_scores_, pos_label=positive_label)
|
||||
auc_result = auc(fpr, tpr)
|
||||
fnr = 1 - tpr
|
||||
eer = fpr[np.argmin(np.absolute(fnr - fpr))]
|
||||
eer_threshold = thresholds[np.argmin(np.absolute(fnr - fpr))]
|
||||
|
||||
# Plot
|
||||
ax.plot(fpr, tpr, linestyle='-', lw=3, color=model_color_, label=f'{model_name_} (EER={eer:.2f}, AUC={auc_result:.3f})')
|
||||
ax.scatter(eer, tpr[np.argmin(np.absolute(fnr - fpr))], color=model_color_, linewidths=8, zorder=10)
|
||||
|
||||
# CSV
|
||||
result_pd = pd.DataFrame({'FPR': fpr, 'TPR': tpr})
|
||||
result_pd['EER'] = pd.DataFrame([eer, tpr[np.argmin(np.absolute(fnr - fpr))]])
|
||||
result_pd.to_csv(f"{results_dir}/{model_name_}_ROC.csv", header=True, index=False)
|
||||
|
||||
ax.set_title('Receiver Operating Characteristics (ROC)', fontsize=15)
|
||||
ax.set_xlabel('FPR', fontsize=15)
|
||||
ax.set_ylabel('TPR', fontsize=15)
|
||||
plt.xticks(fontsize=15)
|
||||
plt.yticks(fontsize=15)
|
||||
ax.legend(loc='lower right', prop={"size": 11})
|
||||
|
||||
ax.set_xlim([0.0, 1.0])
|
||||
ax.set_ylim([0.0, 1.05])
|
||||
plt.savefig(f"{results_dir}/ROC.png", bbox_inches='tight')
|
||||
|
||||
ax.set_xlim([0.0, 0.3])
|
||||
ax.set_ylim([0.7, 1.0])
|
||||
plt.savefig(f"{results_dir}/ROC_zoom.png", bbox_inches='tight')
|
||||
|
||||
|
||||
def compute_roc(scores_dict, cameras, distances):
|
||||
vit_scores = []
|
||||
resnet_scores = []
|
||||
vgg_scores = []
|
||||
inception_scores = []
|
||||
mobilenet_scores = []
|
||||
efficientnet_scores = []
|
||||
ground_truth = []
|
||||
|
||||
for mug_person_ in scores_dict.keys():
|
||||
for sur_item_, sur_values_ in scores_dict[mug_person_].items():
|
||||
person_ = sur_values_['person']
|
||||
cam_ = sur_values_['camera']
|
||||
dist_ = sur_values_['distance']
|
||||
|
||||
if cam_ in cameras and dist_ in distances:
|
||||
vit_scores.append(scores_dict[mug_person_][sur_item_]['vit'])
|
||||
resnet_scores.append(scores_dict[mug_person_][sur_item_]['resnet'])
|
||||
vgg_scores.append(scores_dict[mug_person_][sur_item_]['vgg'])
|
||||
inception_scores.append(scores_dict[mug_person_][sur_item_]['inception'])
|
||||
mobilenet_scores.append(scores_dict[mug_person_][sur_item_]['mobilenet'])
|
||||
efficientnet_scores.append(scores_dict[mug_person_][sur_item_]['efficientnet'])
|
||||
|
||||
ground_truth.append(1 if person_ == mug_person_ else 0) # 1 if same person, 0 if different
|
||||
|
||||
models_scores = {
|
||||
'vit': {'name': 'ViT_B32', 'color': 'blue', 'scores': vit_scores},
|
||||
'resnet': {'name': 'ResNet_50', 'color': 'orange', 'scores': resnet_scores},
|
||||
'vgg': {'name': 'VGG_16', 'color': 'green', 'scores': vgg_scores},
|
||||
'inception': {'name': 'Inception_V3', 'color': 'cyan', 'scores': inception_scores},
|
||||
'mobilenet': {'name': 'MobileNet_V2', 'color': 'magenta', 'scores': mobilenet_scores},
|
||||
'efficientnet': {'name': 'EfficientNet_B0', 'color': 'brown', 'scores': efficientnet_scores},
|
||||
}
|
||||
plot_and_csv(models_scores, ground_truth, cameras, distances)
|
||||
|
||||
|
||||
"""
|
||||
CREATE DATASET
|
||||
"""
|
||||
@@ -261,3 +348,27 @@ except FileNotFoundError:
|
||||
with open('./saved_results/Tests/SCface/scores.pickle', 'wb') as scores_file:
|
||||
pickle.dump(scores, scores_file)
|
||||
|
||||
"""
|
||||
COMPUTE ROC CURVES
|
||||
"""
|
||||
|
||||
|
||||
compute_roc(
|
||||
scores,
|
||||
cameras=[
|
||||
'cam1',
|
||||
'cam2',
|
||||
'cam3',
|
||||
'cam4',
|
||||
'cam5',
|
||||
# 'cam6',
|
||||
# 'cam7',
|
||||
# 'cam8',
|
||||
],
|
||||
distances=[
|
||||
'1',
|
||||
'2',
|
||||
'3',
|
||||
# 'None',
|
||||
]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user