mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-21 17:17:49 +00:00
90 lines
3.5 KiB
Python
Executable File
90 lines
3.5 KiB
Python
Executable File
from collections import OrderedDict
|
|
|
|
from mmdet.core import eval_map, eval_recalls
|
|
from .builder import DATASETS
|
|
from .xml_style import XMLDataset
|
|
|
|
|
|
@DATASETS.register_module()
|
|
class VOCDataset(XMLDataset):
|
|
|
|
CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
|
|
'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
|
|
'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
|
|
'tvmonitor')
|
|
|
|
def __init__(self, **kwargs):
|
|
super(VOCDataset, self).__init__(**kwargs)
|
|
if 'VOC2007' in self.img_prefix:
|
|
self.year = 2007
|
|
elif 'VOC2012' in self.img_prefix:
|
|
self.year = 2012
|
|
else:
|
|
raise ValueError('Cannot infer dataset year from img_prefix')
|
|
|
|
def evaluate(self,
|
|
results,
|
|
metric='mAP',
|
|
logger=None,
|
|
proposal_nums=(100, 300, 1000),
|
|
iou_thr=0.5,
|
|
scale_ranges=None):
|
|
"""Evaluate in VOC protocol.
|
|
|
|
Args:
|
|
results (list[list | tuple]): Testing results of the dataset.
|
|
metric (str | list[str]): Metrics to be evaluated. Options are
|
|
'mAP', 'recall'.
|
|
logger (logging.Logger | str, optional): Logger used for printing
|
|
related information during evaluation. Default: None.
|
|
proposal_nums (Sequence[int]): Proposal number used for evaluating
|
|
recalls, such as recall@100, recall@1000.
|
|
Default: (100, 300, 1000).
|
|
iou_thr (float | list[float]): IoU threshold. It must be a float
|
|
when evaluating mAP, and can be a list when evaluating recall.
|
|
Default: 0.5.
|
|
scale_ranges (list[tuple], optional): Scale ranges for evaluating
|
|
mAP. If not specified, all bounding boxes would be included in
|
|
evaluation. Default: None.
|
|
|
|
Returns:
|
|
dict[str, float]: AP/recall metrics.
|
|
"""
|
|
|
|
if not isinstance(metric, str):
|
|
assert len(metric) == 1
|
|
metric = metric[0]
|
|
allowed_metrics = ['mAP', 'recall']
|
|
if metric not in allowed_metrics:
|
|
raise KeyError(f'metric {metric} is not supported')
|
|
annotations = [self.get_ann_info(i) for i in range(len(self))]
|
|
eval_results = OrderedDict()
|
|
if metric == 'mAP':
|
|
assert isinstance(iou_thr, float)
|
|
if self.year == 2007:
|
|
ds_name = 'voc07'
|
|
else:
|
|
ds_name = self.CLASSES
|
|
mean_ap, _ = eval_map(
|
|
results,
|
|
annotations,
|
|
scale_ranges=None,
|
|
iou_thr=iou_thr,
|
|
dataset=ds_name,
|
|
logger=logger)
|
|
eval_results['mAP'] = mean_ap
|
|
elif metric == 'recall':
|
|
gt_bboxes = [ann['bboxes'] for ann in annotations]
|
|
if isinstance(iou_thr, float):
|
|
iou_thr = [iou_thr]
|
|
recalls = eval_recalls(
|
|
gt_bboxes, results, proposal_nums, iou_thr, logger=logger)
|
|
for i, num in enumerate(proposal_nums):
|
|
for j, iou in enumerate(iou_thr):
|
|
eval_results[f'recall@{num}@{iou}'] = recalls[i, j]
|
|
if recalls.shape[1] > 1:
|
|
ar = recalls.mean(axis=1)
|
|
for i, num in enumerate(proposal_nums):
|
|
eval_results[f'AR@{num}'] = ar[i]
|
|
return eval_results
|