feat: Age and gender inference code updated

This commit is contained in:
yakhyo
2025-04-21 12:19:18 +09:00
parent eef4a0624a
commit 29964df259
3 changed files with 85 additions and 19 deletions

View File

@@ -8,6 +8,9 @@ from uniface.face_utils import bbox_center_alignment
from uniface.model_store import verify_model_weights
from uniface.constants import AgeGenderWeights
from uniface.detection import RetinaFace
from uniface.constants import RetinaFaceWeights
__all__ = ["AgeGender"]
@@ -15,26 +18,30 @@ class AgeGender:
"""
Age and Gender Prediction Model.
"""
def __init__(self, model_name: AgeGenderWeights = AgeGenderWeights.DEFAULT, input_size:Tuple[int, int] = (112, 112)) -> None:
def __init__(self, model_name: AgeGenderWeights = AgeGenderWeights.DEFAULT, input_size: Tuple[int, int] = (112, 112)) -> None:
"""
Initializes the Attribute model for inference.
Args:
model_path (str): Path to the ONNX file.
"""
Logger.info(
f"Initializing RetinaFace with model={model_name}, conf_thresh={conf_thresh}, nms_thresh={nms_thresh}, "
f"pre_nms_topk={pre_nms_topk}, post_nms_topk={post_nms_topk}, dynamic_size={dynamic_size}, "
f"Initializing AgeGender with model={model_name}, "
f"input_size={input_size}"
)
self.model_path = model_path
self.input_size = input_size
self.input_std = 1.0
self.input_mean = 0.0
self._initialize_model(model_path=model_path)
# Get path to model weights
self._model_path = verify_model_weights(model_name)
Logger.info(f"Verfied model weights located at: {self._model_path}")
# Initialize model
self._initialize_model(model_path=self._model_path)
def _initialize_model(self, model_path: str):
"""Initialize the model from the given path.
@@ -102,9 +109,66 @@ class AgeGender:
age = int(np.round(predictions[2]*100))
return gender, age
def get(self, image: np.ndarray, bbox: np.ndarray) -> Tuple[np.int64, int]:
def predict(self, image: np.ndarray, bbox: np.ndarray) -> Tuple[np.int64, int]:
blob = self.preprocess(image, bbox)
predictions = self.session.run(self.output_names, {self.input_names[0]: blob})[0][0]
gender, age = self.postprocess(predictions)
return gender, age
# TODO: For testing purposes only, remove later
def main():
face_detector = RetinaFace(
model_name=RetinaFaceWeights.MNET_V2,
conf_thresh=0.5,
pre_nms_topk=5000,
nms_thresh=0.4,
post_nms_topk=750,
dynamic_size=False,
input_size=(640, 640)
)
age_detector = AgeGender()
cap = cv2.VideoCapture(0)
if not cap.isOpened():
print("Webcam not available.")
return
print("Press 'q' to quit.")
while True:
ret, frame = cap.read()
if not ret:
print("Frame capture failed.")
break
boxes, landmarks = face_detector.detect(frame)
for box, landmark in zip(boxes, landmarks):
x1, y1, x2, y2, score = box.astype(int)
face_crop = frame[y1:y2, x1:x2]
if face_crop.size == 0:
continue
gender, age = age_detector.predict(frame, box[:4])
txt = f"{gender} ({age:.2f})"
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(frame, txt, (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
cv2.imshow("Face + Emotion Detection", frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
if __name__ == "__main__":
main()

View File

@@ -37,7 +37,7 @@ class Emotion:
"""
Initialize the emotion detector with a TorchScript model
"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.emotions = [
@@ -47,7 +47,9 @@ class Emotion:
self.emotions.append("Contempt")
self.input_size = input_size
self.input_std = [0.229, 0.224, 0.225]
self.input_mean = [0.485, 0.456, 0.406]
Logger.info(
f"Initialized Emotion class with model={model_name.name}, "
f"device={'cuda' if torch.cuda.is_available() else 'cpu'}, "
@@ -92,14 +94,14 @@ class Emotion:
Returns:
torch.Tensor: Preprocessed image tensor of shape (1, 3, 112, 112)
"""
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # BGR -> RGB
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # BGR -> RGB
# Resize to (112, 112)
image = cv2.resize(image, self.input_size).astype(np.float32) / 255.0
# Normalize with mean and std
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
mean = np.array(self.input_mean, dtype=np.float32)
std = np.array(self.input_std, dtype=np.float32)
image_normalized = (image - mean) / std
# HWC to CHW

View File

@@ -121,18 +121,18 @@ def bbox_center_alignment(image, center, output_size, scale, rotation):
rot = float(rotation) * np.pi / 180.0
# Scale the image
t1 = trans.SimilarityTransform(scale=scale)
t1 = SimilarityTransform(scale=scale)
# Translate the center point to the origin (after scaling)
cx = center[0] * scale
cy = center[1] * scale
t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy))
t2 = SimilarityTransform(translation=(-1 * cx, -1 * cy))
# Apply rotation around origin (center of face)
t3 = trans.SimilarityTransform(rotation=rot)
t3 = SimilarityTransform(rotation=rot)
# Translate origin to center of output image
t4 = trans.SimilarityTransform(translation=(output_size / 2, output_size / 2))
t4 = SimilarityTransform(translation=(output_size / 2, output_size / 2))
# Combine all transformations in order: scale → center shift → rotate → recentralize
t = t1 + t2 + t3 + t4