mirror of
https://github.com/yakhyo/uniface.git
synced 2025-12-30 09:02:25 +00:00
feat: Age and gender inference code updated
This commit is contained in:
@@ -8,6 +8,9 @@ from uniface.face_utils import bbox_center_alignment
|
||||
from uniface.model_store import verify_model_weights
|
||||
from uniface.constants import AgeGenderWeights
|
||||
|
||||
from uniface.detection import RetinaFace
|
||||
from uniface.constants import RetinaFaceWeights
|
||||
|
||||
__all__ = ["AgeGender"]
|
||||
|
||||
|
||||
@@ -15,26 +18,30 @@ class AgeGender:
|
||||
"""
|
||||
Age and Gender Prediction Model.
|
||||
"""
|
||||
def __init__(self, model_name: AgeGenderWeights = AgeGenderWeights.DEFAULT, input_size:Tuple[int, int] = (112, 112)) -> None:
|
||||
|
||||
def __init__(self, model_name: AgeGenderWeights = AgeGenderWeights.DEFAULT, input_size: Tuple[int, int] = (112, 112)) -> None:
|
||||
"""
|
||||
Initializes the Attribute model for inference.
|
||||
|
||||
|
||||
Args:
|
||||
model_path (str): Path to the ONNX file.
|
||||
"""
|
||||
|
||||
|
||||
Logger.info(
|
||||
f"Initializing RetinaFace with model={model_name}, conf_thresh={conf_thresh}, nms_thresh={nms_thresh}, "
|
||||
f"pre_nms_topk={pre_nms_topk}, post_nms_topk={post_nms_topk}, dynamic_size={dynamic_size}, "
|
||||
f"Initializing AgeGender with model={model_name}, "
|
||||
f"input_size={input_size}"
|
||||
)
|
||||
|
||||
self.model_path = model_path
|
||||
|
||||
self.input_size = input_size
|
||||
self.input_std = 1.0
|
||||
self.input_mean = 0.0
|
||||
|
||||
self._initialize_model(model_path=model_path)
|
||||
# Get path to model weights
|
||||
self._model_path = verify_model_weights(model_name)
|
||||
Logger.info(f"Verfied model weights located at: {self._model_path}")
|
||||
|
||||
# Initialize model
|
||||
self._initialize_model(model_path=self._model_path)
|
||||
|
||||
def _initialize_model(self, model_path: str):
|
||||
"""Initialize the model from the given path.
|
||||
@@ -102,9 +109,66 @@ class AgeGender:
|
||||
age = int(np.round(predictions[2]*100))
|
||||
return gender, age
|
||||
|
||||
def get(self, image: np.ndarray, bbox: np.ndarray) -> Tuple[np.int64, int]:
|
||||
def predict(self, image: np.ndarray, bbox: np.ndarray) -> Tuple[np.int64, int]:
|
||||
blob = self.preprocess(image, bbox)
|
||||
predictions = self.session.run(self.output_names, {self.input_names[0]: blob})[0][0]
|
||||
gender, age = self.postprocess(predictions)
|
||||
|
||||
return gender, age
|
||||
|
||||
|
||||
|
||||
|
||||
# TODO: For testing purposes only, remove later
|
||||
|
||||
def main():
|
||||
|
||||
face_detector = RetinaFace(
|
||||
model_name=RetinaFaceWeights.MNET_V2,
|
||||
conf_thresh=0.5,
|
||||
pre_nms_topk=5000,
|
||||
nms_thresh=0.4,
|
||||
post_nms_topk=750,
|
||||
dynamic_size=False,
|
||||
input_size=(640, 640)
|
||||
)
|
||||
age_detector = AgeGender()
|
||||
|
||||
cap = cv2.VideoCapture(0)
|
||||
if not cap.isOpened():
|
||||
print("Webcam not available.")
|
||||
return
|
||||
|
||||
print("Press 'q' to quit.")
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
print("Frame capture failed.")
|
||||
break
|
||||
|
||||
boxes, landmarks = face_detector.detect(frame)
|
||||
|
||||
for box, landmark in zip(boxes, landmarks):
|
||||
x1, y1, x2, y2, score = box.astype(int)
|
||||
face_crop = frame[y1:y2, x1:x2]
|
||||
|
||||
if face_crop.size == 0:
|
||||
continue
|
||||
|
||||
gender, age = age_detector.predict(frame, box[:4])
|
||||
|
||||
txt = f"{gender} ({age:.2f})"
|
||||
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||||
cv2.putText(frame, txt, (x1, y1 - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
|
||||
|
||||
cv2.imshow("Face + Emotion Detection", frame)
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||
break
|
||||
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -37,7 +37,7 @@ class Emotion:
|
||||
"""
|
||||
Initialize the emotion detector with a TorchScript model
|
||||
"""
|
||||
|
||||
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
self.emotions = [
|
||||
@@ -47,7 +47,9 @@ class Emotion:
|
||||
self.emotions.append("Contempt")
|
||||
|
||||
self.input_size = input_size
|
||||
|
||||
self.input_std = [0.229, 0.224, 0.225]
|
||||
self.input_mean = [0.485, 0.456, 0.406]
|
||||
|
||||
Logger.info(
|
||||
f"Initialized Emotion class with model={model_name.name}, "
|
||||
f"device={'cuda' if torch.cuda.is_available() else 'cpu'}, "
|
||||
@@ -92,14 +94,14 @@ class Emotion:
|
||||
Returns:
|
||||
torch.Tensor: Preprocessed image tensor of shape (1, 3, 112, 112)
|
||||
"""
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # BGR -> RGB
|
||||
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # BGR -> RGB
|
||||
|
||||
# Resize to (112, 112)
|
||||
image = cv2.resize(image, self.input_size).astype(np.float32) / 255.0
|
||||
|
||||
# Normalize with mean and std
|
||||
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
|
||||
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
|
||||
mean = np.array(self.input_mean, dtype=np.float32)
|
||||
std = np.array(self.input_std, dtype=np.float32)
|
||||
image_normalized = (image - mean) / std
|
||||
|
||||
# HWC to CHW
|
||||
|
||||
@@ -121,18 +121,18 @@ def bbox_center_alignment(image, center, output_size, scale, rotation):
|
||||
rot = float(rotation) * np.pi / 180.0
|
||||
|
||||
# Scale the image
|
||||
t1 = trans.SimilarityTransform(scale=scale)
|
||||
t1 = SimilarityTransform(scale=scale)
|
||||
|
||||
# Translate the center point to the origin (after scaling)
|
||||
cx = center[0] * scale
|
||||
cy = center[1] * scale
|
||||
t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy))
|
||||
t2 = SimilarityTransform(translation=(-1 * cx, -1 * cy))
|
||||
|
||||
# Apply rotation around origin (center of face)
|
||||
t3 = trans.SimilarityTransform(rotation=rot)
|
||||
t3 = SimilarityTransform(rotation=rot)
|
||||
|
||||
# Translate origin to center of output image
|
||||
t4 = trans.SimilarityTransform(translation=(output_size / 2, output_size / 2))
|
||||
t4 = SimilarityTransform(translation=(output_size / 2, output_size / 2))
|
||||
|
||||
# Combine all transformations in order: scale → center shift → rotate → recentralize
|
||||
t = t1 + t2 + t3 + t4
|
||||
|
||||
Reference in New Issue
Block a user