From d0446827e94fbda7a3260d2a241892461cf9f0ee Mon Sep 17 00:00:00 2001 From: yakhyo Date: Thu, 10 Apr 2025 18:00:39 +0900 Subject: [PATCH] feat: Add new models --- requirements.txt | 1 + scripts/run_inference.py | 4 +- scripts/sha256_generate.py | 35 ++++++++ uniface/common.py | 2 +- uniface/constants.py | 58 +++++++++++- uniface/detection/retinaface.py | 8 +- uniface/detection/scrfd.py | 153 ++++++++++++++++++++------------ uniface/log.py | 3 +- uniface/model_store.py | 41 +++++---- uniface/retinaface.py | 8 +- 10 files changed, 225 insertions(+), 88 deletions(-) create mode 100644 scripts/sha256_generate.py diff --git a/requirements.txt b/requirements.txt index 42d8005..8df2fff 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ onnxruntime-gpu scikit-image requests pytest +tqdm \ No newline at end of file diff --git a/scripts/run_inference.py b/scripts/run_inference.py index f95376d..8904c66 100644 --- a/scripts/run_inference.py +++ b/scripts/run_inference.py @@ -4,8 +4,8 @@ import time import argparse import numpy as np -from uniface.detection import RetinaFace, draw_detections -from uniface.constants import RetinaFaceWeights +from uniface.detection import RetinaFace, draw_detections, SCRFD +from uniface.constants import RetinaFaceWeights, SCRFDWeights def run_inference(model, image_path, vis_threshold=0.6, save_dir="outputs"): diff --git a/scripts/sha256_generate.py b/scripts/sha256_generate.py new file mode 100644 index 0000000..7213f67 --- /dev/null +++ b/scripts/sha256_generate.py @@ -0,0 +1,35 @@ +import argparse +import hashlib +from pathlib import Path + + +def compute_sha256(file_path: Path, chunk_size: int = 8192) -> str: + sha256_hash = hashlib.sha256() + with file_path.open("rb") as f: + for chunk in iter(lambda: f.read(chunk_size), b""): + sha256_hash.update(chunk) + return sha256_hash.hexdigest() + + +def main(): + parser = argparse.ArgumentParser( + description="Compute SHA256 hash of a model weight file." + ) + parser.add_argument( + "file", + type=Path, + help="Path to the model weight file (.onnx, .pth, etc)." + ) + + args = parser.parse_args() + + if not args.file.exists() or not args.file.is_file(): + print(f"File does not exist: {args.file}") + return + + sha256 = compute_sha256(args.file) + print(f"`SHA256 hash for '{args.file.name}':\n{sha256}") + + +if __name__ == "__main__": + main() diff --git a/uniface/common.py b/uniface/common.py index 8fbcc9f..9133203 100644 --- a/uniface/common.py +++ b/uniface/common.py @@ -82,7 +82,7 @@ def generate_anchors(image_size: Tuple[int, int] = (640, 640)) -> np.ndarray: return output -def nms(dets: List[np.ndarray], threshold: float): +def non_max_supression(dets: List[np.ndarray], threshold: float): """ Apply Non-Maximum Suppression (NMS) to reduce overlapping bounding boxes based on a threshold. diff --git a/uniface/constants.py b/uniface/constants.py index e6b174a..e0763f7 100644 --- a/uniface/constants.py +++ b/uniface/constants.py @@ -44,25 +44,75 @@ class RetinaFaceWeights(str, Enum): RESNET18 = "retinaface_r18" RESNET34 = "retinaface_r34" + +class SCRFDWeights(str, Enum): + """ + Trained on WIDER FACE dataset. + https://github.com/deepinsight/insightface/tree/master/detection/scrfd + """ + SCRFD_10G_KPS = "scrfd_10g" + SCRFD_500M_KPS = "scrfd_500m" + # fmt: on -MODEL_URLS: Dict[RetinaFaceWeights, str] = { +MODEL_URLS: Dict[Enum, str] = { + + # RetinaFace RetinaFaceWeights.MNET_025: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/retinaface_mv1_0.25.onnx', RetinaFaceWeights.MNET_050: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/retinaface_mv1_0.50.onnx', RetinaFaceWeights.MNET_V1: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/retinaface_mv1.onnx', RetinaFaceWeights.MNET_V2: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/retinaface_mv2.onnx', RetinaFaceWeights.RESNET18: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/retinaface_r18.onnx', - RetinaFaceWeights.RESNET34: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/retinaface_r34.onnx' + RetinaFaceWeights.RESNET34: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/retinaface_r34.onnx', + + # MobileFace + MobileFaceWeights.MNET_025: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/###', + MobileFaceWeights.MNET_V2: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/###', + MobileFaceWeights.MNET_V3_SMALL: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/###', + MobileFaceWeights.MNET_V3_LARGE: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/###', + + # SphereFace + SphereFaceWeights.SPHERE20: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/###', + SphereFaceWeights.SPHERE36: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/###', + + + # ArcFace + ArcFaceWeights.MNET: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/w600k_mbf.onnx', + ArcFaceWeights.RESNET: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/w600k_r50.onnx', + + # SCRFD + SCRFDWeights.SCRFD_10G_KPS: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/scrfd_10g_kps.onnx', + SCRFDWeights.SCRFD_500M_KPS: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/scrfd_500m_kps.onnx', } -MODEL_SHA256: Dict[RetinaFaceWeights, str] = { +MODEL_SHA256: Dict[Enum, str] = { + # RetinaFace RetinaFaceWeights.MNET_025: 'b7a7acab55e104dce6f32cdfff929bd83946da5cd869b9e2e9bdffafd1b7e4a5', RetinaFaceWeights.MNET_050: 'd8977186f6037999af5b4113d42ba77a84a6ab0c996b17c713cc3d53b88bfc37', RetinaFaceWeights.MNET_V1: '75c961aaf0aff03d13c074e9ec656e5510e174454dd4964a161aab4fe5f04153', RetinaFaceWeights.MNET_V2: '3ca44c045651cabeed1193a1fae8946ad1f3a55da8fa74b341feab5a8319f757', RetinaFaceWeights.RESNET18: 'e8b5ddd7d2c3c8f7c942f9f10cec09d8e319f78f09725d3f709631de34fb649d', - RetinaFaceWeights.RESNET34: 'bd0263dc2a465d32859555cb1741f2d98991eb0053696e8ee33fec583d30e630' + RetinaFaceWeights.RESNET34: 'bd0263dc2a465d32859555cb1741f2d98991eb0053696e8ee33fec583d30e630', + + # MobileFace + MobileFaceWeights.MNET_025: '#', + MobileFaceWeights.MNET_V2: '#', + MobileFaceWeights.MNET_V3_SMALL: '#', + MobileFaceWeights.MNET_V3_LARGE: '#', + + # SphereFace + SphereFaceWeights.SPHERE20: '#', + SphereFaceWeights.SPHERE36: '#', + + + # ArcFace + ArcFaceWeights.MNET: '9cc6e4a75f0e2bf0b1aed94578f144d15175f357bdc05e815e5c4a02b319eb4f', + ArcFaceWeights.RESNET: '4c06341c33c2ca1f86781dab0e829f88ad5b64be9fba56e56bc9ebdefc619e43', + + # SCRFD + SCRFDWeights.SCRFD_10G_KPS: '5838f7fe053675b1c7a08b633df49e7af5495cee0493c7dcf6697200b85b5b91', + SCRFDWeights.SCRFD_500M_KPS: '5e4447f50245bbd7966bd6c0fa52938c61474a04ec7def48753668a9d8b4ea3a', } CHUNK_SIZE = 8192 diff --git a/uniface/detection/retinaface.py b/uniface/detection/retinaface.py index 5d2cf37..f786847 100644 --- a/uniface/detection/retinaface.py +++ b/uniface/detection/retinaface.py @@ -161,7 +161,7 @@ class RetinaFace: """ original_height, original_width = image.shape[:2] - + if self.dynamic_size: height, width, _ = image.shape self._priors = generate_anchors(image_size=(height, width)) # generate anchors for each input image @@ -244,7 +244,7 @@ class RetinaFace: # Apply NMS detections = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False) - keep = nms(detections, self.nms_thresh) + keep = non_max_supression(detections, self.nms_thresh) detections, landmarks = detections[keep], landmarks[keep] # Keep top-k detections @@ -255,11 +255,11 @@ class RetinaFace: return detections, landmarks def _scale_detections(self, boxes: np.ndarray, landmarks: np.ndarray, resize_factor: float, shape: Tuple[int, int]) -> Tuple[np.ndarray, np.ndarray]: - """Scale bounding boxes and landmarks to the original image size.""" + # Scale bounding boxes and landmarks to the original image size. bbox_scale = np.array([shape[0], shape[1]] * 2) boxes = boxes * bbox_scale / resize_factor landmark_scale = np.array([shape[0], shape[1]] * 5) landmarks = landmarks * landmark_scale / resize_factor - return boxes, landmarks \ No newline at end of file + return boxes, landmarks diff --git a/uniface/detection/scrfd.py b/uniface/detection/scrfd.py index 99b0f4b..c874d67 100644 --- a/uniface/detection/scrfd.py +++ b/uniface/detection/scrfd.py @@ -1,11 +1,18 @@ +# Copyright 2025 Yakhyokhuja Valikhujaev +# Author: Yakhyokhuja Valikhujaev +# GitHub: https://github.com/yakhyo +# Modified from insightface repo + import os import cv2 import numpy as np -import onnxruntime +import onnxruntime as ort from typing import Tuple, Optional, List, Literal -# from uniface.logger import Logger +from uniface.log import Logger +from uniface.model_store import verify_model_weights +from uniface.constants import SCRFDWeights from .utils import non_max_supression, distance2bbox, distance2kps, resize_image __all__ = ['SCRFD'] @@ -13,58 +20,85 @@ __all__ = ['SCRFD'] class SCRFD: """ + A class for face detection using the SCRFD model. + Title: "Sample and Computation Redistribution for Efficient Face Detection" Paper: https://arxiv.org/abs/2105.04714 + + Args: + model_name (SCRFDWeights): Predefined model enum (e.g., SCRFD_10G_KPS). Specifies which SCRFD variant to load. + conf_thresh (float): Confidence threshold for filtering detections. Defaults to 0.5. + nms_thresh (float): Non-Maximum Suppression threshold. Defaults to 0.4. + input_size (Optional[Tuple[int, int]]): Input resolution (height, width) to which the image is resized. Defaults to (640, 640). + + Attributes: + conf_thresh (float): Confidence threshold used to filter raw detections. + nms_thresh (float): NMS threshold to suppress overlapping detections. + input_size (Tuple[int, int]): Target resolution for input resizing. + _fmc (int): Number of feature map scales used in SCRFD. + _feat_stride_fpn (List[int]): Stride values for each feature map. + _num_anchors (int): Number of anchors per feature point. + _center_cache (Dict): Cache of anchor centers for efficient inference. + _model_path (str): Verified path to the downloaded model weights. + + Raises: + ValueError: If the model weights cannot be found or verified. + RuntimeError: If the ONNX model fails to load or initialize. """ def __init__( self, - model_path: str, - input_size: Tuple[int] = (640, 640), - conf_thres: float = 0.5, - iou_thres: float = 0.4 + model_name: SCRFDWeights = SCRFDWeights.SCRFD_10G_KPS, + conf_thresh: float = 0.5, + nms_thresh: float = 0.4, + input_size: Optional[Tuple[int, int]] = (640, 640), ) -> None: - """SCRFD initialization - - Args: - model_path (str): Path model .onnx file. - input_size (int): Input image size. Defaults to (640, 640) - max_num (int): Maximum number of detections - conf_thres (float, optional): Confidence threshold. Defaults to 0.5. - iou_thres (float, optional): Non-max supression (NMS) threshold. Defaults to 0.4. - """ + self.conf_thresh = conf_thresh + self.nms_thresh = nms_thresh self.input_size = input_size - self.conf_thres = conf_thres - self.iou_thres = iou_thres # SCRFD model params -------------- - self.fmc = 3 + self._fmc = 3 self._feat_stride_fpn = [8, 16, 32] self._num_anchors = 2 - self.center_cache = {} + self._center_cache = {} # --------------------------------- - self._initialize_model(model_path=model_path) + Logger.info( + f"Initializing SCRFD with model={model_name}, conf_thresh={conf_thresh}, nms_thresh={nms_thresh}, " + f"input_size={input_size}" + ) - def _initialize_model(self, model_path: str): - """Initialize the model from the given path. + # Get path to model weights + self._model_path = verify_model_weights(model_name) + Logger.info(f"Verified model weights located at: {self._model_path}") + + # Initialize model + self._initialize_model(self._model_path) + + def _initialize_model(self, model_path: str) -> None: + """ + Initializes an ONNX model session from the given path. Args: - model_path (str): Path to .onnx model. + model_path (str): The file path to the ONNX model. + + Raises: + RuntimeError: If the model fails to load, logs an error and raises an exception. """ try: - self.session = onnxruntime.InferenceSession( + self.session = ort.InferenceSession( model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"] ) - # Get model info self.input_names = self.session.get_inputs()[0].name self.output_names = [x.name for x in self.session.get_outputs()] + Logger.info(f"Successfully initialized the model from {model_path}") except Exception as e: - print(f"Failed to load the model: {e}") - raise + Logger.error(f"Failed to load model from '{model_path}': {e}", exc_info=True) + raise RuntimeError(f"Failed to initialize model session for '{model_path}'") from e def preprocess(self, image: np.ndarray) -> Tuple[np.ndarray, Tuple[int, int]]: """Preprocess image for inference. @@ -75,14 +109,12 @@ class SCRFD: Returns: Tuple[np.ndarray, Tuple[int, int]]: Preprocessed blob and input size """ - input_size = tuple(image.shape[0:2][::-1]) - image = image.astype(np.float32) image = (image - 127.5) / 127.5 image = image.transpose(2, 0, 1) # HWC to CHW image = np.expand_dims(image, axis=0) - return image, input_size + return image def inference(self, input_tensor: np.ndarray) -> List[np.ndarray]: """Perform model inference on the preprocessed image tensor. @@ -95,26 +127,26 @@ class SCRFD: """ return self.session.run(self.output_names, {self.input_names: input_tensor}) - def postprocess(self, outputs: List[np.ndarray], image_dim: Tuple[int, int]): + def postprocess(self, outputs: List[np.ndarray], image_size: Tuple[int, int]): scores_list = [] bboxes_list = [] kpss_list = [] - input_height, input_width = image_dim + image_size = image_size - fmc = self.fmc + fmc = self._fmc for idx, stride in enumerate(self._feat_stride_fpn): scores = outputs[idx] bbox_preds = outputs[fmc + idx] * stride kps_preds = outputs[2*fmc + idx] * stride # Generate anchors - fm_height = input_height // stride - fm_width = input_width // stride + fm_height = image_size[0] // stride + fm_width = image_size[1] // stride cache_key = (fm_height, fm_width, stride) - if cache_key in self.center_cache: - anchor_centers = self.center_cache[cache_key] + if cache_key in self._center_cache: + anchor_centers = self._center_cache[cache_key] else: y, x = np.mgrid[:fm_height, :fm_width] anchor_centers = np.stack((x, y), axis=-1).astype(np.float32) @@ -123,10 +155,10 @@ class SCRFD: if self._num_anchors > 1: anchor_centers = np.tile(anchor_centers[:, None, :], (1, self._num_anchors, 1)).reshape(-1, 2) - if len(self.center_cache) < 100: - self.center_cache[cache_key] = anchor_centers + if len(self._center_cache) < 100: + self._center_cache[cache_key] = anchor_centers - pos_indices = np.where(scores >= self.conf_thres)[0] + pos_indices = np.where(scores >= self.conf_thresh)[0] if len(pos_indices) == 0: continue @@ -135,9 +167,9 @@ class SCRFD: scores_list.append(scores_selected) bboxes_list.append(bboxes) - kpss = distance2kps(anchor_centers, kps_preds) - kpss = kpss.reshape((kpss.shape[0], -1, 2)) - kpss_list.append(kpss[pos_indices]) + landmarks = distance2kps(anchor_centers, kps_preds) + landmarks = landmarks.reshape((landmarks.shape[0], -1, 2)) + kpss_list.append(landmarks[pos_indices]) return scores_list, bboxes_list, kpss_list @@ -153,48 +185,57 @@ class SCRFD: image, resize_factor = resize_image(image, target_shape=self.input_size) - image_tensor, _ = self.preprocess(image) + image_tensor = self.preprocess(image) + # ONNXRuntime inference outputs = self.inference(image_tensor) - scores_list, bboxes_list, kpss_list = self.postprocess(outputs, image.shape[:2]) + scores_list, bboxes_list, kpss_list = self.postprocess(outputs, image_size=image.shape[:2]) scores = np.vstack(scores_list) scores_ravel = scores.ravel() order = scores_ravel.argsort()[::-1] - bboxes = np.vstack(bboxes_list) / resize_factor - kpss = np.vstack(kpss_list) / resize_factor + bboxes = np.vstack(bboxes_list) / resize_factor + landmarks = np.vstack(kpss_list) / resize_factor pre_det = np.hstack((bboxes, scores)).astype(np.float32, copy=False) pre_det = pre_det[order, :] - keep = non_max_supression(pre_det, threshold=self.iou_thres) - det = pre_det[keep, :] - kpss = kpss[order, :, :] - kpss = kpss[keep, :, :] + keep = non_max_supression(pre_det, threshold=self.nms_thresh) + + det = pre_det[keep, :] + landmarks = landmarks[order, :, :] + landmarks = landmarks[keep, :, :].astype(np.int32) if 0 < max_num < det.shape[0]: + # Calculate area of detections area = (det[:, 2] - det[:, 0]) * (det[:, 3] - det[:, 1]) - center = (original_height // 2, original_width // 2) + # Calculate offsets from image center + center = (original_height // 2, original_width // 2) offsets = np.vstack( [ (det[:, 0] + det[:, 2]) / 2 - center[1], (det[:, 1] + det[:, 3]) / 2 - center[0], ] ) + + # Calculate scores based on the chosen metric offset_dist_squared = np.sum(np.power(offsets, 2.0), axis=0) if metric == "max": values = area else: - values = area - offset_dist_squared * center_weight # some extra weight on the centering + values = area - offset_dist_squared * center_weight + # Sort by scores and select top `max_num` sorted_indices = np.argsort(values)[::-1][:max_num] det = det[sorted_indices] - kpss = kpss[sorted_indices] + landmarks = landmarks[sorted_indices] + + - return det, kpss + return det, landmarks def draw_bbox(frame, bbox, color=(0, 255, 0), thickness=2): @@ -210,7 +251,7 @@ def draw_keypoints(frame, points, color=(0, 0, 255), radius=2): if __name__ == "__main__": - detector = SCRFD(model_path="det_10g.onnx") + detector = SCRFD() cap = cv2.VideoCapture(0) if not cap.isOpened(): diff --git a/uniface/log.py b/uniface/log.py index 1d88bd5..ab3d4fc 100644 --- a/uniface/log.py +++ b/uniface/log.py @@ -2,6 +2,7 @@ import logging logging.basicConfig( level=logging.INFO, - format="%(asctime)s - %(levelname)s - %(message)s" + format="%(asctime)s - %(levelname)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S" ) Logger = logging.getLogger("uniface") diff --git a/uniface/model_store.py b/uniface/model_store.py index b7aac40..2ab95e9 100644 --- a/uniface/model_store.py +++ b/uniface/model_store.py @@ -5,6 +5,7 @@ import os import hashlib import requests +from tqdm import tqdm from uniface.log import Logger import uniface.constants as const @@ -15,36 +16,38 @@ __all__ = ['verify_model_weights'] def verify_model_weights(model_name: str, root: str = '~/.uniface/models') -> str: """ - Ensures model weights are available by downloading if missing and verifying integrity with a SHA-256 hash. + Ensure model weights are present, downloading and verifying them using SHA-256 if necessary. - Checks if the specified model weights file exists in `root`. If missing, downloads from a predefined URL. - The file is then verified using its SHA-256 hash. If verification fails, the corrupted file is deleted, + Given a model identifier from an Enum class (e.g., `RetinaFaceWeights.MNET_V2`), this function checks if + the corresponding `.onnx` weight file exists locally. If not, it downloads the file from a predefined URL. + After download, the file’s integrity is verified using a SHA-256 hash. If verification fails, the file is deleted and an error is raised. Args: - model_name (str): Name of the model weights to verify or download. - root (str, optional): Directory to store the model weights. Defaults to '~/.uniface/models'. + model_name (Enum): Model weight identifier (e.g., `RetinaFaceWeights.MNET_V2`, `ArcFaceWeights.RESNET`, etc.). + root (str, optional): Directory to store or locate the model weights. Defaults to '~/.uniface/models'. Returns: - str: Path to the verified model weights file. + str: Absolute path to the verified model weights file. Raises: - ValueError: If the model is not found or if verification fails. + ValueError: If the model is unknown or SHA-256 verification fails. ConnectionError: If downloading the file fails. Examples: - >>> # Download and verify 'retinaface_mnet025' weights - >>> verify_model_weights('retinaface_mnet025') - '/home/user/.uniface/models/retinaface_mnet025.onnx' + >>> from uniface.models import RetinaFaceWeights, verify_model_weights + >>> verify_model_weights(RetinaFaceWeights.MNET_V2) + '/home/user/.uniface/models/retinaface_mnet_v2.onnx' - >>> # Use a custom directory - >>> verify_model_weights('retinaface_r34', root='/custom/dir') + >>> verify_model_weights(RetinaFaceWeights.RESNET34, root='/custom/dir') '/custom/dir/retinaface_r34.onnx' """ root = os.path.expanduser(root) os.makedirs(root, exist_ok=True) - model_path = os.path.join(root, f'{model_name}.onnx') + + model_name = model_name.value + model_path = os.path.normpath(os.path.join(root, f'{model_name}.onnx')) if not os.path.exists(model_path): url = const.MODEL_URLS.get(model_name) @@ -54,9 +57,9 @@ def verify_model_weights(model_name: str, root: str = '~/.uniface/models') -> st Logger.info(f"Downloading model '{model_name}' from {url}") download_file(url, model_path) - Logger.info(f"Successfully downloaded '{model_name}' to {os.path.normpath(model_path)}") + Logger.info(f"Successfully downloaded '{model_name}' to {model_path}") else: - Logger.info(f"Model '{model_name}' already exists at {os.path.normpath(model_path)}") + Logger.info(f"Model '{model_name}' already exists at {model_path}") expected_hash = const.MODEL_SHA256.get(model_name) if expected_hash and not verify_file_hash(model_path, expected_hash): @@ -72,10 +75,16 @@ def download_file(url: str, dest_path: str) -> None: try: response = requests.get(url, stream=True) response.raise_for_status() - with open(dest_path, "wb") as file: + with open(dest_path, "wb") as file, tqdm( + desc=f"Downloading {dest_path}", + unit='B', + unit_scale=True, + unit_divisor=1024 + ) as progress: for chunk in response.iter_content(chunk_size=const.CHUNK_SIZE): if chunk: file.write(chunk) + progress.update(len(chunk)) except requests.RequestException as e: raise ConnectionError(f"Failed to download file from {url}. Error: {e}") diff --git a/uniface/retinaface.py b/uniface/retinaface.py index 6f9ff0b..abbc5c4 100644 --- a/uniface/retinaface.py +++ b/uniface/retinaface.py @@ -13,7 +13,7 @@ from uniface.log import Logger from uniface.model_store import verify_model_weights from uniface.constants import RetinaFaceWeights from uniface.common import ( - nms, + non_max_supression, resize_image, decode_boxes, generate_anchors, @@ -161,7 +161,7 @@ class RetinaFace: """ original_height, original_width = image.shape[:2] - + if self.dynamic_size: height, width, _ = image.shape self._priors = generate_anchors(image_size=(height, width)) # generate anchors for each input image @@ -244,7 +244,7 @@ class RetinaFace: # Apply NMS detections = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False) - keep = nms(detections, self.nms_thresh) + keep = non_max_supression(detections, self.nms_thresh) detections, landmarks = detections[keep], landmarks[keep] # Keep top-k detections @@ -255,7 +255,7 @@ class RetinaFace: return detections, landmarks def _scale_detections(self, boxes: np.ndarray, landmarks: np.ndarray, resize_factor: float, shape: Tuple[int, int]) -> Tuple[np.ndarray, np.ndarray]: - """Scale bounding boxes and landmarks to the original image size.""" + # Scale bounding boxes and landmarks to the original image size. bbox_scale = np.array([shape[0], shape[1]] * 2) boxes = boxes * bbox_scale / resize_factor