feat: Add face emotion model

This commit is contained in:
yakhyo
2025-04-11 13:56:50 +09:00
parent 282737e0e9
commit 6a7ba6fc0a
5 changed files with 253 additions and 5 deletions

View File

@@ -22,7 +22,7 @@ def extract_reference_embedding(detector, recognizer, image_path):
return embedding
def run_video(detector, recognizer, ref_embedding, threshold=0.45):
def run_video(detector, recognizer, ref_embedding, threshold=0.30):
cap = cv2.VideoCapture(0)
if not cap.isOpened():
raise RuntimeError("Webcam could not be opened.")

View File

@@ -0,0 +1,2 @@
from .age_gender import AgeGender
from .emotion import Emotion

View File

@@ -3,22 +3,32 @@ import numpy as np
import onnxruntime
from typing import Tuple
from uniface.log import Logger
from uniface.face_utils import bbox_center_alignment
from uniface.model_store import verify_model_weights
from uniface.constants import AgeGenderWeights
__all__ = ["Attribute"]
__all__ = ["AgeGender"]
class Attribute:
class AgeGender:
"""
Age and Gender Prediction Model.
"""
def __init__(self, model_path: str) -> None:
def __init__(self, model_name: AgeGenderWeights = AgeGenderWeights.DEFAULT, input_size:Tuple[int, int] = (112, 112)) -> None:
"""
Initializes the Attribute model for inference.
Args:
model_path (str): Path to the ONNX file.
"""
Logger.info(
f"Initializing RetinaFace with model={model_name}, conf_thresh={conf_thresh}, nms_thresh={nms_thresh}, "
f"pre_nms_topk={pre_nms_topk}, post_nms_topk={post_nms_topk}, dynamic_size={dynamic_size}, "
f"input_size={input_size}"
)
self.model_path = model_path
self.input_std = 1.0

View File

@@ -0,0 +1,204 @@
# Copyright 2025 Yakhyokhuja Valikhujaev
# Author: Yakhyokhuja Valikhujaev
# GitHub: https://github.com/yakhyo
import cv2
import numpy as np
import torch
from PIL import Image
from typing import Tuple, Union
from uniface.log import Logger
from uniface import RetinaFace
from uniface.model_store import verify_model_weights
from uniface.constants import RetinaFaceWeights, DDAMFNWeights
class Emotion:
"""
Emotion recognition using a TorchScript model.
Args:
model_name (DDAMFNWeights): Pretrained model enum. Defaults to AFFECNET7.
Attributes:
emotions (List[str]): Emotion label list.
device (torch.device): Inference device (CPU or CUDA).
model (torch.jit.ScriptModule): Loaded TorchScript model.
Raises:
ValueError: If model weights are invalid or not found.
RuntimeError: If model loading fails.
"""
def __init__(self, model_name: DDAMFNWeights = DDAMFNWeights.AFFECNET7, input_size: Tuple[int, int] = (112, 112)) -> None:
"""
Initialize the emotion detector with a TorchScript model
"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.emotions = [
"Neutral", "Happy", "Sad", "Surprise", "Fear", "Disgust", "Angry"
]
if model_name == DDAMFNWeights.AFFECNET8:
self.emotions.append("Contempt")
self.input_size = input_size
Logger.info(
f"Initialized Emotion class with model={model_name.name}, "
f"device={'cuda' if torch.cuda.is_available() else 'cpu'}, "
f"num_classes={len(self.emotions)}, input_size={self.input_size}"
)
# Get path to model weights
self._model_path = verify_model_weights(model_name)
Logger.info(f"Verified model weights located at: {self._model_path}")
# Initialize model
self._initialize_model(model_path=self._model_path)
def _initialize_model(self, model_path: str) -> None:
"""
Initializes a TorchScript model for emotion inference.
Args:
model_path (str): Path to the TorchScript (.pt) model.
"""
try:
self.model = torch.jit.load(model_path, map_location=self.device)
self.model.eval()
Logger.info(f"TorchScript model successfully loaded from: {model_path}")
# Warm-up
dummy = torch.randn(1, 3, 112, 112).to(self.device)
with torch.no_grad():
_ = self.model(dummy)
Logger.info("Emotion model warmed up with dummy input.")
except Exception as e:
Logger.error(f"Failed to load TorchScript model from {model_path}: {e}")
raise
def preprocess(self, image: np.ndarray) -> torch.Tensor:
"""
Resize, normalize and convert image to tensor manually without torchvision.
Args:
image (np.ndarray): RGB image (H, W, 3)
Returns:
torch.Tensor: Preprocessed image tensor of shape (1, 3, 112, 112)
"""
# Resize to (112, 112)
image = cv2.resize(image, self.input_size).astype(np.float32) / 255.0
# Normalize with mean and std
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
image_normalized = (image - mean) / std
# HWC to CHW
image_transposed = image_normalized.transpose((2, 0, 1))
# Convert to torch tensor and add batch dimension
tensor = torch.from_numpy(image_transposed).unsqueeze(0).to(self.device)
return tensor
def predict(self, image: np.ndarray) -> Tuple[Union[str, None], Union[float, None]]:
"""
Predict the emotion from an RGB face image.
Args:
image (np.ndarray): Input face image in RGB format.
Returns:
Tuple[str, float]: (Predicted emotion label, Confidence score)
Raises:
RuntimeError: If the input is invalid or inference fails internally.
"""
if not isinstance(image, np.ndarray):
Logger.error("Input must be a NumPy ndarray.")
raise ValueError("Input must be a NumPy ndarray (RGB image).")
if image.ndim != 3 or image.shape[2] != 3:
Logger.error(f"Invalid image shape: {image.shape}. Expected HxWx3 RGB image.")
raise ValueError("Input image must be in RGB format with shape (H, W, 3).")
try:
tensor = self.preprocess(image)
with torch.no_grad():
output = self.model(tensor)
if isinstance(output, tuple):
output = output[0]
probs = torch.nn.functional.softmax(output, dim=1).squeeze(0).cpu().numpy()
pred_idx = int(np.argmax(probs))
confidence = round(float(probs[pred_idx]), 2)
return self.emotions[pred_idx], confidence
except Exception as e:
Logger.error(f"Emotion inference failed: {e}")
return None, None
# TODO: For testing purposes only, remove later
def main():
face_detector = RetinaFace(
model_name=RetinaFaceWeights.MNET_V2,
conf_thresh=0.5,
pre_nms_topk=5000,
nms_thresh=0.4,
post_nms_topk=750,
dynamic_size=False,
input_size=(640, 640)
)
emotion_detector = Emotion()
cap = cv2.VideoCapture(0)
if not cap.isOpened():
print("Webcam not available.")
return
print("Press 'q' to quit.")
while True:
ret, frame = cap.read()
if not ret:
print("Frame capture failed.")
break
boxes, _ = face_detector.detect(frame)
for box in boxes:
x1, y1, x2, y2, score = box.astype(int)
face_crop = frame[y1:y2, x1:x2]
if face_crop.size == 0:
continue
face_rgb = cv2.cvtColor(face_crop, cv2.COLOR_BGR2RGB)
emotion, preds = emotion_detector.predict(face_rgb)
txt = f"{emotion} ({preds:.2f})"
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(frame, txt, (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
cv2.imshow("Face + Emotion Detection", frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
if __name__ == "__main__":
main()

View File

@@ -48,11 +48,28 @@ class RetinaFaceWeights(str, Enum):
class SCRFDWeights(str, Enum):
"""
Trained on WIDER FACE dataset.
https://github.com/deepinsight/insightface/tree/master/detection/scrfd
https://github.com/deepinsight/insightface
"""
SCRFD_10G_KPS = "scrfd_10g"
SCRFD_500M_KPS = "scrfd_500m"
class DDAMFNWeights(str, Enum):
"""
Trained on AffectNet dataset.
https://github.com/SainingZhang/DDAMFN/tree/main/DDAMFN
"""
AFFECNET7 = "affecnet7"
AFFECNET8 = "affecnet8"
class AgeGenderWeights(str, Enum):
"""
Trained on CelebA dataset.
https://github.com/deepinsight/insightface
"""
DEFAULT = "age_gender"
# fmt: on
@@ -84,6 +101,14 @@ MODEL_URLS: Dict[Enum, str] = {
# SCRFD
SCRFDWeights.SCRFD_10G_KPS: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/scrfd_10g_kps.onnx',
SCRFDWeights.SCRFD_500M_KPS: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/scrfd_500m_kps.onnx',
# DDAFM
DDAMFNWeights.AFFECNET7: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/affecnet7.script',
DDAMFNWeights.AFFECNET8: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/affecnet8.script',
# AgeGender
AgeGenderWeights.DEFAULT: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/genderage.onnx',
}
MODEL_SHA256: Dict[Enum, str] = {
@@ -113,6 +138,13 @@ MODEL_SHA256: Dict[Enum, str] = {
# SCRFD
SCRFDWeights.SCRFD_10G_KPS: '5838f7fe053675b1c7a08b633df49e7af5495cee0493c7dcf6697200b85b5b91',
SCRFDWeights.SCRFD_500M_KPS: '5e4447f50245bbd7966bd6c0fa52938c61474a04ec7def48753668a9d8b4ea3a',
# DDAFM
DDAMFNWeights.AFFECNET7: '10535bf8b6afe8e9d6ae26cea6c3add9a93036e9addb6adebfd4a972171d015d',
DDAMFNWeights.AFFECNET8: '8c66963bc71db42796a14dfcbfcd181b268b65a3fc16e87147d6a3a3d7e0f487',
# AgeGender
AgeGenderWeights.DEFAULT: '4fde69b1c810857b88c64a335084f1c3fe8f01246c9a191b48c7bb756d6652fb',
}
CHUNK_SIZE = 8192