mirror of
https://github.com/yakhyo/uniface.git
synced 2025-12-30 09:02:25 +00:00
feat: Add new models
This commit is contained in:
@@ -5,3 +5,4 @@ onnxruntime-gpu
|
||||
scikit-image
|
||||
requests
|
||||
pytest
|
||||
tqdm
|
||||
@@ -4,8 +4,8 @@ import time
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from uniface.detection import RetinaFace, draw_detections
|
||||
from uniface.constants import RetinaFaceWeights
|
||||
from uniface.detection import RetinaFace, draw_detections, SCRFD
|
||||
from uniface.constants import RetinaFaceWeights, SCRFDWeights
|
||||
|
||||
|
||||
def run_inference(model, image_path, vis_threshold=0.6, save_dir="outputs"):
|
||||
|
||||
35
scripts/sha256_generate.py
Normal file
35
scripts/sha256_generate.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import argparse
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def compute_sha256(file_path: Path, chunk_size: int = 8192) -> str:
|
||||
sha256_hash = hashlib.sha256()
|
||||
with file_path.open("rb") as f:
|
||||
for chunk in iter(lambda: f.read(chunk_size), b""):
|
||||
sha256_hash.update(chunk)
|
||||
return sha256_hash.hexdigest()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Compute SHA256 hash of a model weight file."
|
||||
)
|
||||
parser.add_argument(
|
||||
"file",
|
||||
type=Path,
|
||||
help="Path to the model weight file (.onnx, .pth, etc)."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.file.exists() or not args.file.is_file():
|
||||
print(f"File does not exist: {args.file}")
|
||||
return
|
||||
|
||||
sha256 = compute_sha256(args.file)
|
||||
print(f"`SHA256 hash for '{args.file.name}':\n{sha256}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -82,7 +82,7 @@ def generate_anchors(image_size: Tuple[int, int] = (640, 640)) -> np.ndarray:
|
||||
return output
|
||||
|
||||
|
||||
def nms(dets: List[np.ndarray], threshold: float):
|
||||
def non_max_supression(dets: List[np.ndarray], threshold: float):
|
||||
"""
|
||||
Apply Non-Maximum Suppression (NMS) to reduce overlapping bounding boxes based on a threshold.
|
||||
|
||||
|
||||
@@ -44,25 +44,75 @@ class RetinaFaceWeights(str, Enum):
|
||||
RESNET18 = "retinaface_r18"
|
||||
RESNET34 = "retinaface_r34"
|
||||
|
||||
|
||||
class SCRFDWeights(str, Enum):
|
||||
"""
|
||||
Trained on WIDER FACE dataset.
|
||||
https://github.com/deepinsight/insightface/tree/master/detection/scrfd
|
||||
"""
|
||||
SCRFD_10G_KPS = "scrfd_10g"
|
||||
SCRFD_500M_KPS = "scrfd_500m"
|
||||
|
||||
# fmt: on
|
||||
|
||||
|
||||
MODEL_URLS: Dict[RetinaFaceWeights, str] = {
|
||||
MODEL_URLS: Dict[Enum, str] = {
|
||||
|
||||
# RetinaFace
|
||||
RetinaFaceWeights.MNET_025: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/retinaface_mv1_0.25.onnx',
|
||||
RetinaFaceWeights.MNET_050: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/retinaface_mv1_0.50.onnx',
|
||||
RetinaFaceWeights.MNET_V1: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/retinaface_mv1.onnx',
|
||||
RetinaFaceWeights.MNET_V2: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/retinaface_mv2.onnx',
|
||||
RetinaFaceWeights.RESNET18: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/retinaface_r18.onnx',
|
||||
RetinaFaceWeights.RESNET34: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/retinaface_r34.onnx'
|
||||
RetinaFaceWeights.RESNET34: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/retinaface_r34.onnx',
|
||||
|
||||
# MobileFace
|
||||
MobileFaceWeights.MNET_025: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/###',
|
||||
MobileFaceWeights.MNET_V2: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/###',
|
||||
MobileFaceWeights.MNET_V3_SMALL: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/###',
|
||||
MobileFaceWeights.MNET_V3_LARGE: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/###',
|
||||
|
||||
# SphereFace
|
||||
SphereFaceWeights.SPHERE20: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/###',
|
||||
SphereFaceWeights.SPHERE36: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/###',
|
||||
|
||||
|
||||
# ArcFace
|
||||
ArcFaceWeights.MNET: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/w600k_mbf.onnx',
|
||||
ArcFaceWeights.RESNET: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/w600k_r50.onnx',
|
||||
|
||||
# SCRFD
|
||||
SCRFDWeights.SCRFD_10G_KPS: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/scrfd_10g_kps.onnx',
|
||||
SCRFDWeights.SCRFD_500M_KPS: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/scrfd_500m_kps.onnx',
|
||||
}
|
||||
|
||||
MODEL_SHA256: Dict[RetinaFaceWeights, str] = {
|
||||
MODEL_SHA256: Dict[Enum, str] = {
|
||||
# RetinaFace
|
||||
RetinaFaceWeights.MNET_025: 'b7a7acab55e104dce6f32cdfff929bd83946da5cd869b9e2e9bdffafd1b7e4a5',
|
||||
RetinaFaceWeights.MNET_050: 'd8977186f6037999af5b4113d42ba77a84a6ab0c996b17c713cc3d53b88bfc37',
|
||||
RetinaFaceWeights.MNET_V1: '75c961aaf0aff03d13c074e9ec656e5510e174454dd4964a161aab4fe5f04153',
|
||||
RetinaFaceWeights.MNET_V2: '3ca44c045651cabeed1193a1fae8946ad1f3a55da8fa74b341feab5a8319f757',
|
||||
RetinaFaceWeights.RESNET18: 'e8b5ddd7d2c3c8f7c942f9f10cec09d8e319f78f09725d3f709631de34fb649d',
|
||||
RetinaFaceWeights.RESNET34: 'bd0263dc2a465d32859555cb1741f2d98991eb0053696e8ee33fec583d30e630'
|
||||
RetinaFaceWeights.RESNET34: 'bd0263dc2a465d32859555cb1741f2d98991eb0053696e8ee33fec583d30e630',
|
||||
|
||||
# MobileFace
|
||||
MobileFaceWeights.MNET_025: '#',
|
||||
MobileFaceWeights.MNET_V2: '#',
|
||||
MobileFaceWeights.MNET_V3_SMALL: '#',
|
||||
MobileFaceWeights.MNET_V3_LARGE: '#',
|
||||
|
||||
# SphereFace
|
||||
SphereFaceWeights.SPHERE20: '#',
|
||||
SphereFaceWeights.SPHERE36: '#',
|
||||
|
||||
|
||||
# ArcFace
|
||||
ArcFaceWeights.MNET: '9cc6e4a75f0e2bf0b1aed94578f144d15175f357bdc05e815e5c4a02b319eb4f',
|
||||
ArcFaceWeights.RESNET: '4c06341c33c2ca1f86781dab0e829f88ad5b64be9fba56e56bc9ebdefc619e43',
|
||||
|
||||
# SCRFD
|
||||
SCRFDWeights.SCRFD_10G_KPS: '5838f7fe053675b1c7a08b633df49e7af5495cee0493c7dcf6697200b85b5b91',
|
||||
SCRFDWeights.SCRFD_500M_KPS: '5e4447f50245bbd7966bd6c0fa52938c61474a04ec7def48753668a9d8b4ea3a',
|
||||
}
|
||||
|
||||
CHUNK_SIZE = 8192
|
||||
|
||||
@@ -161,7 +161,7 @@ class RetinaFace:
|
||||
"""
|
||||
|
||||
original_height, original_width = image.shape[:2]
|
||||
|
||||
|
||||
if self.dynamic_size:
|
||||
height, width, _ = image.shape
|
||||
self._priors = generate_anchors(image_size=(height, width)) # generate anchors for each input image
|
||||
@@ -244,7 +244,7 @@ class RetinaFace:
|
||||
|
||||
# Apply NMS
|
||||
detections = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
|
||||
keep = nms(detections, self.nms_thresh)
|
||||
keep = non_max_supression(detections, self.nms_thresh)
|
||||
detections, landmarks = detections[keep], landmarks[keep]
|
||||
|
||||
# Keep top-k detections
|
||||
@@ -255,11 +255,11 @@ class RetinaFace:
|
||||
return detections, landmarks
|
||||
|
||||
def _scale_detections(self, boxes: np.ndarray, landmarks: np.ndarray, resize_factor: float, shape: Tuple[int, int]) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Scale bounding boxes and landmarks to the original image size."""
|
||||
# Scale bounding boxes and landmarks to the original image size.
|
||||
bbox_scale = np.array([shape[0], shape[1]] * 2)
|
||||
boxes = boxes * bbox_scale / resize_factor
|
||||
|
||||
landmark_scale = np.array([shape[0], shape[1]] * 5)
|
||||
landmarks = landmarks * landmark_scale / resize_factor
|
||||
|
||||
return boxes, landmarks
|
||||
return boxes, landmarks
|
||||
|
||||
@@ -1,11 +1,18 @@
|
||||
# Copyright 2025 Yakhyokhuja Valikhujaev
|
||||
# Author: Yakhyokhuja Valikhujaev
|
||||
# GitHub: https://github.com/yakhyo
|
||||
# Modified from insightface repo
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import onnxruntime
|
||||
import onnxruntime as ort
|
||||
|
||||
from typing import Tuple, Optional, List, Literal
|
||||
|
||||
# from uniface.logger import Logger
|
||||
from uniface.log import Logger
|
||||
from uniface.model_store import verify_model_weights
|
||||
from uniface.constants import SCRFDWeights
|
||||
from .utils import non_max_supression, distance2bbox, distance2kps, resize_image
|
||||
|
||||
__all__ = ['SCRFD']
|
||||
@@ -13,58 +20,85 @@ __all__ = ['SCRFD']
|
||||
|
||||
class SCRFD:
|
||||
"""
|
||||
A class for face detection using the SCRFD model.
|
||||
|
||||
Title: "Sample and Computation Redistribution for Efficient Face Detection"
|
||||
Paper: https://arxiv.org/abs/2105.04714
|
||||
|
||||
Args:
|
||||
model_name (SCRFDWeights): Predefined model enum (e.g., SCRFD_10G_KPS). Specifies which SCRFD variant to load.
|
||||
conf_thresh (float): Confidence threshold for filtering detections. Defaults to 0.5.
|
||||
nms_thresh (float): Non-Maximum Suppression threshold. Defaults to 0.4.
|
||||
input_size (Optional[Tuple[int, int]]): Input resolution (height, width) to which the image is resized. Defaults to (640, 640).
|
||||
|
||||
Attributes:
|
||||
conf_thresh (float): Confidence threshold used to filter raw detections.
|
||||
nms_thresh (float): NMS threshold to suppress overlapping detections.
|
||||
input_size (Tuple[int, int]): Target resolution for input resizing.
|
||||
_fmc (int): Number of feature map scales used in SCRFD.
|
||||
_feat_stride_fpn (List[int]): Stride values for each feature map.
|
||||
_num_anchors (int): Number of anchors per feature point.
|
||||
_center_cache (Dict): Cache of anchor centers for efficient inference.
|
||||
_model_path (str): Verified path to the downloaded model weights.
|
||||
|
||||
Raises:
|
||||
ValueError: If the model weights cannot be found or verified.
|
||||
RuntimeError: If the ONNX model fails to load or initialize.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
input_size: Tuple[int] = (640, 640),
|
||||
conf_thres: float = 0.5,
|
||||
iou_thres: float = 0.4
|
||||
model_name: SCRFDWeights = SCRFDWeights.SCRFD_10G_KPS,
|
||||
conf_thresh: float = 0.5,
|
||||
nms_thresh: float = 0.4,
|
||||
input_size: Optional[Tuple[int, int]] = (640, 640),
|
||||
) -> None:
|
||||
"""SCRFD initialization
|
||||
|
||||
Args:
|
||||
model_path (str): Path model .onnx file.
|
||||
input_size (int): Input image size. Defaults to (640, 640)
|
||||
max_num (int): Maximum number of detections
|
||||
conf_thres (float, optional): Confidence threshold. Defaults to 0.5.
|
||||
iou_thres (float, optional): Non-max supression (NMS) threshold. Defaults to 0.4.
|
||||
"""
|
||||
|
||||
self.conf_thresh = conf_thresh
|
||||
self.nms_thresh = nms_thresh
|
||||
self.input_size = input_size
|
||||
self.conf_thres = conf_thres
|
||||
self.iou_thres = iou_thres
|
||||
|
||||
# SCRFD model params --------------
|
||||
self.fmc = 3
|
||||
self._fmc = 3
|
||||
self._feat_stride_fpn = [8, 16, 32]
|
||||
self._num_anchors = 2
|
||||
|
||||
self.center_cache = {}
|
||||
self._center_cache = {}
|
||||
# ---------------------------------
|
||||
|
||||
self._initialize_model(model_path=model_path)
|
||||
Logger.info(
|
||||
f"Initializing SCRFD with model={model_name}, conf_thresh={conf_thresh}, nms_thresh={nms_thresh}, "
|
||||
f"input_size={input_size}"
|
||||
)
|
||||
|
||||
def _initialize_model(self, model_path: str):
|
||||
"""Initialize the model from the given path.
|
||||
# Get path to model weights
|
||||
self._model_path = verify_model_weights(model_name)
|
||||
Logger.info(f"Verified model weights located at: {self._model_path}")
|
||||
|
||||
# Initialize model
|
||||
self._initialize_model(self._model_path)
|
||||
|
||||
def _initialize_model(self, model_path: str) -> None:
|
||||
"""
|
||||
Initializes an ONNX model session from the given path.
|
||||
|
||||
Args:
|
||||
model_path (str): Path to .onnx model.
|
||||
model_path (str): The file path to the ONNX model.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the model fails to load, logs an error and raises an exception.
|
||||
"""
|
||||
try:
|
||||
self.session = onnxruntime.InferenceSession(
|
||||
self.session = ort.InferenceSession(
|
||||
model_path,
|
||||
providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
|
||||
)
|
||||
# Get model info
|
||||
self.input_names = self.session.get_inputs()[0].name
|
||||
self.output_names = [x.name for x in self.session.get_outputs()]
|
||||
Logger.info(f"Successfully initialized the model from {model_path}")
|
||||
except Exception as e:
|
||||
print(f"Failed to load the model: {e}")
|
||||
raise
|
||||
Logger.error(f"Failed to load model from '{model_path}': {e}", exc_info=True)
|
||||
raise RuntimeError(f"Failed to initialize model session for '{model_path}'") from e
|
||||
|
||||
def preprocess(self, image: np.ndarray) -> Tuple[np.ndarray, Tuple[int, int]]:
|
||||
"""Preprocess image for inference.
|
||||
@@ -75,14 +109,12 @@ class SCRFD:
|
||||
Returns:
|
||||
Tuple[np.ndarray, Tuple[int, int]]: Preprocessed blob and input size
|
||||
"""
|
||||
input_size = tuple(image.shape[0:2][::-1])
|
||||
|
||||
image = image.astype(np.float32)
|
||||
image = (image - 127.5) / 127.5
|
||||
image = image.transpose(2, 0, 1) # HWC to CHW
|
||||
image = np.expand_dims(image, axis=0)
|
||||
|
||||
return image, input_size
|
||||
return image
|
||||
|
||||
def inference(self, input_tensor: np.ndarray) -> List[np.ndarray]:
|
||||
"""Perform model inference on the preprocessed image tensor.
|
||||
@@ -95,26 +127,26 @@ class SCRFD:
|
||||
"""
|
||||
return self.session.run(self.output_names, {self.input_names: input_tensor})
|
||||
|
||||
def postprocess(self, outputs: List[np.ndarray], image_dim: Tuple[int, int]):
|
||||
def postprocess(self, outputs: List[np.ndarray], image_size: Tuple[int, int]):
|
||||
scores_list = []
|
||||
bboxes_list = []
|
||||
kpss_list = []
|
||||
|
||||
input_height, input_width = image_dim
|
||||
image_size = image_size
|
||||
|
||||
fmc = self.fmc
|
||||
fmc = self._fmc
|
||||
for idx, stride in enumerate(self._feat_stride_fpn):
|
||||
scores = outputs[idx]
|
||||
bbox_preds = outputs[fmc + idx] * stride
|
||||
kps_preds = outputs[2*fmc + idx] * stride
|
||||
|
||||
# Generate anchors
|
||||
fm_height = input_height // stride
|
||||
fm_width = input_width // stride
|
||||
fm_height = image_size[0] // stride
|
||||
fm_width = image_size[1] // stride
|
||||
cache_key = (fm_height, fm_width, stride)
|
||||
|
||||
if cache_key in self.center_cache:
|
||||
anchor_centers = self.center_cache[cache_key]
|
||||
if cache_key in self._center_cache:
|
||||
anchor_centers = self._center_cache[cache_key]
|
||||
else:
|
||||
y, x = np.mgrid[:fm_height, :fm_width]
|
||||
anchor_centers = np.stack((x, y), axis=-1).astype(np.float32)
|
||||
@@ -123,10 +155,10 @@ class SCRFD:
|
||||
if self._num_anchors > 1:
|
||||
anchor_centers = np.tile(anchor_centers[:, None, :], (1, self._num_anchors, 1)).reshape(-1, 2)
|
||||
|
||||
if len(self.center_cache) < 100:
|
||||
self.center_cache[cache_key] = anchor_centers
|
||||
if len(self._center_cache) < 100:
|
||||
self._center_cache[cache_key] = anchor_centers
|
||||
|
||||
pos_indices = np.where(scores >= self.conf_thres)[0]
|
||||
pos_indices = np.where(scores >= self.conf_thresh)[0]
|
||||
if len(pos_indices) == 0:
|
||||
continue
|
||||
|
||||
@@ -135,9 +167,9 @@ class SCRFD:
|
||||
scores_list.append(scores_selected)
|
||||
bboxes_list.append(bboxes)
|
||||
|
||||
kpss = distance2kps(anchor_centers, kps_preds)
|
||||
kpss = kpss.reshape((kpss.shape[0], -1, 2))
|
||||
kpss_list.append(kpss[pos_indices])
|
||||
landmarks = distance2kps(anchor_centers, kps_preds)
|
||||
landmarks = landmarks.reshape((landmarks.shape[0], -1, 2))
|
||||
kpss_list.append(landmarks[pos_indices])
|
||||
|
||||
return scores_list, bboxes_list, kpss_list
|
||||
|
||||
@@ -153,48 +185,57 @@ class SCRFD:
|
||||
|
||||
image, resize_factor = resize_image(image, target_shape=self.input_size)
|
||||
|
||||
image_tensor, _ = self.preprocess(image)
|
||||
image_tensor = self.preprocess(image)
|
||||
|
||||
# ONNXRuntime inference
|
||||
outputs = self.inference(image_tensor)
|
||||
|
||||
scores_list, bboxes_list, kpss_list = self.postprocess(outputs, image.shape[:2])
|
||||
scores_list, bboxes_list, kpss_list = self.postprocess(outputs, image_size=image.shape[:2])
|
||||
|
||||
scores = np.vstack(scores_list)
|
||||
scores_ravel = scores.ravel()
|
||||
order = scores_ravel.argsort()[::-1]
|
||||
bboxes = np.vstack(bboxes_list) / resize_factor
|
||||
|
||||
kpss = np.vstack(kpss_list) / resize_factor
|
||||
bboxes = np.vstack(bboxes_list) / resize_factor
|
||||
landmarks = np.vstack(kpss_list) / resize_factor
|
||||
|
||||
pre_det = np.hstack((bboxes, scores)).astype(np.float32, copy=False)
|
||||
pre_det = pre_det[order, :]
|
||||
keep = non_max_supression(pre_det, threshold=self.iou_thres)
|
||||
det = pre_det[keep, :]
|
||||
|
||||
kpss = kpss[order, :, :]
|
||||
kpss = kpss[keep, :, :]
|
||||
keep = non_max_supression(pre_det, threshold=self.nms_thresh)
|
||||
|
||||
det = pre_det[keep, :]
|
||||
landmarks = landmarks[order, :, :]
|
||||
landmarks = landmarks[keep, :, :].astype(np.int32)
|
||||
|
||||
if 0 < max_num < det.shape[0]:
|
||||
# Calculate area of detections
|
||||
area = (det[:, 2] - det[:, 0]) * (det[:, 3] - det[:, 1])
|
||||
center = (original_height // 2, original_width // 2)
|
||||
|
||||
# Calculate offsets from image center
|
||||
center = (original_height // 2, original_width // 2)
|
||||
offsets = np.vstack(
|
||||
[
|
||||
(det[:, 0] + det[:, 2]) / 2 - center[1],
|
||||
(det[:, 1] + det[:, 3]) / 2 - center[0],
|
||||
]
|
||||
)
|
||||
|
||||
# Calculate scores based on the chosen metric
|
||||
offset_dist_squared = np.sum(np.power(offsets, 2.0), axis=0)
|
||||
if metric == "max":
|
||||
values = area
|
||||
else:
|
||||
values = area - offset_dist_squared * center_weight # some extra weight on the centering
|
||||
values = area - offset_dist_squared * center_weight
|
||||
|
||||
# Sort by scores and select top `max_num`
|
||||
sorted_indices = np.argsort(values)[::-1][:max_num]
|
||||
det = det[sorted_indices]
|
||||
kpss = kpss[sorted_indices]
|
||||
landmarks = landmarks[sorted_indices]
|
||||
|
||||
|
||||
|
||||
return det, kpss
|
||||
return det, landmarks
|
||||
|
||||
|
||||
def draw_bbox(frame, bbox, color=(0, 255, 0), thickness=2):
|
||||
@@ -210,7 +251,7 @@ def draw_keypoints(frame, points, color=(0, 0, 255), radius=2):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
detector = SCRFD(model_path="det_10g.onnx")
|
||||
detector = SCRFD()
|
||||
cap = cv2.VideoCapture(0)
|
||||
|
||||
if not cap.isOpened():
|
||||
|
||||
@@ -2,6 +2,7 @@ import logging
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
Logger = logging.getLogger("uniface")
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
import os
|
||||
import hashlib
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
from uniface.log import Logger
|
||||
import uniface.constants as const
|
||||
@@ -15,36 +16,38 @@ __all__ = ['verify_model_weights']
|
||||
|
||||
def verify_model_weights(model_name: str, root: str = '~/.uniface/models') -> str:
|
||||
"""
|
||||
Ensures model weights are available by downloading if missing and verifying integrity with a SHA-256 hash.
|
||||
Ensure model weights are present, downloading and verifying them using SHA-256 if necessary.
|
||||
|
||||
Checks if the specified model weights file exists in `root`. If missing, downloads from a predefined URL.
|
||||
The file is then verified using its SHA-256 hash. If verification fails, the corrupted file is deleted,
|
||||
Given a model identifier from an Enum class (e.g., `RetinaFaceWeights.MNET_V2`), this function checks if
|
||||
the corresponding `.onnx` weight file exists locally. If not, it downloads the file from a predefined URL.
|
||||
After download, the file’s integrity is verified using a SHA-256 hash. If verification fails, the file is deleted
|
||||
and an error is raised.
|
||||
|
||||
Args:
|
||||
model_name (str): Name of the model weights to verify or download.
|
||||
root (str, optional): Directory to store the model weights. Defaults to '~/.uniface/models'.
|
||||
model_name (Enum): Model weight identifier (e.g., `RetinaFaceWeights.MNET_V2`, `ArcFaceWeights.RESNET`, etc.).
|
||||
root (str, optional): Directory to store or locate the model weights. Defaults to '~/.uniface/models'.
|
||||
|
||||
Returns:
|
||||
str: Path to the verified model weights file.
|
||||
str: Absolute path to the verified model weights file.
|
||||
|
||||
Raises:
|
||||
ValueError: If the model is not found or if verification fails.
|
||||
ValueError: If the model is unknown or SHA-256 verification fails.
|
||||
ConnectionError: If downloading the file fails.
|
||||
|
||||
Examples:
|
||||
>>> # Download and verify 'retinaface_mnet025' weights
|
||||
>>> verify_model_weights('retinaface_mnet025')
|
||||
'/home/user/.uniface/models/retinaface_mnet025.onnx'
|
||||
>>> from uniface.models import RetinaFaceWeights, verify_model_weights
|
||||
>>> verify_model_weights(RetinaFaceWeights.MNET_V2)
|
||||
'/home/user/.uniface/models/retinaface_mnet_v2.onnx'
|
||||
|
||||
>>> # Use a custom directory
|
||||
>>> verify_model_weights('retinaface_r34', root='/custom/dir')
|
||||
>>> verify_model_weights(RetinaFaceWeights.RESNET34, root='/custom/dir')
|
||||
'/custom/dir/retinaface_r34.onnx'
|
||||
"""
|
||||
|
||||
root = os.path.expanduser(root)
|
||||
os.makedirs(root, exist_ok=True)
|
||||
model_path = os.path.join(root, f'{model_name}.onnx')
|
||||
|
||||
model_name = model_name.value
|
||||
model_path = os.path.normpath(os.path.join(root, f'{model_name}.onnx'))
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
url = const.MODEL_URLS.get(model_name)
|
||||
@@ -54,9 +57,9 @@ def verify_model_weights(model_name: str, root: str = '~/.uniface/models') -> st
|
||||
|
||||
Logger.info(f"Downloading model '{model_name}' from {url}")
|
||||
download_file(url, model_path)
|
||||
Logger.info(f"Successfully downloaded '{model_name}' to {os.path.normpath(model_path)}")
|
||||
Logger.info(f"Successfully downloaded '{model_name}' to {model_path}")
|
||||
else:
|
||||
Logger.info(f"Model '{model_name}' already exists at {os.path.normpath(model_path)}")
|
||||
Logger.info(f"Model '{model_name}' already exists at {model_path}")
|
||||
|
||||
expected_hash = const.MODEL_SHA256.get(model_name)
|
||||
if expected_hash and not verify_file_hash(model_path, expected_hash):
|
||||
@@ -72,10 +75,16 @@ def download_file(url: str, dest_path: str) -> None:
|
||||
try:
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
with open(dest_path, "wb") as file:
|
||||
with open(dest_path, "wb") as file, tqdm(
|
||||
desc=f"Downloading {dest_path}",
|
||||
unit='B',
|
||||
unit_scale=True,
|
||||
unit_divisor=1024
|
||||
) as progress:
|
||||
for chunk in response.iter_content(chunk_size=const.CHUNK_SIZE):
|
||||
if chunk:
|
||||
file.write(chunk)
|
||||
progress.update(len(chunk))
|
||||
except requests.RequestException as e:
|
||||
raise ConnectionError(f"Failed to download file from {url}. Error: {e}")
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from uniface.log import Logger
|
||||
from uniface.model_store import verify_model_weights
|
||||
from uniface.constants import RetinaFaceWeights
|
||||
from uniface.common import (
|
||||
nms,
|
||||
non_max_supression,
|
||||
resize_image,
|
||||
decode_boxes,
|
||||
generate_anchors,
|
||||
@@ -161,7 +161,7 @@ class RetinaFace:
|
||||
"""
|
||||
|
||||
original_height, original_width = image.shape[:2]
|
||||
|
||||
|
||||
if self.dynamic_size:
|
||||
height, width, _ = image.shape
|
||||
self._priors = generate_anchors(image_size=(height, width)) # generate anchors for each input image
|
||||
@@ -244,7 +244,7 @@ class RetinaFace:
|
||||
|
||||
# Apply NMS
|
||||
detections = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
|
||||
keep = nms(detections, self.nms_thresh)
|
||||
keep = non_max_supression(detections, self.nms_thresh)
|
||||
detections, landmarks = detections[keep], landmarks[keep]
|
||||
|
||||
# Keep top-k detections
|
||||
@@ -255,7 +255,7 @@ class RetinaFace:
|
||||
return detections, landmarks
|
||||
|
||||
def _scale_detections(self, boxes: np.ndarray, landmarks: np.ndarray, resize_factor: float, shape: Tuple[int, int]) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Scale bounding boxes and landmarks to the original image size."""
|
||||
# Scale bounding boxes and landmarks to the original image size.
|
||||
bbox_scale = np.array([shape[0], shape[1]] * 2)
|
||||
boxes = boxes * bbox_scale / resize_factor
|
||||
|
||||
|
||||
Reference in New Issue
Block a user