mirror of
https://github.com/yakhyo/uniface.git
synced 2025-12-30 09:02:25 +00:00
feat: Add face emotion model
This commit is contained in:
@@ -22,7 +22,7 @@ def extract_reference_embedding(detector, recognizer, image_path):
|
||||
return embedding
|
||||
|
||||
|
||||
def run_video(detector, recognizer, ref_embedding, threshold=0.45):
|
||||
def run_video(detector, recognizer, ref_embedding, threshold=0.30):
|
||||
cap = cv2.VideoCapture(0)
|
||||
if not cap.isOpened():
|
||||
raise RuntimeError("Webcam could not be opened.")
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
from .age_gender import AgeGender
|
||||
from .emotion import Emotion
|
||||
|
||||
@@ -3,22 +3,32 @@ import numpy as np
|
||||
import onnxruntime
|
||||
from typing import Tuple
|
||||
|
||||
from uniface.log import Logger
|
||||
from uniface.face_utils import bbox_center_alignment
|
||||
from uniface.model_store import verify_model_weights
|
||||
from uniface.constants import AgeGenderWeights
|
||||
|
||||
__all__ = ["Attribute"]
|
||||
__all__ = ["AgeGender"]
|
||||
|
||||
|
||||
class Attribute:
|
||||
class AgeGender:
|
||||
"""
|
||||
Age and Gender Prediction Model.
|
||||
"""
|
||||
def __init__(self, model_path: str) -> None:
|
||||
def __init__(self, model_name: AgeGenderWeights = AgeGenderWeights.DEFAULT, input_size:Tuple[int, int] = (112, 112)) -> None:
|
||||
"""
|
||||
Initializes the Attribute model for inference.
|
||||
|
||||
Args:
|
||||
model_path (str): Path to the ONNX file.
|
||||
"""
|
||||
|
||||
Logger.info(
|
||||
f"Initializing RetinaFace with model={model_name}, conf_thresh={conf_thresh}, nms_thresh={nms_thresh}, "
|
||||
f"pre_nms_topk={pre_nms_topk}, post_nms_topk={post_nms_topk}, dynamic_size={dynamic_size}, "
|
||||
f"input_size={input_size}"
|
||||
)
|
||||
|
||||
self.model_path = model_path
|
||||
|
||||
self.input_std = 1.0
|
||||
204
uniface/attribute/emotion.py
Normal file
204
uniface/attribute/emotion.py
Normal file
@@ -0,0 +1,204 @@
|
||||
# Copyright 2025 Yakhyokhuja Valikhujaev
|
||||
# Author: Yakhyokhuja Valikhujaev
|
||||
# GitHub: https://github.com/yakhyo
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from typing import Tuple, Union
|
||||
|
||||
from uniface.log import Logger
|
||||
from uniface import RetinaFace
|
||||
from uniface.model_store import verify_model_weights
|
||||
from uniface.constants import RetinaFaceWeights, DDAMFNWeights
|
||||
|
||||
|
||||
class Emotion:
|
||||
"""
|
||||
Emotion recognition using a TorchScript model.
|
||||
|
||||
Args:
|
||||
model_name (DDAMFNWeights): Pretrained model enum. Defaults to AFFECNET7.
|
||||
|
||||
Attributes:
|
||||
emotions (List[str]): Emotion label list.
|
||||
device (torch.device): Inference device (CPU or CUDA).
|
||||
model (torch.jit.ScriptModule): Loaded TorchScript model.
|
||||
|
||||
Raises:
|
||||
ValueError: If model weights are invalid or not found.
|
||||
RuntimeError: If model loading fails.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: DDAMFNWeights = DDAMFNWeights.AFFECNET7, input_size: Tuple[int, int] = (112, 112)) -> None:
|
||||
"""
|
||||
Initialize the emotion detector with a TorchScript model
|
||||
"""
|
||||
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
self.emotions = [
|
||||
"Neutral", "Happy", "Sad", "Surprise", "Fear", "Disgust", "Angry"
|
||||
]
|
||||
if model_name == DDAMFNWeights.AFFECNET8:
|
||||
self.emotions.append("Contempt")
|
||||
|
||||
self.input_size = input_size
|
||||
|
||||
Logger.info(
|
||||
f"Initialized Emotion class with model={model_name.name}, "
|
||||
f"device={'cuda' if torch.cuda.is_available() else 'cpu'}, "
|
||||
f"num_classes={len(self.emotions)}, input_size={self.input_size}"
|
||||
)
|
||||
|
||||
# Get path to model weights
|
||||
self._model_path = verify_model_weights(model_name)
|
||||
Logger.info(f"Verified model weights located at: {self._model_path}")
|
||||
|
||||
# Initialize model
|
||||
self._initialize_model(model_path=self._model_path)
|
||||
|
||||
def _initialize_model(self, model_path: str) -> None:
|
||||
"""
|
||||
Initializes a TorchScript model for emotion inference.
|
||||
|
||||
Args:
|
||||
model_path (str): Path to the TorchScript (.pt) model.
|
||||
"""
|
||||
try:
|
||||
self.model = torch.jit.load(model_path, map_location=self.device)
|
||||
self.model.eval()
|
||||
Logger.info(f"TorchScript model successfully loaded from: {model_path}")
|
||||
|
||||
# Warm-up
|
||||
dummy = torch.randn(1, 3, 112, 112).to(self.device)
|
||||
with torch.no_grad():
|
||||
_ = self.model(dummy)
|
||||
Logger.info("Emotion model warmed up with dummy input.")
|
||||
|
||||
except Exception as e:
|
||||
Logger.error(f"Failed to load TorchScript model from {model_path}: {e}")
|
||||
raise
|
||||
|
||||
def preprocess(self, image: np.ndarray) -> torch.Tensor:
|
||||
"""
|
||||
Resize, normalize and convert image to tensor manually without torchvision.
|
||||
|
||||
Args:
|
||||
image (np.ndarray): RGB image (H, W, 3)
|
||||
Returns:
|
||||
torch.Tensor: Preprocessed image tensor of shape (1, 3, 112, 112)
|
||||
"""
|
||||
# Resize to (112, 112)
|
||||
image = cv2.resize(image, self.input_size).astype(np.float32) / 255.0
|
||||
|
||||
# Normalize with mean and std
|
||||
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
|
||||
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
|
||||
image_normalized = (image - mean) / std
|
||||
|
||||
# HWC to CHW
|
||||
image_transposed = image_normalized.transpose((2, 0, 1))
|
||||
|
||||
# Convert to torch tensor and add batch dimension
|
||||
tensor = torch.from_numpy(image_transposed).unsqueeze(0).to(self.device)
|
||||
|
||||
return tensor
|
||||
|
||||
def predict(self, image: np.ndarray) -> Tuple[Union[str, None], Union[float, None]]:
|
||||
"""
|
||||
Predict the emotion from an RGB face image.
|
||||
|
||||
Args:
|
||||
image (np.ndarray): Input face image in RGB format.
|
||||
|
||||
Returns:
|
||||
Tuple[str, float]: (Predicted emotion label, Confidence score)
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the input is invalid or inference fails internally.
|
||||
"""
|
||||
if not isinstance(image, np.ndarray):
|
||||
Logger.error("Input must be a NumPy ndarray.")
|
||||
raise ValueError("Input must be a NumPy ndarray (RGB image).")
|
||||
|
||||
if image.ndim != 3 or image.shape[2] != 3:
|
||||
Logger.error(f"Invalid image shape: {image.shape}. Expected HxWx3 RGB image.")
|
||||
raise ValueError("Input image must be in RGB format with shape (H, W, 3).")
|
||||
|
||||
try:
|
||||
tensor = self.preprocess(image)
|
||||
|
||||
with torch.no_grad():
|
||||
output = self.model(tensor)
|
||||
|
||||
if isinstance(output, tuple):
|
||||
output = output[0]
|
||||
|
||||
probs = torch.nn.functional.softmax(output, dim=1).squeeze(0).cpu().numpy()
|
||||
pred_idx = int(np.argmax(probs))
|
||||
confidence = round(float(probs[pred_idx]), 2)
|
||||
|
||||
return self.emotions[pred_idx], confidence
|
||||
|
||||
except Exception as e:
|
||||
Logger.error(f"Emotion inference failed: {e}")
|
||||
return None, None
|
||||
|
||||
|
||||
# TODO: For testing purposes only, remove later
|
||||
|
||||
def main():
|
||||
|
||||
face_detector = RetinaFace(
|
||||
model_name=RetinaFaceWeights.MNET_V2,
|
||||
conf_thresh=0.5,
|
||||
pre_nms_topk=5000,
|
||||
nms_thresh=0.4,
|
||||
post_nms_topk=750,
|
||||
dynamic_size=False,
|
||||
input_size=(640, 640)
|
||||
)
|
||||
emotion_detector = Emotion()
|
||||
|
||||
cap = cv2.VideoCapture(0)
|
||||
if not cap.isOpened():
|
||||
print("Webcam not available.")
|
||||
return
|
||||
|
||||
print("Press 'q' to quit.")
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
print("Frame capture failed.")
|
||||
break
|
||||
|
||||
boxes, _ = face_detector.detect(frame)
|
||||
|
||||
for box in boxes:
|
||||
x1, y1, x2, y2, score = box.astype(int)
|
||||
face_crop = frame[y1:y2, x1:x2]
|
||||
|
||||
if face_crop.size == 0:
|
||||
continue
|
||||
|
||||
face_rgb = cv2.cvtColor(face_crop, cv2.COLOR_BGR2RGB)
|
||||
emotion, preds = emotion_detector.predict(face_rgb)
|
||||
|
||||
txt = f"{emotion} ({preds:.2f})"
|
||||
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||||
cv2.putText(frame, txt, (x1, y1 - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
|
||||
|
||||
cv2.imshow("Face + Emotion Detection", frame)
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||
break
|
||||
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -48,11 +48,28 @@ class RetinaFaceWeights(str, Enum):
|
||||
class SCRFDWeights(str, Enum):
|
||||
"""
|
||||
Trained on WIDER FACE dataset.
|
||||
https://github.com/deepinsight/insightface/tree/master/detection/scrfd
|
||||
https://github.com/deepinsight/insightface
|
||||
"""
|
||||
SCRFD_10G_KPS = "scrfd_10g"
|
||||
SCRFD_500M_KPS = "scrfd_500m"
|
||||
|
||||
|
||||
class DDAMFNWeights(str, Enum):
|
||||
"""
|
||||
Trained on AffectNet dataset.
|
||||
https://github.com/SainingZhang/DDAMFN/tree/main/DDAMFN
|
||||
"""
|
||||
AFFECNET7 = "affecnet7"
|
||||
AFFECNET8 = "affecnet8"
|
||||
|
||||
|
||||
class AgeGenderWeights(str, Enum):
|
||||
"""
|
||||
Trained on CelebA dataset.
|
||||
https://github.com/deepinsight/insightface
|
||||
"""
|
||||
DEFAULT = "age_gender"
|
||||
|
||||
# fmt: on
|
||||
|
||||
|
||||
@@ -84,6 +101,14 @@ MODEL_URLS: Dict[Enum, str] = {
|
||||
# SCRFD
|
||||
SCRFDWeights.SCRFD_10G_KPS: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/scrfd_10g_kps.onnx',
|
||||
SCRFDWeights.SCRFD_500M_KPS: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/scrfd_500m_kps.onnx',
|
||||
|
||||
|
||||
# DDAFM
|
||||
DDAMFNWeights.AFFECNET7: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/affecnet7.script',
|
||||
DDAMFNWeights.AFFECNET8: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/affecnet8.script',
|
||||
|
||||
# AgeGender
|
||||
AgeGenderWeights.DEFAULT: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/genderage.onnx',
|
||||
}
|
||||
|
||||
MODEL_SHA256: Dict[Enum, str] = {
|
||||
@@ -113,6 +138,13 @@ MODEL_SHA256: Dict[Enum, str] = {
|
||||
# SCRFD
|
||||
SCRFDWeights.SCRFD_10G_KPS: '5838f7fe053675b1c7a08b633df49e7af5495cee0493c7dcf6697200b85b5b91',
|
||||
SCRFDWeights.SCRFD_500M_KPS: '5e4447f50245bbd7966bd6c0fa52938c61474a04ec7def48753668a9d8b4ea3a',
|
||||
|
||||
# DDAFM
|
||||
DDAMFNWeights.AFFECNET7: '10535bf8b6afe8e9d6ae26cea6c3add9a93036e9addb6adebfd4a972171d015d',
|
||||
DDAMFNWeights.AFFECNET8: '8c66963bc71db42796a14dfcbfcd181b268b65a3fc16e87147d6a3a3d7e0f487',
|
||||
|
||||
# AgeGender
|
||||
AgeGenderWeights.DEFAULT: '4fde69b1c810857b88c64a335084f1c3fe8f01246c9a191b48c7bb756d6652fb',
|
||||
}
|
||||
|
||||
CHUNK_SIZE = 8192
|
||||
|
||||
Reference in New Issue
Block a user