mirror of
https://github.com/yakhyo/uniface.git
synced 2025-12-30 09:02:25 +00:00
feat: Update recognition, landmark modules
This commit is contained in:
@@ -4,16 +4,17 @@ import time
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from uniface.detection import RetinaFace, draw_detections, SCRFD
|
||||
from uniface.constants import RetinaFaceWeights, SCRFDWeights
|
||||
# UPDATED: Use the factory function and import from the new location
|
||||
from uniface.detection import create_detector
|
||||
from uniface.visualization import draw_detections
|
||||
|
||||
|
||||
def run_inference(model, image_path, vis_threshold=0.6, save_dir="outputs"):
|
||||
def run_inference(detector, image_path: str, vis_threshold: float = 0.6, save_dir: str = "outputs"):
|
||||
"""
|
||||
Run face detection on a single image.
|
||||
|
||||
Args:
|
||||
model (RetinaFace): Initialized RetinaFace model.
|
||||
detector: Initialized face detector.
|
||||
image_path (str): Path to input image.
|
||||
vis_threshold (float): Threshold for drawing detections.
|
||||
save_dir (str): Directory to save output image.
|
||||
@@ -23,8 +24,18 @@ def run_inference(model, image_path, vis_threshold=0.6, save_dir="outputs"):
|
||||
print(f"❌ Error: Failed to load image from '{image_path}'")
|
||||
return
|
||||
|
||||
boxes, landmarks = model.detect(image)
|
||||
draw_detections(image, (boxes, landmarks), vis_threshold)
|
||||
# 1. Get the list of face dictionaries from the detector
|
||||
faces = detector.detect(image)
|
||||
|
||||
if faces:
|
||||
# 2. Unpack the data into separate lists
|
||||
bboxes = [face['bbox'] for face in faces]
|
||||
scores = [face['confidence'] for face in faces]
|
||||
landmarks = [face['landmarks'] for face in faces]
|
||||
|
||||
# 3. Pass the unpacked lists to the drawing function
|
||||
draw_detections(image, bboxes, scores, landmarks, vis_threshold=0.6)
|
||||
|
||||
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
output_path = os.path.join(save_dir, f"{os.path.splitext(os.path.basename(image_path))[0]}_out.jpg")
|
||||
@@ -33,28 +44,38 @@ def run_inference(model, image_path, vis_threshold=0.6, save_dir="outputs"):
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Run RetinaFace inference on an image.")
|
||||
parser = argparse.ArgumentParser(description="Run face detection on an image.")
|
||||
parser.add_argument("--image", type=str, required=True, help="Path to the input image")
|
||||
parser.add_argument("--model", type=str, default="MNET_V2", choices=[m.name for m in RetinaFaceWeights], help="Model variant to use")
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="retinaface",
|
||||
choices=['retinaface', 'scrfd'],
|
||||
help="Detection method to use."
|
||||
)
|
||||
parser.add_argument("--threshold", type=float, default=0.6, help="Visualization confidence threshold")
|
||||
parser.add_argument("--iterations", type=int, default=1, help="Number of inference runs for benchmarking")
|
||||
parser.add_argument("--save_dir", type=str, default="outputs", help="Directory to save output images")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
model_name = RetinaFaceWeights[args.model]
|
||||
model = RetinaFace(model_name=model_name)
|
||||
print(f"Initializing detector: {args.method}")
|
||||
detector = create_detector(method=args.method)
|
||||
|
||||
avg_time = 0
|
||||
for i in range(args.iterations):
|
||||
start = time.time()
|
||||
run_inference(model, args.image, args.threshold, args.save_dir)
|
||||
run_inference(detector, args.image, args.threshold, args.save_dir)
|
||||
elapsed = time.time() - start
|
||||
print(f"[{i + 1}/{args.iterations}] ⏱️ Inference time: {elapsed:.4f} seconds")
|
||||
avg_time += elapsed
|
||||
if i >= 0: # Avoid counting the first run if it includes model loading time
|
||||
avg_time += elapsed
|
||||
|
||||
if args.iterations > 1:
|
||||
print(f"\n🔥 Average inference time over {args.iterations} runs: {avg_time / args.iterations:.4f} seconds")
|
||||
# Adjust average calculation to exclude potential first-run overhead
|
||||
effective_iterations = max(1, args.iterations)
|
||||
print(
|
||||
f"\n🔥 Average inference time over {effective_iterations} runs: {avg_time / effective_iterations:.4f} seconds")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
101
scripts/run_face_search.py
Normal file
101
scripts/run_face_search.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import cv2
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
# Use the new high-level factory functions
|
||||
from uniface.detection import create_detector
|
||||
from uniface.recognition import create_recognizer
|
||||
from uniface.face_utils import compute_similarity
|
||||
|
||||
|
||||
def extract_reference_embedding(detector, recognizer, image_path: str) -> np.ndarray:
|
||||
"""Extracts a normalized embedding from the first face found in an image."""
|
||||
image = cv2.imread(image_path)
|
||||
if image is None:
|
||||
raise RuntimeError(f"Failed to load image: {image_path}")
|
||||
|
||||
faces = detector.detect(image)
|
||||
if not faces:
|
||||
raise RuntimeError("No faces found in reference image.")
|
||||
|
||||
# Get landmarks from the first detected face dictionary
|
||||
landmarks = np.array(faces[0]['landmarks'])
|
||||
|
||||
# Use normalized embedding for more reliable similarity comparison
|
||||
embedding = recognizer.get_normalized_embedding(image, landmarks)
|
||||
return embedding
|
||||
|
||||
|
||||
def run_video(detector, recognizer, ref_embedding: np.ndarray, threshold: float = 0.4):
|
||||
"""Run real-time face recognition from a webcam feed."""
|
||||
cap = cv2.VideoCapture(0)
|
||||
if not cap.isOpened():
|
||||
raise RuntimeError("Webcam could not be opened.")
|
||||
print("Webcam started. Press 'q' to quit.")
|
||||
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
faces = detector.detect(frame)
|
||||
|
||||
# Loop through each detected face
|
||||
for face in faces:
|
||||
# Extract bbox and landmarks from the dictionary
|
||||
bbox = face['bbox']
|
||||
landmarks = np.array(face['landmarks'])
|
||||
|
||||
x1, y1, x2, y2 = map(int, bbox)
|
||||
|
||||
# Get the normalized embedding for the current face
|
||||
embedding = recognizer.get_normalized_embedding(frame, landmarks)
|
||||
|
||||
# Compare with the reference embedding
|
||||
sim = compute_similarity(ref_embedding, embedding)
|
||||
|
||||
# Draw results
|
||||
label = f"Match ({sim:.2f})" if sim > threshold else f"Unknown ({sim:.2f})"
|
||||
color = (0, 255, 0) if sim > threshold else (0, 0, 255)
|
||||
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
|
||||
cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
|
||||
|
||||
cv2.imshow("Face Recognition", frame)
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||
break
|
||||
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Face recognition using a reference image.")
|
||||
parser.add_argument("--image", type=str, required=True, help="Path to the reference face image.")
|
||||
parser.add_argument(
|
||||
"--detector",
|
||||
type=str,
|
||||
default="scrfd",
|
||||
choices=['retinaface', 'scrfd'],
|
||||
help="Face detection method."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recognizer",
|
||||
type=str,
|
||||
default="arcface",
|
||||
choices=['arcface', 'mobileface', 'sphereface'],
|
||||
help="Face recognition method."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("Initializing models...")
|
||||
detector = create_detector(method=args.detector)
|
||||
recognizer = create_recognizer(method=args.recognizer)
|
||||
|
||||
print("Extracting reference embedding...")
|
||||
ref_embedding = extract_reference_embedding(detector, recognizer, args.image)
|
||||
|
||||
run_video(detector, recognizer, ref_embedding)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -2,18 +2,21 @@ import cv2
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from uniface.detection import RetinaFace
|
||||
from uniface.constants import RetinaFaceWeights
|
||||
from uniface.recognition import ArcFace
|
||||
# Use the new high-level factory functions for consistency
|
||||
from uniface.detection import create_detector
|
||||
from uniface.recognition import create_recognizer
|
||||
|
||||
# Import enums for argument choices
|
||||
from uniface.constants import RetinaFaceWeights, ArcFaceWeights, MobileFaceWeights, SphereFaceWeights
|
||||
|
||||
|
||||
def run_inference(detector, recognizer, image_path):
|
||||
def run_inference(detector, recognizer, image_path: str):
|
||||
"""
|
||||
Detect faces and extract embeddings from a single image.
|
||||
|
||||
Args:
|
||||
detector (RetinaFace): Initialized face detector.
|
||||
recognizer (ArcFace): Face recognition model.
|
||||
detector: Initialized face detector.
|
||||
recognizer: Initialized face recognition model.
|
||||
image_path (str): Path to the input image.
|
||||
"""
|
||||
image = cv2.imread(image_path)
|
||||
@@ -21,36 +24,53 @@ def run_inference(detector, recognizer, image_path):
|
||||
print(f"Error: Failed to load image from '{image_path}'")
|
||||
return
|
||||
|
||||
boxes, landmarks = detector.detect(image)
|
||||
faces = detector.detect(image)
|
||||
|
||||
if len(boxes) == 0:
|
||||
if not faces:
|
||||
print("No faces detected.")
|
||||
return
|
||||
|
||||
print(f"Detected {len(boxes)} face(s). Extracting embeddings...")
|
||||
print(f"Detected {len(faces)} face(s). Extracting embeddings for the first face...")
|
||||
|
||||
for i, landmark in enumerate(landmarks[:1]):
|
||||
embedding = recognizer.get_embedding(image, landmark)
|
||||
norm_embedding = recognizer.get_normalized_embedding(image, landmark)
|
||||
print("embedding:", np.sum(embedding))
|
||||
print("norm embedding:",np.sum(norm_embedding))
|
||||
# Process the first detected face
|
||||
first_face = faces[0]
|
||||
landmarks = np.array(first_face['landmarks']) # Convert landmarks to numpy array
|
||||
|
||||
# Extract embedding using the landmarks from the face dictionary
|
||||
embedding = recognizer.get_embedding(image, landmarks)
|
||||
norm_embedding = recognizer.get_normalized_embedding(image, landmarks)
|
||||
|
||||
# Print some info about the embeddings
|
||||
print(f" - Embedding shape: {embedding.shape}")
|
||||
print(f" - L2 norm of unnormalized embedding: {np.linalg.norm(embedding):.4f}")
|
||||
print(f" - L2 norm of normalized embedding: {np.linalg.norm(norm_embedding):.4f}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Extract face embeddings from a single image.")
|
||||
parser.add_argument("--image", type=str, required=True, help="Path to the input image.")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
"--detector",
|
||||
type=str,
|
||||
default="MNET_V2",
|
||||
choices=[m.name for m in RetinaFaceWeights],
|
||||
help="RetinaFace model variant to use."
|
||||
default="retinaface",
|
||||
choices=['retinaface', 'scrfd'],
|
||||
help="Face detection method to use."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recognizer",
|
||||
type=str,
|
||||
default="arcface",
|
||||
choices=['arcface', 'mobileface', 'sphereface'],
|
||||
help="Face recognition method to use."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
detector = RetinaFace(model_name=RetinaFaceWeights[args.model])
|
||||
recognizer = ArcFace()
|
||||
print(f"Initializing detector: {args.detector}")
|
||||
detector = create_detector(method=args.detector, model_name=RetinaFaceWeights.MNET_V2)
|
||||
|
||||
print(f"Initializing recognizer: {args.recognizer}")
|
||||
recognizer = create_recognizer(method=args.recognizer)
|
||||
|
||||
run_inference(detector, recognizer, args.image)
|
||||
|
||||
|
||||
@@ -1,69 +0,0 @@
|
||||
import cv2
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from uniface.detection import RetinaFace
|
||||
from uniface.constants import RetinaFaceWeights
|
||||
from uniface.recognition import ArcFace
|
||||
from uniface.face_utils import compute_similarity
|
||||
|
||||
|
||||
def extract_reference_embedding(detector, recognizer, image_path):
|
||||
image = cv2.imread(image_path)
|
||||
if image is None:
|
||||
raise RuntimeError(f"Failed to load image: {image_path}")
|
||||
|
||||
boxes, landmarks = detector.detect(image)
|
||||
if len(boxes) == 0:
|
||||
raise RuntimeError("No faces found in reference image.")
|
||||
|
||||
embedding = recognizer.get_embedding(image, landmarks[0])
|
||||
return embedding
|
||||
|
||||
|
||||
def run_video(detector, recognizer, ref_embedding, threshold=0.30):
|
||||
cap = cv2.VideoCapture(0)
|
||||
if not cap.isOpened():
|
||||
raise RuntimeError("Webcam could not be opened.")
|
||||
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
boxes, landmarks = detector.detect(frame)
|
||||
|
||||
for box, lm in zip(boxes, landmarks):
|
||||
x1, y1, x2, y2 = map(int, box[:4])
|
||||
embedding = recognizer.get_embedding(frame, lm)
|
||||
sim = compute_similarity(ref_embedding, embedding)
|
||||
label = f"Match ({sim:.2f})" if sim > threshold else f"Unknown ({sim:.2f})"
|
||||
color = (0, 255, 0) if sim > threshold else (0, 0, 255)
|
||||
|
||||
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
|
||||
cv2.putText(frame, label, (x1, y1 - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
|
||||
|
||||
cv2.imshow("Face Recognition", frame)
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||
break
|
||||
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Face recognition using a reference image.")
|
||||
parser.add_argument("--image", type=str, required=True, help="Path to the reference face image.")
|
||||
parser.add_argument("--model", type=str, default="MNET_V2",
|
||||
choices=[m.name for m in RetinaFaceWeights], help="Face detector model.")
|
||||
args = parser.parse_args()
|
||||
|
||||
detector = RetinaFace(model_name=RetinaFaceWeights[args.model])
|
||||
recognizer = ArcFace()
|
||||
ref_embedding = extract_reference_embedding(detector, recognizer, args.image)
|
||||
run_video(detector, recognizer, ref_embedding)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user