mirror of
https://github.com/yakhyo/uniface.git
synced 2025-12-30 09:02:25 +00:00
feat: Update recognition, landmark modules
This commit is contained in:
@@ -2,18 +2,21 @@ import cv2
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from uniface.detection import RetinaFace
|
||||
from uniface.constants import RetinaFaceWeights
|
||||
from uniface.recognition import ArcFace
|
||||
# Use the new high-level factory functions for consistency
|
||||
from uniface.detection import create_detector
|
||||
from uniface.recognition import create_recognizer
|
||||
|
||||
# Import enums for argument choices
|
||||
from uniface.constants import RetinaFaceWeights, ArcFaceWeights, MobileFaceWeights, SphereFaceWeights
|
||||
|
||||
|
||||
def run_inference(detector, recognizer, image_path):
|
||||
def run_inference(detector, recognizer, image_path: str):
|
||||
"""
|
||||
Detect faces and extract embeddings from a single image.
|
||||
|
||||
Args:
|
||||
detector (RetinaFace): Initialized face detector.
|
||||
recognizer (ArcFace): Face recognition model.
|
||||
detector: Initialized face detector.
|
||||
recognizer: Initialized face recognition model.
|
||||
image_path (str): Path to the input image.
|
||||
"""
|
||||
image = cv2.imread(image_path)
|
||||
@@ -21,36 +24,53 @@ def run_inference(detector, recognizer, image_path):
|
||||
print(f"Error: Failed to load image from '{image_path}'")
|
||||
return
|
||||
|
||||
boxes, landmarks = detector.detect(image)
|
||||
faces = detector.detect(image)
|
||||
|
||||
if len(boxes) == 0:
|
||||
if not faces:
|
||||
print("No faces detected.")
|
||||
return
|
||||
|
||||
print(f"Detected {len(boxes)} face(s). Extracting embeddings...")
|
||||
print(f"Detected {len(faces)} face(s). Extracting embeddings for the first face...")
|
||||
|
||||
for i, landmark in enumerate(landmarks[:1]):
|
||||
embedding = recognizer.get_embedding(image, landmark)
|
||||
norm_embedding = recognizer.get_normalized_embedding(image, landmark)
|
||||
print("embedding:", np.sum(embedding))
|
||||
print("norm embedding:",np.sum(norm_embedding))
|
||||
# Process the first detected face
|
||||
first_face = faces[0]
|
||||
landmarks = np.array(first_face['landmarks']) # Convert landmarks to numpy array
|
||||
|
||||
# Extract embedding using the landmarks from the face dictionary
|
||||
embedding = recognizer.get_embedding(image, landmarks)
|
||||
norm_embedding = recognizer.get_normalized_embedding(image, landmarks)
|
||||
|
||||
# Print some info about the embeddings
|
||||
print(f" - Embedding shape: {embedding.shape}")
|
||||
print(f" - L2 norm of unnormalized embedding: {np.linalg.norm(embedding):.4f}")
|
||||
print(f" - L2 norm of normalized embedding: {np.linalg.norm(norm_embedding):.4f}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Extract face embeddings from a single image.")
|
||||
parser.add_argument("--image", type=str, required=True, help="Path to the input image.")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
"--detector",
|
||||
type=str,
|
||||
default="MNET_V2",
|
||||
choices=[m.name for m in RetinaFaceWeights],
|
||||
help="RetinaFace model variant to use."
|
||||
default="retinaface",
|
||||
choices=['retinaface', 'scrfd'],
|
||||
help="Face detection method to use."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recognizer",
|
||||
type=str,
|
||||
default="arcface",
|
||||
choices=['arcface', 'mobileface', 'sphereface'],
|
||||
help="Face recognition method to use."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
detector = RetinaFace(model_name=RetinaFaceWeights[args.model])
|
||||
recognizer = ArcFace()
|
||||
print(f"Initializing detector: {args.detector}")
|
||||
detector = create_detector(method=args.detector, model_name=RetinaFaceWeights.MNET_V2)
|
||||
|
||||
print(f"Initializing recognizer: {args.recognizer}")
|
||||
recognizer = create_recognizer(method=args.recognizer)
|
||||
|
||||
run_inference(detector, recognizer, args.image)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user