chore: Code style formatting changes

This commit is contained in:
yakhyo
2025-11-26 00:05:24 +09:00
parent 0771a7959a
commit 84cda5f56c
12 changed files with 172 additions and 77 deletions

View File

@@ -19,7 +19,7 @@ dependencies = [
requires-python = ">=3.10" requires-python = ">=3.10"
[project.optional-dependencies] [project.optional-dependencies]
dev = ["pytest>=7.0.0"] dev = ["pytest>=7.0.0", "ruff>=0.4.0"]
gpu = ["onnxruntime-gpu>=1.16.0"] gpu = ["onnxruntime-gpu>=1.16.0"]
[project.urls] [project.urls]
@@ -35,3 +35,13 @@ packages = { find = {} }
[tool.setuptools.package-data] [tool.setuptools.package-data]
"uniface" = ["*.txt", "*.md"] "uniface" = ["*.txt", "*.md"]
[tool.ruff]
line-length = 120
target-version = "py310"
[tool.ruff.lint]
select = ["E", "F", "I", "W"]
[tool.ruff.lint.isort]
known-first-party = ["uniface"]

View File

@@ -33,7 +33,15 @@ def process_image(detector, image_path: Path, output_path: Path, threshold: floa
landmarks = [f["landmarks"] for f in faces] landmarks = [f["landmarks"] for f in faces]
draw_detections(image, bboxes, scores, landmarks, vis_threshold=threshold) draw_detections(image, bboxes, scores, landmarks, vis_threshold=threshold)
cv2.putText(image, f"Faces: {len(faces)}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) cv2.putText(
image,
f"Faces: {len(faces)}",
(10, 30),
cv2.FONT_HERSHEY_SIMPLEX,
1,
(0, 255, 0),
2,
)
cv2.imwrite(str(output_path), image) cv2.imwrite(str(output_path), image)
return len(faces) return len(faces)

View File

@@ -21,7 +21,13 @@ def draw_age_gender_label(image, bbox, gender: str, age: int):
cv2.putText(image, text, (x1 + 5, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 2) cv2.putText(image, text, (x1 + 5, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 2)
def process_image(detector, age_gender, image_path: str, save_dir: str = "outputs", threshold: float = 0.6): def process_image(
detector,
age_gender,
image_path: str,
save_dir: str = "outputs",
threshold: float = 0.6,
):
image = cv2.imread(image_path) image = cv2.imread(image_path)
if image is None: if image is None:
print(f"Error: Failed to load image from '{image_path}'") print(f"Error: Failed to load image from '{image_path}'")
@@ -75,7 +81,15 @@ def run_webcam(detector, age_gender, threshold: float = 0.6):
gender, age = age_gender.predict(frame, face["bbox"]) # predict per face gender, age = age_gender.predict(frame, face["bbox"]) # predict per face
draw_age_gender_label(frame, face["bbox"], gender, age) draw_age_gender_label(frame, face["bbox"], gender, age)
cv2.putText(frame, f"Faces: {len(faces)}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) cv2.putText(
frame,
f"Faces: {len(faces)}",
(10, 30),
cv2.FONT_HERSHEY_SIMPLEX,
1,
(0, 255, 0),
2,
)
cv2.imshow("Age & Gender Detection", frame) cv2.imshow("Age & Gender Detection", frame)
if cv2.waitKey(1) & 0xFF == ord("q"): if cv2.waitKey(1) & 0xFF == ord("q"):

View File

@@ -53,7 +53,15 @@ def run_webcam(detector, threshold: float = 0.6):
landmarks = [f["landmarks"] for f in faces] landmarks = [f["landmarks"] for f in faces]
draw_detections(frame, bboxes, scores, landmarks, vis_threshold=threshold) draw_detections(frame, bboxes, scores, landmarks, vis_threshold=threshold)
cv2.putText(frame, f"Faces: {len(faces)}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) cv2.putText(
frame,
f"Faces: {len(faces)}",
(10, 30),
cv2.FONT_HERSHEY_SIMPLEX,
1,
(0, 255, 0),
2,
)
cv2.imshow("Face Detection", frame) cv2.imshow("Face Detection", frame)
if cv2.waitKey(1) & 0xFF == ord("q"): if cv2.waitKey(1) & 0xFF == ord("q"):

View File

@@ -76,7 +76,12 @@ def main():
parser.add_argument("--image", type=str, required=True, help="Reference face image") parser.add_argument("--image", type=str, required=True, help="Reference face image")
parser.add_argument("--threshold", type=float, default=0.4, help="Match threshold") parser.add_argument("--threshold", type=float, default=0.4, help="Match threshold")
parser.add_argument("--detector", type=str, default="scrfd", choices=["retinaface", "scrfd"]) parser.add_argument("--detector", type=str, default="scrfd", choices=["retinaface", "scrfd"])
parser.add_argument("--recognizer", type=str, default="arcface", choices=["arcface", "mobileface", "sphereface"]) parser.add_argument(
"--recognizer",
type=str,
default="arcface",
choices=["arcface", "mobileface", "sphereface"],
)
args = parser.parse_args() args = parser.parse_args()
detector = RetinaFace() if args.detector == "retinaface" else SCRFD() detector = RetinaFace() if args.detector == "retinaface" else SCRFD()

View File

@@ -34,7 +34,15 @@ def process_image(detector, landmarker, image_path: str, save_dir: str = "output
for x, y in landmarks.astype(int): for x, y in landmarks.astype(int):
cv2.circle(image, (x, y), 1, (0, 255, 0), -1) cv2.circle(image, (x, y), 1, (0, 255, 0), -1)
cv2.putText(image, f"Face {i + 1}", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) cv2.putText(
image,
f"Face {i + 1}",
(x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(0, 255, 0),
2,
)
os.makedirs(save_dir, exist_ok=True) os.makedirs(save_dir, exist_ok=True)
output_path = os.path.join(save_dir, f"{Path(image_path).stem}_landmarks.jpg") output_path = os.path.join(save_dir, f"{Path(image_path).stem}_landmarks.jpg")
@@ -67,7 +75,15 @@ def run_webcam(detector, landmarker):
for x, y in landmarks.astype(int): for x, y in landmarks.astype(int):
cv2.circle(frame, (x, y), 1, (0, 255, 0), -1) cv2.circle(frame, (x, y), 1, (0, 255, 0), -1)
cv2.putText(frame, f"Faces: {len(faces)}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) cv2.putText(
frame,
f"Faces: {len(faces)}",
(10, 30),
cv2.FONT_HERSHEY_SIMPLEX,
1,
(0, 255, 0),
2,
)
cv2.imshow("106-Point Landmarks", frame) cv2.imshow("106-Point Landmarks", frame)
if cv2.waitKey(1) & 0xFF == ord("q"): if cv2.waitKey(1) & 0xFF == ord("q"):

View File

@@ -79,7 +79,12 @@ def main():
parser.add_argument("--image2", type=str, help="Second image for comparison") parser.add_argument("--image2", type=str, help="Second image for comparison")
parser.add_argument("--threshold", type=float, default=0.35, help="Similarity threshold") parser.add_argument("--threshold", type=float, default=0.35, help="Similarity threshold")
parser.add_argument("--detector", type=str, default="retinaface", choices=["retinaface", "scrfd"]) parser.add_argument("--detector", type=str, default="retinaface", choices=["retinaface", "scrfd"])
parser.add_argument("--recognizer", type=str, default="arcface", choices=["arcface", "mobileface", "sphereface"]) parser.add_argument(
"--recognizer",
type=str,
default="arcface",
choices=["arcface", "mobileface", "sphereface"],
)
args = parser.parse_args() args = parser.parse_args()
detector = RetinaFace() if args.detector == "retinaface" else SCRFD() detector = RetinaFace() if args.detector == "retinaface" else SCRFD()

View File

@@ -11,7 +11,13 @@ from uniface import SCRFD, RetinaFace
from uniface.visualization import draw_detections from uniface.visualization import draw_detections
def process_video(detector, input_path: str, output_path: str, threshold: float = 0.6, show_preview: bool = False): def process_video(
detector,
input_path: str,
output_path: str,
threshold: float = 0.6,
show_preview: bool = False,
):
cap = cv2.VideoCapture(input_path) cap = cv2.VideoCapture(input_path)
if not cap.isOpened(): if not cap.isOpened():
print(f"Error: Cannot open video file '{input_path}'") print(f"Error: Cannot open video file '{input_path}'")
@@ -51,7 +57,15 @@ def process_video(detector, input_path: str, output_path: str, threshold: float
landmarks = [f["landmarks"] for f in faces] landmarks = [f["landmarks"] for f in faces]
draw_detections(frame, bboxes, scores, landmarks, vis_threshold=threshold) draw_detections(frame, bboxes, scores, landmarks, vis_threshold=threshold)
cv2.putText(frame, f"Faces: {len(faces)}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) cv2.putText(
frame,
f"Faces: {len(faces)}",
(10, 30),
cv2.FONT_HERSHEY_SIMPLEX,
1,
(0, 255, 0),
2,
)
out.write(frame) out.write(frame)
if show_preview: if show_preview:

View File

@@ -31,7 +31,7 @@ def test_prediction_output_format(age_gender_model, mock_image, mock_bbox):
def test_gender_values(age_gender_model, mock_image, mock_bbox): def test_gender_values(age_gender_model, mock_image, mock_bbox):
gender, age = age_gender_model.predict(mock_image, mock_bbox) gender, age = age_gender_model.predict(mock_image, mock_bbox)
assert gender in ['Male', 'Female'], f"Gender should be 'Male' or 'Female', got '{gender}'" assert gender in ["Male", "Female"], f"Gender should be 'Male' or 'Female', got '{gender}'"
def test_age_range(age_gender_model, mock_image, mock_bbox): def test_age_range(age_gender_model, mock_image, mock_bbox):
@@ -48,7 +48,7 @@ def test_different_bbox_sizes(age_gender_model, mock_image):
for bbox in test_bboxes: for bbox in test_bboxes:
gender, age = age_gender_model.predict(mock_image, bbox) gender, age = age_gender_model.predict(mock_image, bbox)
assert gender in ['Male', 'Female'], f"Failed for bbox {bbox}" assert gender in ["Male", "Female"], f"Failed for bbox {bbox}"
assert 0 <= age <= 120, f"Age out of range for bbox {bbox}" assert 0 <= age <= 120, f"Age out of range for bbox {bbox}"
@@ -58,7 +58,7 @@ def test_different_image_sizes(age_gender_model, mock_bbox):
for size in test_sizes: for size in test_sizes:
mock_image = np.random.randint(0, 255, size, dtype=np.uint8) mock_image = np.random.randint(0, 255, size, dtype=np.uint8)
gender, age = age_gender_model.predict(mock_image, mock_bbox) gender, age = age_gender_model.predict(mock_image, mock_bbox)
assert gender in ['Male', 'Female'], f"Failed for image size {size}" assert gender in ["Male", "Female"], f"Failed for image size {size}"
assert 0 <= age <= 120, f"Age out of range for image size {size}" assert 0 <= age <= 120, f"Age out of range for image size {size}"
@@ -73,14 +73,14 @@ def test_consistency(age_gender_model, mock_image, mock_bbox):
def test_bbox_list_format(age_gender_model, mock_image): def test_bbox_list_format(age_gender_model, mock_image):
bbox_list = [100, 100, 300, 300] bbox_list = [100, 100, 300, 300]
gender, age = age_gender_model.predict(mock_image, bbox_list) gender, age = age_gender_model.predict(mock_image, bbox_list)
assert gender in ['Male', 'Female'], "Should work with bbox as list" assert gender in ["Male", "Female"], "Should work with bbox as list"
assert 0 <= age <= 120, "Age should be in valid range" assert 0 <= age <= 120, "Age should be in valid range"
def test_bbox_array_format(age_gender_model, mock_image): def test_bbox_array_format(age_gender_model, mock_image):
bbox_array = np.array([100, 100, 300, 300]) bbox_array = np.array([100, 100, 300, 300])
gender, age = age_gender_model.predict(mock_image, bbox_array) gender, age = age_gender_model.predict(mock_image, bbox_array)
assert gender in ['Male', 'Female'], "Should work with bbox as numpy array" assert gender in ["Male", "Female"], "Should work with bbox as numpy array"
assert 0 <= age <= 120, "Age should be in valid range" assert 0 <= age <= 120, "Age should be in valid range"
@@ -98,7 +98,7 @@ def test_multiple_predictions(age_gender_model, mock_image):
assert len(results) == 3, "Should have 3 predictions" assert len(results) == 3, "Should have 3 predictions"
for gender, age in results: for gender, age in results:
assert gender in ['Male', 'Female'] assert gender in ["Male", "Female"]
assert 0 <= age <= 120 assert 0 <= age <= 120

View File

@@ -16,7 +16,7 @@ def test_create_detector_retinaface():
""" """
Test creating a RetinaFace detector using factory function. Test creating a RetinaFace detector using factory function.
""" """
detector = create_detector('retinaface') detector = create_detector("retinaface")
assert detector is not None, "Failed to create RetinaFace detector" assert detector is not None, "Failed to create RetinaFace detector"
@@ -24,7 +24,7 @@ def test_create_detector_scrfd():
""" """
Test creating a SCRFD detector using factory function. Test creating a SCRFD detector using factory function.
""" """
detector = create_detector('scrfd') detector = create_detector("scrfd")
assert detector is not None, "Failed to create SCRFD detector" assert detector is not None, "Failed to create SCRFD detector"
@@ -33,10 +33,10 @@ def test_create_detector_with_config():
Test creating detector with custom configuration. Test creating detector with custom configuration.
""" """
detector = create_detector( detector = create_detector(
'retinaface', "retinaface",
model_name=RetinaFaceWeights.MNET_V2, model_name=RetinaFaceWeights.MNET_V2,
conf_thresh=0.8, conf_thresh=0.8,
nms_thresh=0.3 nms_thresh=0.3,
) )
assert detector is not None, "Failed to create detector with custom config" assert detector is not None, "Failed to create detector with custom config"
@@ -46,18 +46,14 @@ def test_create_detector_invalid_method():
Test that invalid detector method raises an error. Test that invalid detector method raises an error.
""" """
with pytest.raises((ValueError, KeyError)): with pytest.raises((ValueError, KeyError)):
create_detector('invalid_method') create_detector("invalid_method")
def test_create_detector_scrfd_with_model(): def test_create_detector_scrfd_with_model():
""" """
Test creating SCRFD detector with specific model. Test creating SCRFD detector with specific model.
""" """
detector = create_detector( detector = create_detector("scrfd", model_name=SCRFDWeights.SCRFD_10G_KPS, conf_thresh=0.5)
'scrfd',
model_name=SCRFDWeights.SCRFD_10G_KPS,
conf_thresh=0.5
)
assert detector is not None, "Failed to create SCRFD with specific model" assert detector is not None, "Failed to create SCRFD with specific model"
@@ -66,7 +62,7 @@ def test_create_recognizer_arcface():
""" """
Test creating an ArcFace recognizer using factory function. Test creating an ArcFace recognizer using factory function.
""" """
recognizer = create_recognizer('arcface') recognizer = create_recognizer("arcface")
assert recognizer is not None, "Failed to create ArcFace recognizer" assert recognizer is not None, "Failed to create ArcFace recognizer"
@@ -74,7 +70,7 @@ def test_create_recognizer_mobileface():
""" """
Test creating a MobileFace recognizer using factory function. Test creating a MobileFace recognizer using factory function.
""" """
recognizer = create_recognizer('mobileface') recognizer = create_recognizer("mobileface")
assert recognizer is not None, "Failed to create MobileFace recognizer" assert recognizer is not None, "Failed to create MobileFace recognizer"
@@ -82,7 +78,7 @@ def test_create_recognizer_sphereface():
""" """
Test creating a SphereFace recognizer using factory function. Test creating a SphereFace recognizer using factory function.
""" """
recognizer = create_recognizer('sphereface') recognizer = create_recognizer("sphereface")
assert recognizer is not None, "Failed to create SphereFace recognizer" assert recognizer is not None, "Failed to create SphereFace recognizer"
@@ -91,7 +87,7 @@ def test_create_recognizer_invalid_method():
Test that invalid recognizer method raises an error. Test that invalid recognizer method raises an error.
""" """
with pytest.raises((ValueError, KeyError)): with pytest.raises((ValueError, KeyError)):
create_recognizer('invalid_method') create_recognizer("invalid_method")
# create_landmarker tests # create_landmarker tests
@@ -99,7 +95,7 @@ def test_create_landmarker():
""" """
Test creating a Landmark106 detector using factory function. Test creating a Landmark106 detector using factory function.
""" """
landmarker = create_landmarker('2d106det') landmarker = create_landmarker("2d106det")
assert landmarker is not None, "Failed to create Landmark106 detector" assert landmarker is not None, "Failed to create Landmark106 detector"
@@ -116,7 +112,7 @@ def test_create_landmarker_invalid_method():
Test that invalid landmarker method raises an error. Test that invalid landmarker method raises an error.
""" """
with pytest.raises((ValueError, KeyError)): with pytest.raises((ValueError, KeyError)):
create_landmarker('invalid_method') create_landmarker("invalid_method")
# detect_faces tests # detect_faces tests
@@ -125,7 +121,7 @@ def test_detect_faces_retinaface():
Test high-level detect_faces function with RetinaFace. Test high-level detect_faces function with RetinaFace.
""" """
mock_image = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8) mock_image = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8)
faces = detect_faces(mock_image, method='retinaface') faces = detect_faces(mock_image, method="retinaface")
assert isinstance(faces, list), "detect_faces should return a list" assert isinstance(faces, list), "detect_faces should return a list"
@@ -135,7 +131,7 @@ def test_detect_faces_scrfd():
Test high-level detect_faces function with SCRFD. Test high-level detect_faces function with SCRFD.
""" """
mock_image = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8) mock_image = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8)
faces = detect_faces(mock_image, method='scrfd') faces = detect_faces(mock_image, method="scrfd")
assert isinstance(faces, list), "detect_faces should return a list" assert isinstance(faces, list), "detect_faces should return a list"
@@ -145,13 +141,13 @@ def test_detect_faces_with_threshold():
Test detect_faces with custom confidence threshold. Test detect_faces with custom confidence threshold.
""" """
mock_image = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8) mock_image = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8)
faces = detect_faces(mock_image, method='retinaface', conf_thresh=0.8) faces = detect_faces(mock_image, method="retinaface", conf_thresh=0.8)
assert isinstance(faces, list), "detect_faces should return a list" assert isinstance(faces, list), "detect_faces should return a list"
# All detections should respect threshold # All detections should respect threshold
for face in faces: for face in faces:
assert face['confidence'] >= 0.8, "All detections should meet confidence threshold" assert face["confidence"] >= 0.8, "All detections should meet confidence threshold"
def test_detect_faces_default_method(): def test_detect_faces_default_method():
@@ -169,7 +165,7 @@ def test_detect_faces_empty_image():
Test detect_faces on a blank image. Test detect_faces on a blank image.
""" """
empty_image = np.zeros((640, 640, 3), dtype=np.uint8) empty_image = np.zeros((640, 640, 3), dtype=np.uint8)
faces = detect_faces(empty_image, method='retinaface') faces = detect_faces(empty_image, method="retinaface")
assert isinstance(faces, list), "Should return a list even for empty image" assert isinstance(faces, list), "Should return a list even for empty image"
assert len(faces) == 0, "Should detect no faces in blank image" assert len(faces) == 0, "Should detect no faces in blank image"
@@ -193,8 +189,8 @@ def test_list_available_detectors_contents():
detectors = list_available_detectors() detectors = list_available_detectors()
# Should include at least these detectors # Should include at least these detectors
assert 'retinaface' in detectors, "Should include 'retinaface'" assert "retinaface" in detectors, "Should include 'retinaface'"
assert 'scrfd' in detectors, "Should include 'scrfd'" assert "scrfd" in detectors, "Should include 'scrfd'"
# Integration tests # Integration tests
@@ -202,7 +198,7 @@ def test_detector_inference_from_factory():
""" """
Test that detector created from factory can perform inference. Test that detector created from factory can perform inference.
""" """
detector = create_detector('retinaface') detector = create_detector("retinaface")
mock_image = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8) mock_image = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8)
faces = detector.detect(mock_image) faces = detector.detect(mock_image)
@@ -213,7 +209,7 @@ def test_recognizer_inference_from_factory():
""" """
Test that recognizer created from factory can perform inference. Test that recognizer created from factory can perform inference.
""" """
recognizer = create_recognizer('arcface') recognizer = create_recognizer("arcface")
mock_image = np.random.randint(0, 255, (112, 112, 3), dtype=np.uint8) mock_image = np.random.randint(0, 255, (112, 112, 3), dtype=np.uint8)
embedding = recognizer.get_embedding(mock_image) embedding = recognizer.get_embedding(mock_image)
@@ -225,7 +221,7 @@ def test_landmarker_inference_from_factory():
""" """
Test that landmarker created from factory can perform inference. Test that landmarker created from factory can perform inference.
""" """
landmarker = create_landmarker('2d106det') landmarker = create_landmarker("2d106det")
mock_image = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8) mock_image = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8)
mock_bbox = [100, 100, 300, 300] mock_bbox = [100, 100, 300, 300]
@@ -238,8 +234,8 @@ def test_multiple_detector_creation():
""" """
Test that multiple detectors can be created independently. Test that multiple detectors can be created independently.
""" """
detector1 = create_detector('retinaface') detector1 = create_detector("retinaface")
detector2 = create_detector('scrfd') detector2 = create_detector("scrfd")
assert detector1 is not None assert detector1 is not None
assert detector2 is not None assert detector2 is not None
@@ -250,8 +246,8 @@ def test_detector_with_different_configs():
""" """
Test creating multiple detectors with different configurations. Test creating multiple detectors with different configurations.
""" """
detector_high_thresh = create_detector('retinaface', conf_thresh=0.9) detector_high_thresh = create_detector("retinaface", conf_thresh=0.9)
detector_low_thresh = create_detector('retinaface', conf_thresh=0.3) detector_low_thresh = create_detector("retinaface", conf_thresh=0.3)
mock_image = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8) mock_image = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8)
@@ -269,9 +265,9 @@ def test_factory_returns_correct_types():
""" """
from uniface import RetinaFace, ArcFace, Landmark106 from uniface import RetinaFace, ArcFace, Landmark106
detector = create_detector('retinaface') detector = create_detector("retinaface")
recognizer = create_recognizer('arcface') recognizer = create_recognizer("arcface")
landmarker = create_landmarker('2d106det') landmarker = create_landmarker("2d106det")
assert isinstance(detector, RetinaFace), "Should return RetinaFace instance" assert isinstance(detector, RetinaFace), "Should return RetinaFace instance"
assert isinstance(recognizer, ArcFace), "Should return ArcFace instance" assert isinstance(recognizer, ArcFace), "Should return ArcFace instance"

View File

@@ -41,13 +41,16 @@ def mock_landmarks():
""" """
Create mock 5-point facial landmarks. Create mock 5-point facial landmarks.
""" """
return np.array([ return np.array(
[38.2946, 51.6963], [
[73.5318, 51.5014], [38.2946, 51.6963],
[56.0252, 71.7366], [73.5318, 51.5014],
[41.5493, 92.3655], [56.0252, 71.7366],
[70.7299, 92.2041] [41.5493, 92.3655],
], dtype=np.float32) [70.7299, 92.2041],
],
dtype=np.float32,
)
# ArcFace Tests # ArcFace Tests
@@ -173,8 +176,7 @@ def test_different_models_different_embeddings(arcface_model, mobileface_model,
# Embeddings should be different (with high probability for random input) # Embeddings should be different (with high probability for random input)
# We check that they're not identical # We check that they're not identical
assert not np.allclose(arcface_emb, mobileface_emb), \ assert not np.allclose(arcface_emb, mobileface_emb), "Different models should produce different embeddings"
"Different models should produce different embeddings"
def test_embedding_similarity_computation(arcface_model, mock_aligned_face): def test_embedding_similarity_computation(arcface_model, mock_aligned_face):
@@ -191,6 +193,7 @@ def test_embedding_similarity_computation(arcface_model, mock_aligned_face):
# Compute cosine similarity # Compute cosine similarity
from uniface import compute_similarity from uniface import compute_similarity
similarity = compute_similarity(emb1, emb2) similarity = compute_similarity(emb1, emb2)
# Similarity should be between -1 and 1 # Similarity should be between -1 and 1
@@ -205,6 +208,7 @@ def test_same_face_high_similarity(arcface_model, mock_aligned_face):
emb2 = arcface_model.get_embedding(mock_aligned_face) emb2 = arcface_model.get_embedding(mock_aligned_face)
from uniface import compute_similarity from uniface import compute_similarity
similarity = compute_similarity(emb1, emb2) similarity = compute_similarity(emb1, emb2)
# Same image should have similarity close to 1.0 # Same image should have similarity close to 1.0

View File

@@ -18,13 +18,16 @@ def mock_landmarks():
Create mock 5-point facial landmarks. Create mock 5-point facial landmarks.
Standard positions for a face roughly centered at (112/2, 112/2). Standard positions for a face roughly centered at (112/2, 112/2).
""" """
return np.array([ return np.array(
[38.2946, 51.6963], # Left eye [
[73.5318, 51.5014], # Right eye [38.2946, 51.6963], # Left eye
[56.0252, 71.7366], # Nose [73.5318, 51.5014], # Right eye
[41.5493, 92.3655], # Left mouth corner [56.0252, 71.7366], # Nose
[70.7299, 92.2041] # Right mouth corner [41.5493, 92.3655], # Left mouth corner
], dtype=np.float32) [70.7299, 92.2041], # Right mouth corner
],
dtype=np.float32,
)
# compute_similarity tests # compute_similarity tests
@@ -166,7 +169,7 @@ def test_face_alignment_landmarks_as_list(mock_image):
[73.5318, 51.5014], [73.5318, 51.5014],
[56.0252, 71.7366], [56.0252, 71.7366],
[41.5493, 92.3655], [41.5493, 92.3655],
[70.7299, 92.2041] [70.7299, 92.2041],
] ]
# Convert list to numpy array before passing to face_alignment # Convert list to numpy array before passing to face_alignment
@@ -201,9 +204,18 @@ def test_face_alignment_from_different_positions(mock_image):
""" """
# Landmarks at different positions # Landmarks at different positions
positions = [ positions = [
np.array([[100, 100], [150, 100], [125, 130], [110, 150], [140, 150]], dtype=np.float32), np.array(
np.array([[300, 200], [350, 200], [325, 230], [310, 250], [340, 250]], dtype=np.float32), [[100, 100], [150, 100], [125, 130], [110, 150], [140, 150]],
np.array([[500, 400], [550, 400], [525, 430], [510, 450], [540, 450]], dtype=np.float32), dtype=np.float32,
),
np.array(
[[300, 200], [350, 200], [325, 230], [310, 250], [340, 250]],
dtype=np.float32,
),
np.array(
[[500, 400], [550, 400], [525, 430], [510, 450], [540, 450]],
dtype=np.float32,
),
] ]
for landmarks in positions: for landmarks in positions:
@@ -216,13 +228,16 @@ def test_face_alignment_landmark_count(mock_image):
Test that face_alignment works specifically with 5-point landmarks. Test that face_alignment works specifically with 5-point landmarks.
""" """
# Standard 5-point landmarks # Standard 5-point landmarks
landmarks_5pt = np.array([ landmarks_5pt = np.array(
[38.2946, 51.6963], [
[73.5318, 51.5014], [38.2946, 51.6963],
[56.0252, 71.7366], [73.5318, 51.5014],
[41.5493, 92.3655], [56.0252, 71.7366],
[70.7299, 92.2041] [41.5493, 92.3655],
], dtype=np.float32) [70.7299, 92.2041],
],
dtype=np.float32,
)
aligned, _ = face_alignment(mock_image, landmarks_5pt, image_size=(112, 112)) aligned, _ = face_alignment(mock_image, landmarks_5pt, image_size=(112, 112))
assert aligned.shape == (112, 112, 3), "Should work with 5-point landmarks" assert aligned.shape == (112, 112, 3), "Should work with 5-point landmarks"