mirror of
https://github.com/yakhyo/uniface.git
synced 2025-12-30 09:02:25 +00:00
ref: Update attribute and landmark modules
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -1,3 +1,5 @@
|
|||||||
|
tmp_*
|
||||||
|
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.py[cod]
|
*.py[cod]
|
||||||
|
|||||||
@@ -1,2 +1,96 @@
|
|||||||
from .age_gender import AgeGender
|
# Copyright 2025 Yakhyokhuja Valikhujaev
|
||||||
from .emotion import Emotion
|
# Author: Yakhyokhuja Valikhujaev
|
||||||
|
# GitHub: https://github.com/yakhyo
|
||||||
|
|
||||||
|
from typing import Dict, Any, List, Union
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from uniface.attribute.age_gender import AgeGender
|
||||||
|
from uniface.attribute.emotion import Emotion
|
||||||
|
from uniface.attribute.base import Attribute
|
||||||
|
from uniface.constants import AgeGenderWeights, DDAMFNWeights
|
||||||
|
|
||||||
|
# Public API for the attribute module
|
||||||
|
__all__ = [
|
||||||
|
"AgeGender",
|
||||||
|
"Emotion",
|
||||||
|
"create_attribute_predictor",
|
||||||
|
"predict_attributes"
|
||||||
|
]
|
||||||
|
|
||||||
|
# A mapping from model enums to their corresponding attribute classes
|
||||||
|
_ATTRIBUTE_MODELS = {
|
||||||
|
**{model: AgeGender for model in AgeGenderWeights},
|
||||||
|
**{model: Emotion for model in DDAMFNWeights}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_attribute_predictor(
|
||||||
|
model_name: Union[AgeGenderWeights, DDAMFNWeights],
|
||||||
|
**kwargs: Any
|
||||||
|
) -> Attribute:
|
||||||
|
"""
|
||||||
|
Factory function to create an attribute predictor instance.
|
||||||
|
|
||||||
|
This high-level API simplifies the creation of attribute models by
|
||||||
|
dynamically selecting the correct class based on the provided model enum.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: The enum corresponding to the desired attribute model
|
||||||
|
(e.g., AgeGenderWeights.DEFAULT or DDAMFNWeights.AFFECNET7).
|
||||||
|
**kwargs: Additional keyword arguments to pass to the model's constructor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An initialized instance of an Attribute predictor class (e.g., AgeGender).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the provided model_name is not a supported enum.
|
||||||
|
"""
|
||||||
|
model_class = _ATTRIBUTE_MODELS.get(model_name)
|
||||||
|
|
||||||
|
if model_class is None:
|
||||||
|
raise ValueError(f"Unsupported attribute model: {model_name}. "
|
||||||
|
f"Please choose from AgeGenderWeights or DDAMFNWeights.")
|
||||||
|
|
||||||
|
# Pass model_name to the constructor, as some classes might need it
|
||||||
|
return model_class(model_name=model_name, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def predict_attributes(
|
||||||
|
image: np.ndarray,
|
||||||
|
detections: List[Dict[str, np.ndarray]],
|
||||||
|
predictor: Attribute
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
High-level API to predict attributes for multiple detected faces.
|
||||||
|
|
||||||
|
This function iterates through a list of face detections, runs the
|
||||||
|
specified attribute predictor on each one, and appends the results back
|
||||||
|
into the detection dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (np.ndarray): The full input image in BGR format.
|
||||||
|
detections (List[Dict]): A list of detection results, where each dict
|
||||||
|
must contain a 'bbox' and optionally 'landmark'.
|
||||||
|
predictor (Attribute): An initialized attribute predictor instance,
|
||||||
|
created by `create_attribute_predictor`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The list of detections, where each dictionary is updated with a new
|
||||||
|
'attributes' key containing the prediction result.
|
||||||
|
"""
|
||||||
|
for face in detections:
|
||||||
|
# Initialize attributes dict if it doesn't exist
|
||||||
|
if 'attributes' not in face:
|
||||||
|
face['attributes'] = {}
|
||||||
|
|
||||||
|
if isinstance(predictor, AgeGender):
|
||||||
|
gender, age = predictor(image, face['bbox'])
|
||||||
|
face['attributes']['gender'] = gender
|
||||||
|
face['attributes']['age'] = age
|
||||||
|
elif isinstance(predictor, Emotion):
|
||||||
|
emotion, confidence = predictor(image, face['landmark'])
|
||||||
|
face['attributes']['emotion'] = emotion
|
||||||
|
face['attributes']['confidence'] = confidence
|
||||||
|
|
||||||
|
return detections
|
||||||
|
|||||||
@@ -5,232 +5,176 @@
|
|||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
|
from typing import Tuple, Union, List
|
||||||
|
|
||||||
from typing import Tuple
|
from uniface.attribute.base import Attribute
|
||||||
|
|
||||||
from uniface.log import Logger
|
from uniface.log import Logger
|
||||||
from uniface.constants import AgeGenderWeights
|
from uniface.constants import AgeGenderWeights
|
||||||
from uniface.face_utils import bbox_center_alignment
|
from uniface.face_utils import bbox_center_alignment
|
||||||
from uniface.model_store import verify_model_weights
|
from uniface.model_store import verify_model_weights
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["AgeGender"]
|
__all__ = ["AgeGender"]
|
||||||
|
|
||||||
|
|
||||||
class AgeGender:
|
class AgeGender(Attribute):
|
||||||
"""
|
"""
|
||||||
Age and gender prediction model using ONNX Runtime.
|
Age and gender prediction model using ONNX Runtime.
|
||||||
|
|
||||||
Loads a pretrained ONNX model to predict both age (in years) and gender
|
This class inherits from the base `Attribute` class and implements the
|
||||||
(0: female, 1: male) from a detected face region. Handles model loading,
|
functionality for predicting age (in years) and gender (0 for female,
|
||||||
preprocessing, inference, and output interpretation.
|
1 for male) from a face image. It requires a bounding box to locate the face.
|
||||||
|
"""
|
||||||
|
|
||||||
Attributes:
|
def __init__(self, model_name: AgeGenderWeights = AgeGenderWeights.DEFAULT) -> None:
|
||||||
input_size (Tuple[int, int]): Model's expected input resolution (width, height).
|
"""
|
||||||
input_mean (float): Mean value used for input normalization.
|
Initializes the AgeGender prediction model.
|
||||||
input_std (float): Standard deviation used for input normalization.
|
|
||||||
model_path (str): Path to the verified ONNX model file.
|
|
||||||
session (onnxruntime.InferenceSession): ONNX Runtime session for inference.
|
|
||||||
input_names (List[str]): List of input node names.
|
|
||||||
output_names (List[str]): List of output node names.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name (AgeGenderWeights): Enum specifying the age-gender model to load.
|
model_name (AgeGenderWeights): The enum specifying the model weights
|
||||||
input_size (Tuple[int, int]): Resolution for model input; defaults to (112, 112).
|
to load.
|
||||||
"""
|
"""
|
||||||
|
Logger.info(f"Initializing AgeGender with model={model_name.name}")
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_name: AgeGenderWeights = AgeGenderWeights.DEFAULT,
|
|
||||||
input_size: Tuple[int, int] = (112, 112)
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Initializes the Age and Gender prediction model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name: Model weights enum to use
|
|
||||||
input_size: Input resolution for the model (width, height)
|
|
||||||
"""
|
|
||||||
Logger.info(
|
|
||||||
f"Initializing AgeGender with model={model_name}, "
|
|
||||||
f"input_size={input_size}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Model configuration
|
|
||||||
self.input_size = input_size
|
|
||||||
self.input_std = 1.0
|
|
||||||
self.input_mean = 0.0
|
|
||||||
|
|
||||||
# Get path to model weights
|
|
||||||
self.model_path = verify_model_weights(model_name)
|
self.model_path = verify_model_weights(model_name)
|
||||||
Logger.info(f"Verified model weights located at: {self.model_path}")
|
|
||||||
|
|
||||||
# Initialize model
|
|
||||||
self._initialize_model()
|
self._initialize_model()
|
||||||
|
|
||||||
def _initialize_model(self):
|
def _initialize_model(self) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize the ONNX model for inference.
|
Initializes the ONNX model and creates an inference session.
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If the model fails to load or initialize.
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Initialize session with available providers
|
|
||||||
self.session = ort.InferenceSession(
|
self.session = ort.InferenceSession(
|
||||||
self.model_path,
|
self.model_path,
|
||||||
providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
|
providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
|
||||||
)
|
)
|
||||||
|
# Get model input details from the loaded model
|
||||||
# Extract model metadata
|
input_meta = self.session.get_inputs()[0]
|
||||||
input_metadata = self.session.get_inputs()[0]
|
self.input_name = input_meta.name
|
||||||
input_shape = input_metadata.shape
|
self.input_size = tuple(input_meta.shape[2:4]) # (height, width)
|
||||||
self.input_size = tuple(input_shape[2:4][::-1]) # Update from model (width, height)
|
|
||||||
|
|
||||||
# Get input/output names
|
|
||||||
self.input_names = [input.name for input in self.session.get_inputs()]
|
|
||||||
self.output_names = [output.name for output in self.session.get_outputs()]
|
self.output_names = [output.name for output in self.session.get_outputs()]
|
||||||
|
Logger.info(f"Successfully initialized AgeGender model with input size {self.input_size}")
|
||||||
Logger.info(f"Successfully initialized AgeGender model")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
Logger.error(f"Failed to load AgeGender model from '{self.model_path}'", exc_info=True)
|
Logger.error(f"Failed to load AgeGender model from '{self.model_path}'", exc_info=True)
|
||||||
raise RuntimeError(f"Failed to initialize AgeGender model: {e}")
|
raise RuntimeError(f"Failed to initialize AgeGender model: {e}")
|
||||||
|
|
||||||
def preprocess(self, image: np.ndarray, bbox: np.ndarray) -> np.ndarray:
|
def preprocess(self, image: np.ndarray, bbox: Union[List, np.ndarray]) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Preprocess the input image and face bounding box for inference.
|
Aligns the face based on the bounding box and preprocesses it for inference.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image: Input image in BGR format
|
image (np.ndarray): The full input image in BGR format.
|
||||||
bbox: Face bounding box coordinates [x1, y1, x2, y2]
|
bbox (Union[List, np.ndarray]): The face bounding box coordinates [x1, y1, x2, y2].
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Preprocessed image blob ready for inference
|
np.ndarray: The preprocessed image blob ready for inference.
|
||||||
"""
|
"""
|
||||||
# Calculate face dimensions and center
|
bbox = np.asarray(bbox)
|
||||||
|
|
||||||
width, height = bbox[2] - bbox[0], bbox[3] - bbox[1]
|
width, height = bbox[2] - bbox[0], bbox[3] - bbox[1]
|
||||||
center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
|
center = ((bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2)
|
||||||
|
scale = self.input_size[1] / (max(width, height) * 1.5)
|
||||||
|
|
||||||
# Determine scale to fit face with margin
|
# **Rotation parameter restored here**
|
||||||
scale = self.input_size[0] / (max(width, height) * 1.5)
|
|
||||||
rotation = 0.0
|
rotation = 0.0
|
||||||
|
|
||||||
# Align face based on bounding box
|
|
||||||
aligned_face, _ = bbox_center_alignment(
|
aligned_face, _ = bbox_center_alignment(
|
||||||
image, center, self.input_size[0], scale, rotation
|
image, center, self.input_size[1], scale, rotation
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert to blob format for network input
|
blob = cv2.dnn.blobFromImage(
|
||||||
face_blob = cv2.dnn.blobFromImage(
|
|
||||||
aligned_face,
|
aligned_face,
|
||||||
1.0 / self.input_std,
|
scalefactor=1.0,
|
||||||
self.input_size,
|
size=self.input_size[::-1],
|
||||||
(self.input_mean, self.input_mean, self.input_mean),
|
mean=(0.0, 0.0, 0.0),
|
||||||
swapRB=True # Convert BGR to RGB
|
swapRB=True
|
||||||
)
|
)
|
||||||
|
return blob
|
||||||
|
|
||||||
return face_blob
|
def postprocess(self, prediction: np.ndarray) -> Tuple[str, int]:
|
||||||
|
|
||||||
def postprocess(self, predictions: np.ndarray) -> Tuple[int, int]:
|
|
||||||
"""
|
"""
|
||||||
Process model predictions to extract gender and age.
|
Processes the raw model output to extract gender and age.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
predictions: Raw model output, shape [1, 3] where:
|
prediction (np.ndarray): The raw output from the model inference.
|
||||||
- First two elements represent gender logits
|
|
||||||
- Third element represents normalized age
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple containing:
|
Tuple[str, int]: A tuple containing the predicted gender label ("Female" or "Male")
|
||||||
- Gender (0: female, 1: male)
|
and age (in years).
|
||||||
- Age in years
|
|
||||||
"""
|
"""
|
||||||
# First two values are gender logits (female/male)
|
# First two values are gender logits
|
||||||
gender = int(np.argmax(predictions[:2]))
|
gender_id = int(np.argmax(prediction[:2]))
|
||||||
|
gender = "Female" if gender_id == 0 else "Male"
|
||||||
# Third value is normalized age that needs scaling
|
# Third value is normalized age, scaled by 100
|
||||||
age = int(np.round(predictions[2] * 100))
|
age = int(np.round(prediction[2] * 100))
|
||||||
|
|
||||||
return gender, age
|
return gender, age
|
||||||
|
|
||||||
def predict(self, image: np.ndarray, bbox: np.ndarray) -> Tuple[int, int]:
|
def predict(self, image: np.ndarray, bbox: Union[List, np.ndarray]) -> Tuple[str, int]:
|
||||||
"""
|
"""
|
||||||
Predict age and gender for a face in the image.
|
Predicts age and gender for a single face specified by a bounding box.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image: Input image in BGR format
|
image (np.ndarray): The full input image in BGR format.
|
||||||
bbox: Face bounding box [x1, y1, x2, y2]
|
bbox (Union[List, np.ndarray]): The face bounding box coordinates [x1, y1, x2, y2].
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- 'gender_id': Gender as integer (0: female, 1: male)
|
Tuple[str, int]: A tuple containing the predicted gender label and age.
|
||||||
- 'age': Age in years
|
|
||||||
"""
|
"""
|
||||||
# Preprocess and run inference
|
|
||||||
face_blob = self.preprocess(image, bbox)
|
face_blob = self.preprocess(image, bbox)
|
||||||
predictions = self.session.run(
|
prediction = self.session.run(self.output_names, {self.input_name: face_blob})[0][0]
|
||||||
self.output_names,
|
gender, age = self.postprocess(prediction)
|
||||||
{self.input_names[0]: face_blob}
|
return gender, age
|
||||||
)[0][0]
|
|
||||||
|
|
||||||
# Extract gender and age from predictions
|
|
||||||
gender_id, age = self.postprocess(predictions)
|
|
||||||
|
|
||||||
return gender_id, age
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: For testing purposes only, remove later
|
# TODO: below is only for testing, remove it later
|
||||||
|
if __name__ == "__main__":
|
||||||
def main():
|
# To run this script, you need to have uniface.detection installed
|
||||||
from uniface.detection import RetinaFace
|
# or available in your path.
|
||||||
|
from uniface.detection import create_detector
|
||||||
from uniface.constants import RetinaFaceWeights
|
from uniface.constants import RetinaFaceWeights
|
||||||
|
|
||||||
face_detector = RetinaFace(
|
print("Initializing models for live inference...")
|
||||||
model_name=RetinaFaceWeights.MNET_V2,
|
# 1. Initialize the face detector
|
||||||
conf_thresh=0.5,
|
# Using a smaller model for faster real-time performance
|
||||||
pre_nms_topk=5000,
|
detector = create_detector(model_name=RetinaFaceWeights.MNET_V2)
|
||||||
nms_thresh=0.4,
|
|
||||||
post_nms_topk=750,
|
|
||||||
dynamic_size=False,
|
|
||||||
input_size=(640, 640)
|
|
||||||
)
|
|
||||||
age_detector = AgeGender()
|
|
||||||
|
|
||||||
|
# 2. Initialize the attribute predictor
|
||||||
|
age_gender_predictor = AgeGender()
|
||||||
|
|
||||||
|
# 3. Start webcam capture
|
||||||
cap = cv2.VideoCapture(0)
|
cap = cv2.VideoCapture(0)
|
||||||
if not cap.isOpened():
|
if not cap.isOpened():
|
||||||
print("Webcam not available.")
|
print("Error: Could not open webcam.")
|
||||||
return
|
exit()
|
||||||
|
|
||||||
print("Press 'q' to quit.")
|
print("Starting webcam feed. Press 'q' to quit.")
|
||||||
while True:
|
while True:
|
||||||
ret, frame = cap.read()
|
ret, frame = cap.read()
|
||||||
if not ret:
|
if not ret:
|
||||||
print("Frame capture failed.")
|
print("Error: Failed to capture frame.")
|
||||||
break
|
break
|
||||||
|
|
||||||
boxes, landmarks = face_detector.detect(frame)
|
# Detect faces in the current frame
|
||||||
|
detections = detector.detect(frame)
|
||||||
|
|
||||||
for box, landmark in zip(boxes, landmarks):
|
# For each detected face, predict age and gender
|
||||||
x1, y1, x2, y2, score = box.astype(int)
|
for detection in detections:
|
||||||
face_crop = frame[y1:y2, x1:x2]
|
box = detection['bbox']
|
||||||
|
x1, y1, x2, y2 = map(int, box)
|
||||||
|
|
||||||
if face_crop.size == 0:
|
# Predict attributes
|
||||||
continue
|
gender, age = age_gender_predictor.predict(frame, box)
|
||||||
|
|
||||||
gender, age = age_detector.predict(frame, box[:4])
|
# Prepare text and draw on the frame
|
||||||
|
label = f"{gender}, {age}"
|
||||||
txt = f"{gender} ({age:.2f})"
|
|
||||||
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||||||
cv2.putText(frame, txt, (x1, y1 - 10),
|
cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
|
||||||
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
|
|
||||||
|
|
||||||
cv2.imshow("Face + Emotion Detection", frame)
|
# Display the resulting frame
|
||||||
|
cv2.imshow("Age and Gender Inference (Press 'q' to quit)", frame)
|
||||||
|
|
||||||
|
# Break the loop if 'q' is pressed
|
||||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Release resources
|
||||||
cap.release()
|
cap.release()
|
||||||
cv2.destroyAllWindows()
|
cv2.destroyAllWindows()
|
||||||
|
print("Inference stopped.")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|||||||
91
uniface/attribute/base.py
Normal file
91
uniface/attribute/base.py
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
# Copyright 2025 Yakhyokhuja Valikhujaev
|
||||||
|
# Author: Yakhyokhuja Valikhujaev
|
||||||
|
# GitHub: https://github.com/yakhyo
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class Attribute(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for face attribute models.
|
||||||
|
|
||||||
|
This class defines the common interface that all attribute models
|
||||||
|
(e.g., age-gender, emotion) must implement. It ensures a consistent API
|
||||||
|
across different attribute prediction modules in the library, making them
|
||||||
|
interchangeable and easy to use.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _initialize_model(self) -> None:
|
||||||
|
"""
|
||||||
|
Initializes the underlying model for inference.
|
||||||
|
|
||||||
|
This method should handle loading model weights, creating the
|
||||||
|
inference session (e.g., ONNX Runtime, PyTorch), and any necessary
|
||||||
|
warm-up procedures to prepare the model for prediction.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Subclasses must implement the _initialize_model method.")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def preprocess(self, image: np.ndarray, *args: Any) -> Any:
|
||||||
|
"""
|
||||||
|
Preprocesses the input data for the model.
|
||||||
|
|
||||||
|
This method should take a raw image and any other necessary data
|
||||||
|
(like bounding boxes or landmarks) and convert it into the format
|
||||||
|
expected by the model's inference engine (e.g., a blob or tensor).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (np.ndarray): The input image containing the face, typically
|
||||||
|
in BGR format.
|
||||||
|
*args: Additional arguments required for preprocessing, such as
|
||||||
|
bounding boxes or facial landmarks.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The preprocessed data ready for model inference.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Subclasses must implement the preprocess method.")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def postprocess(self, prediction: Any) -> Any:
|
||||||
|
"""
|
||||||
|
Postprocesses the raw model output into a human-readable format.
|
||||||
|
|
||||||
|
This method takes the raw output from the model's inference and
|
||||||
|
converts it into a meaningful result, such as an age value, a gender
|
||||||
|
label, or an emotion category.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prediction (Any): The raw output from the model's inference.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The final, processed attributes.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Subclasses must implement the postprocess method.")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def predict(self, image: np.ndarray, *args: Any) -> Any:
|
||||||
|
"""
|
||||||
|
Performs end-to-end attribute prediction on a given image.
|
||||||
|
|
||||||
|
This method orchestrates the full pipeline: it calls the preprocess,
|
||||||
|
inference, and postprocess steps to return the final, user-friendly
|
||||||
|
attribute prediction.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (np.ndarray): The input image containing the face.
|
||||||
|
*args: Additional data required for prediction, such as a bounding
|
||||||
|
box or landmarks.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The final predicted attributes.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Subclasses must implement the predict method.")
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs) -> Any:
|
||||||
|
"""
|
||||||
|
Provides a convenient, callable shortcut for the `predict` method.
|
||||||
|
"""
|
||||||
|
return self.predict(*args, **kwargs)
|
||||||
@@ -5,218 +5,166 @@
|
|||||||
import cv2
|
import cv2
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from typing import Tuple, Union, List
|
||||||
|
|
||||||
from typing import Tuple, Union
|
from uniface.attribute.base import Attribute
|
||||||
|
|
||||||
from uniface.log import Logger
|
from uniface.log import Logger
|
||||||
from uniface.constants import DDAMFNWeights
|
from uniface.constants import DDAMFNWeights
|
||||||
from uniface.face_utils import face_alignment
|
from uniface.face_utils import face_alignment
|
||||||
from uniface.model_store import verify_model_weights
|
from uniface.model_store import verify_model_weights
|
||||||
|
|
||||||
|
__all__ = ["Emotion"]
|
||||||
|
|
||||||
class Emotion:
|
|
||||||
|
class Emotion(Attribute):
|
||||||
"""
|
"""
|
||||||
Emotion recognition using a TorchScript model.
|
Emotion recognition model using a TorchScript model.
|
||||||
|
|
||||||
Args:
|
This class inherits from the base `Attribute` class and implements the
|
||||||
model_weights (DDAMFNWeights): Pretrained model weights enum. Defaults to AFFECNET7.
|
functionality for predicting one of several emotion categories from a face
|
||||||
input_size (Tuple[int, int]): Size of input images. Defaults to (112, 112).
|
image. It requires 5-point facial landmarks for alignment.
|
||||||
|
|
||||||
Attributes:
|
|
||||||
emotion_labels (List[str]): List of emotion labels the model can predict.
|
|
||||||
device (torch.device): Inference device (CPU or CUDA).
|
|
||||||
model (torch.jit.ScriptModule): Loaded TorchScript model.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If model weights are invalid or not found.
|
|
||||||
RuntimeError: If model loading fails.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_weights: DDAMFNWeights = DDAMFNWeights.AFFECNET7,
|
model_weights: DDAMFNWeights = DDAMFNWeights.AFFECNET7,
|
||||||
input_size: Tuple[int, int] = (112, 112)
|
input_size: Tuple[int, int] = (112, 112),
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize the emotion detector with a TorchScript model
|
Initializes the emotion recognition model.
|
||||||
"""
|
|
||||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
self.emotion_labels = [
|
|
||||||
"Neutral", "Happy", "Sad", "Surprise", "Fear", "Disgust", "Angry"
|
|
||||||
]
|
|
||||||
|
|
||||||
# Add contempt for AFFECNET8 model
|
Args:
|
||||||
|
model_weights (DDAMFNWeights): The enum for the model weights to load.
|
||||||
|
input_size (Tuple[int, int]): The expected input size for the model.
|
||||||
|
"""
|
||||||
|
Logger.info(f"Initializing Emotion with model={model_weights.name}")
|
||||||
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
self.input_size = input_size
|
||||||
|
self.model_path = verify_model_weights(model_weights)
|
||||||
|
|
||||||
|
# Define emotion labels based on the selected model
|
||||||
|
self.emotion_labels = ["Neutral", "Happy", "Sad", "Surprise", "Fear", "Disgust", "Angry"]
|
||||||
if model_weights == DDAMFNWeights.AFFECNET8:
|
if model_weights == DDAMFNWeights.AFFECNET8:
|
||||||
self.emotion_labels.append("Contempt")
|
self.emotion_labels.append("Contempt")
|
||||||
|
|
||||||
# Initialize image preprocessing parameters
|
self._initialize_model()
|
||||||
self.input_size = input_size
|
|
||||||
self.normalization_std = [0.229, 0.224, 0.225]
|
|
||||||
self.normalization_mean = [0.485, 0.456, 0.406]
|
|
||||||
|
|
||||||
Logger.info(
|
def _initialize_model(self) -> None:
|
||||||
f"Initialized Emotion class with model={model_weights.name}, "
|
|
||||||
f"device={'cuda' if torch.cuda.is_available() else 'cpu'}, "
|
|
||||||
f"num_classes={len(self.emotion_labels)}, input_size={self.input_size}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get path to model weights and initialize model
|
|
||||||
self.model_path = verify_model_weights(model_weights)
|
|
||||||
Logger.info(f"Verified model weights located at: {self.model_path}")
|
|
||||||
self._load_model()
|
|
||||||
|
|
||||||
def _load_model(self) -> None:
|
|
||||||
"""
|
"""
|
||||||
Loads and initializes a TorchScript model for emotion inference.
|
Loads and initializes the TorchScript model for inference.
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If loading the model fails.
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
self.model = torch.jit.load(self.model_path, map_location=self.device)
|
self.model = torch.jit.load(self.model_path, map_location=self.device)
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
Logger.info(f"TorchScript model successfully loaded from: {self.model_path}")
|
# Warm-up with a dummy input for faster first inference
|
||||||
|
|
||||||
# Warm-up with dummy input
|
|
||||||
dummy_input = torch.randn(1, 3, *self.input_size).to(self.device)
|
dummy_input = torch.randn(1, 3, *self.input_size).to(self.device)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
_ = self.model(dummy_input)
|
self.model(dummy_input)
|
||||||
Logger.info("Emotion model warmed up with dummy input.")
|
Logger.info(f"Successfully initialized Emotion model on {self.device}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
Logger.error(f"Failed to load TorchScript model from {self.model_path}: {e}")
|
Logger.error(f"Failed to load Emotion model from '{self.model_path}'", exc_info=True)
|
||||||
raise RuntimeError(f"Model loading failed: {str(e)}")
|
raise RuntimeError(f"Failed to initialize Emotion model: {e}")
|
||||||
|
|
||||||
def preprocess(self, image: np.ndarray) -> torch.Tensor:
|
def preprocess(self, image: np.ndarray, landmark: Union[List, np.ndarray]) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Preprocess image for model inference: resize, normalize and convert to tensor.
|
Aligns the face using landmarks and preprocesses it into a tensor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image (np.ndarray): BGR image (H, W, 3)
|
image (np.ndarray): The full input image in BGR format.
|
||||||
|
landmark (Union[List, np.ndarray]): The 5-point facial landmarks.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: Preprocessed image tensor of shape (1, 3, H, W)
|
torch.Tensor: The preprocessed image tensor ready for inference.
|
||||||
"""
|
"""
|
||||||
# Convert BGR to RGB
|
landmark = np.asarray(landmark)
|
||||||
rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
||||||
|
|
||||||
# Resize to target input size
|
|
||||||
resized_image = cv2.resize(rgb_image, self.input_size).astype(np.float32) / 255.0
|
|
||||||
|
|
||||||
# Normalize with mean and std
|
|
||||||
mean_array = np.array(self.normalization_mean, dtype=np.float32)
|
|
||||||
std_array = np.array(self.normalization_std, dtype=np.float32)
|
|
||||||
normalized_image = (resized_image - mean_array) / std_array
|
|
||||||
|
|
||||||
# Convert from HWC to CHW format
|
|
||||||
transposed_image = normalized_image.transpose((2, 0, 1))
|
|
||||||
|
|
||||||
# Convert to torch tensor and add batch dimension
|
|
||||||
tensor = torch.from_numpy(transposed_image).unsqueeze(0).to(self.device)
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
def predict(self, image: np.ndarray, landmark: np.ndarray) -> Tuple[Union[str, None], Union[float, None]]:
|
|
||||||
"""
|
|
||||||
Predict the emotion from a face image.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image (np.ndarray): Input face image in BGR format.
|
|
||||||
landmark (np.ndarray): Facial five point landmark.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[str, float]: (Predicted emotion label, Confidence score)
|
|
||||||
Returns (None, None) if prediction fails.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the input is not a valid BGR image.
|
|
||||||
"""
|
|
||||||
# Validate input
|
|
||||||
if not isinstance(image, np.ndarray):
|
|
||||||
Logger.error("Input must be a NumPy ndarray.")
|
|
||||||
raise ValueError("Input must be a NumPy ndarray (BGR image).")
|
|
||||||
|
|
||||||
if image.ndim != 3 or image.shape[2] != 3:
|
|
||||||
Logger.error(f"Invalid image shape: {image.shape}. Expected HxWx3 image.")
|
|
||||||
raise ValueError("Input image must have shape (H, W, 3).")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Align face using landmarks
|
|
||||||
aligned_image, _ = face_alignment(image, landmark)
|
aligned_image, _ = face_alignment(image, landmark)
|
||||||
|
|
||||||
# Preprocess and run inference
|
# Convert BGR to RGB, resize, normalize, and convert to a CHW tensor
|
||||||
input_tensor = self.preprocess(aligned_image)
|
rgb_image = cv2.cvtColor(aligned_image, cv2.COLOR_BGR2RGB)
|
||||||
|
resized_image = cv2.resize(rgb_image, self.input_size).astype(np.float32) / 255.0
|
||||||
|
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
|
||||||
|
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
|
||||||
|
normalized_image = (resized_image - mean) / std
|
||||||
|
transposed_image = normalized_image.transpose((2, 0, 1))
|
||||||
|
|
||||||
|
return torch.from_numpy(transposed_image).unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
|
def postprocess(self, prediction: torch.Tensor) -> Tuple[str, float]:
|
||||||
|
"""
|
||||||
|
Processes the raw model output to get the emotion label and confidence score.
|
||||||
|
"""
|
||||||
|
probabilities = torch.nn.functional.softmax(prediction, dim=1).squeeze().cpu().numpy()
|
||||||
|
pred_index = np.argmax(probabilities)
|
||||||
|
emotion_label = self.emotion_labels[pred_index]
|
||||||
|
confidence = float(probabilities[pred_index])
|
||||||
|
return emotion_label, confidence
|
||||||
|
|
||||||
|
def predict(self, image: np.ndarray, landmark: Union[List, np.ndarray]) -> Tuple[str, float]:
|
||||||
|
"""
|
||||||
|
Predicts the emotion from a single face specified by its landmarks.
|
||||||
|
"""
|
||||||
|
input_tensor = self.preprocess(image, landmark)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = self.model(input_tensor)
|
output = self.model(input_tensor)
|
||||||
|
|
||||||
# Handle case where model returns a tuple
|
|
||||||
if isinstance(output, tuple):
|
if isinstance(output, tuple):
|
||||||
output = output[0]
|
output = output[0]
|
||||||
|
|
||||||
# Get probabilities and prediction
|
return self.postprocess(output)
|
||||||
probabilities = torch.nn.functional.softmax(output, dim=1).squeeze(0).cpu().numpy()
|
|
||||||
predicted_index = int(np.argmax(probabilities))
|
|
||||||
confidence_score = round(float(probabilities[predicted_index]), 2)
|
|
||||||
|
|
||||||
return self.emotion_labels[predicted_index], confidence_score
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
Logger.error(f"Emotion inference failed: {e}")
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: For testing purposes only, remove later
|
# TODO: below is only for testing, remove it later
|
||||||
|
if __name__ == "__main__":
|
||||||
def main():
|
from uniface.detection import create_detector
|
||||||
from uniface import RetinaFace
|
|
||||||
from uniface.constants import RetinaFaceWeights
|
from uniface.constants import RetinaFaceWeights
|
||||||
|
|
||||||
face_detector = RetinaFace(
|
print("Initializing models for live inference...")
|
||||||
model_name=RetinaFaceWeights.MNET_V2,
|
# 1. Initialize the face detector
|
||||||
conf_thresh=0.5,
|
# Using a smaller model for faster real-time performance
|
||||||
pre_nms_topk=5000,
|
detector = create_detector(model_name=RetinaFaceWeights.MNET_V2)
|
||||||
nms_thresh=0.4,
|
|
||||||
post_nms_topk=750,
|
|
||||||
dynamic_size=False,
|
|
||||||
input_size=(640, 640)
|
|
||||||
)
|
|
||||||
emotion_detector = Emotion()
|
|
||||||
|
|
||||||
|
# 2. Initialize the attribute predictor
|
||||||
|
emotion_predictor = Emotion()
|
||||||
|
|
||||||
|
# 3. Start webcam capture
|
||||||
cap = cv2.VideoCapture(0)
|
cap = cv2.VideoCapture(0)
|
||||||
if not cap.isOpened():
|
if not cap.isOpened():
|
||||||
print("Webcam not available.")
|
print("Error: Could not open webcam.")
|
||||||
return
|
exit()
|
||||||
|
|
||||||
print("Press 'q' to quit.")
|
print("Starting webcam feed. Press 'q' to quit.")
|
||||||
while True:
|
while True:
|
||||||
ret, frame = cap.read()
|
ret, frame = cap.read()
|
||||||
if not ret:
|
if not ret:
|
||||||
print("Frame capture failed.")
|
print("Error: Failed to capture frame.")
|
||||||
break
|
break
|
||||||
|
|
||||||
boxes, landmarks = face_detector.detect(frame)
|
# Detect faces in the current frame.
|
||||||
|
# This method returns a list of dictionaries for each detected face.
|
||||||
|
detections = detector.detect(frame)
|
||||||
|
|
||||||
for box, landmark in zip(boxes, landmarks):
|
# For each detected face, predict the emotion
|
||||||
x1, y1, x2, y2, score = box.astype(int)
|
for detection in detections:
|
||||||
face_crop = frame[y1:y2, x1:x2]
|
box = detection['bbox']
|
||||||
|
landmark = detection['landmarks']
|
||||||
|
x1, y1, x2, y2 = map(int, box)
|
||||||
|
|
||||||
if face_crop.size == 0:
|
# Predict attributes using the landmark
|
||||||
continue
|
emotion, confidence = emotion_predictor.predict(frame, landmark)
|
||||||
|
|
||||||
emotion, preds = emotion_detector.predict(frame, landmark)
|
# Prepare text and draw on the frame
|
||||||
|
label = f"{emotion} ({confidence:.2f})"
|
||||||
|
cv2.rectangle(frame, (x1, y1), (x2, y2), (255, 0, 0), 2)
|
||||||
|
cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 0, 0), 2)
|
||||||
|
|
||||||
txt = f"{emotion} ({preds:.2f})"
|
# Display the resulting frame
|
||||||
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
cv2.imshow("Emotion Inference (Press 'q' to quit)", frame)
|
||||||
cv2.putText(frame, txt, (x1, y1 - 10),
|
|
||||||
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
|
|
||||||
|
|
||||||
cv2.imshow("Face + Emotion Detection", frame)
|
# Break the loop if 'q' is pressed
|
||||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Release resources
|
||||||
cap.release()
|
cap.release()
|
||||||
cv2.destroyAllWindows()
|
cv2.destroyAllWindows()
|
||||||
|
print("Inference stopped.")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -6,7 +6,7 @@ from .models import Landmark106
|
|||||||
from .base import BaseLandmarker
|
from .base import BaseLandmarker
|
||||||
|
|
||||||
|
|
||||||
def create_landmarker(method: str = '2d106', **kwargs) -> BaseLandmarker:
|
def create_landmarker(method: str = '2d106det', **kwargs) -> BaseLandmarker:
|
||||||
"""
|
"""
|
||||||
Factory function to create facial landmark predictors.
|
Factory function to create facial landmark predictors.
|
||||||
|
|
||||||
@@ -18,10 +18,10 @@ def create_landmarker(method: str = '2d106', **kwargs) -> BaseLandmarker:
|
|||||||
Initialized landmarker instance.
|
Initialized landmarker instance.
|
||||||
"""
|
"""
|
||||||
method = method.lower()
|
method = method.lower()
|
||||||
if method == 'insightface_106':
|
if method == '2d106det':
|
||||||
return Landmark106(**kwargs)
|
return Landmark106(**kwargs)
|
||||||
else:
|
else:
|
||||||
available = ['insightface_106']
|
available = ['2d106det']
|
||||||
raise ValueError(f"Unsupported method: '{method}'. Available: {available}")
|
raise ValueError(f"Unsupported method: '{method}'. Available: {available}")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -168,7 +168,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# 1. Create the detector and landmarker using the new API
|
# 1. Create the detector and landmarker using the new API
|
||||||
face_detector = create_detector('retinaface')
|
face_detector = create_detector('retinaface')
|
||||||
landmarker = create_landmarker() # Uses the default '106' method
|
landmarker = create_landmarker() # Uses the default '2d106det' method
|
||||||
|
|
||||||
cap = cv2.VideoCapture(0)
|
cap = cv2.VideoCapture(0)
|
||||||
if not cap.isOpened():
|
if not cap.isOpened():
|
||||||
|
|||||||
Reference in New Issue
Block a user