diff --git a/MODELS.md b/MODELS.md
index 86338bd..8f63d06 100644
--- a/MODELS.md
+++ b/MODELS.md
@@ -10,14 +10,14 @@ Complete guide to all available models, their performance characteristics, and s
RetinaFace models are trained on the WIDER FACE dataset and provide excellent accuracy-speed tradeoffs.
-| Model Name | Params | Size | Easy | Medium | Hard | Use Case |
-|---------------------|--------|--------|--------|--------|--------|----------------------------|
-| `MNET_025` | 0.4M | 1.7MB | 88.48% | 87.02% | 80.61% | Mobile/Edge devices |
-| `MNET_050` | 1.0M | 2.6MB | 89.42% | 87.97% | 82.40% | Mobile/Edge devices |
-| `MNET_V1` | 3.5M | 3.8MB | 90.59% | 89.14% | 84.13% | Balanced mobile |
-| `MNET_V2` ⭐ | 3.2M | 3.5MB | 91.70% | 91.03% | 86.60% | **Recommended default** |
-| `RESNET18` | 11.7M | 27MB | 92.50% | 91.02% | 86.63% | Server/High accuracy |
-| `RESNET34` | 24.8M | 56MB | 94.16% | 93.12% | 88.90% | Maximum accuracy |
+| Model Name | Params | Size | Easy | Medium | Hard | Use Case |
+| -------------- | ------ | ----- | ------ | ------ | ------ | ----------------------------- |
+| `MNET_025` | 0.4M | 1.7MB | 88.48% | 87.02% | 80.61% | Mobile/Edge devices |
+| `MNET_050` | 1.0M | 2.6MB | 89.42% | 87.97% | 82.40% | Mobile/Edge devices |
+| `MNET_V1` | 3.5M | 3.8MB | 90.59% | 89.14% | 84.13% | Balanced mobile |
+| `MNET_V2` ⭐ | 3.2M | 3.5MB | 91.70% | 91.03% | 86.60% | **Recommended default** |
+| `RESNET18` | 11.7M | 27MB | 92.50% | 91.02% | 86.63% | Server/High accuracy |
+| `RESNET34` | 24.8M | 56MB | 94.16% | 93.12% | 88.90% | Maximum accuracy |
**Accuracy**: WIDER FACE validation set (Easy/Medium/Hard subsets) - from [RetinaFace paper](https://arxiv.org/abs/1905.00641)
**Speed**: Benchmark on your own hardware using `scripts/run_detection.py --iterations 100`
@@ -46,10 +46,10 @@ detector = RetinaFace(
SCRFD (Sample and Computation Redistribution for Efficient Face Detection) models offer state-of-the-art speed-accuracy tradeoffs.
-| Model Name | Params | Size | Easy | Medium | Hard | Use Case |
-|-----------------|--------|-------|--------|--------|--------|----------------------------|
-| `SCRFD_500M` | 0.6M | 2.5MB | 90.57% | 88.12% | 68.51% | Real-time applications |
-| `SCRFD_10G` ⭐ | 4.2M | 17MB | 95.16% | 93.87% | 83.05% | **High accuracy + speed** |
+| Model Name | Params | Size | Easy | Medium | Hard | Use Case |
+| ---------------- | ------ | ----- | ------ | ------ | ------ | ------------------------------- |
+| `SCRFD_500M` | 0.6M | 2.5MB | 90.57% | 88.12% | 68.51% | Real-time applications |
+| `SCRFD_10G` ⭐ | 4.2M | 17MB | 95.16% | 93.87% | 83.05% | **High accuracy + speed** |
**Accuracy**: WIDER FACE validation set - from [SCRFD paper](https://arxiv.org/abs/2105.04714)
**Speed**: Benchmark on your own hardware using `scripts/run_detection.py --iterations 100`
@@ -76,16 +76,58 @@ detector = SCRFD(
---
+### YOLOv5-Face Family
+
+YOLOv5-Face models provide excellent detection accuracy with 5-point facial landmarks, optimized for real-time applications.
+
+| Model Name | Params | Size | Easy | Medium | Hard | FLOPs (G) | Use Case |
+| -------------- | ------ | ---- | ------ | ------ | ------ | --------- | ------------------------------ |
+| `YOLOV5S` ⭐ | 7.1M | 28MB | 94.33% | 92.61% | 83.15% | 5.751 | **Real-time + accuracy** |
+| `YOLOV5M` | 21.1M | 84MB | 95.30% | 93.76% | 85.28% | 18.146 | High accuracy |
+
+**Accuracy**: WIDER FACE validation set - from [YOLOv5-Face paper](https://arxiv.org/abs/2105.12931)
+**Speed**: Benchmark on your own hardware using `scripts/run_detection.py --iterations 100`
+**Note**: Fixed input size of 640×640. Models exported to ONNX from [deepcam-cn/yolov5-face](https://github.com/deepcam-cn/yolov5-face)
+
+#### Usage
+
+```python
+from uniface import YOLOv5Face
+from uniface.constants import YOLOv5FaceWeights
+
+# Real-time detection (recommended)
+detector = YOLOv5Face(
+ model_name=YOLOv5FaceWeights.YOLOV5S,
+ conf_thresh=0.6,
+ nms_thresh=0.5
+)
+
+# High accuracy
+detector = YOLOv5Face(
+ model_name=YOLOv5FaceWeights.YOLOV5M,
+ conf_thresh=0.6
+)
+
+# Detect faces with landmarks
+faces = detector.detect(image)
+for face in faces:
+ bbox = face['bbox'] # [x1, y1, x2, y2]
+ confidence = face['confidence']
+ landmarks = face['landmarks'] # 5-point landmarks (5, 2)
+```
+
+---
+
## Face Recognition Models
### ArcFace
State-of-the-art face recognition using additive angular margin loss.
-| Model Name | Backbone | Params | Size | Use Case |
-|-------------|-------------|--------|-------|----------------------------|
-| `MNET` ⭐ | MobileNet | 2.0M | 8MB | **Balanced (recommended)** |
-| `RESNET` | ResNet50 | 43.6M | 166MB | Maximum accuracy |
+| Model Name | Backbone | Params | Size | Use Case |
+| ----------- | --------- | ------ | ----- | -------------------------------- |
+| `MNET` ⭐ | MobileNet | 2.0M | 8MB | **Balanced (recommended)** |
+| `RESNET` | ResNet50 | 43.6M | 166MB | Maximum accuracy |
**Dataset**: Trained on MS1M-V2 (5.8M images, 85K identities)
**Accuracy**: Benchmark on your own dataset or use standard face verification benchmarks
@@ -113,12 +155,12 @@ embedding = recognizer.get_normalized_embedding(image, landmarks)
Lightweight face recognition optimized for mobile devices.
-| Model Name | Backbone | Params | Size | LFW | CALFW | CPLFW | AgeDB-30 | Use Case |
-|-----------------|-----------------|--------|------|-------|-------|-------|----------|--------------------|
-| `MNET_025` | MobileNetV1 0.25| 0.36M | 1MB | 98.76%| 92.02%| 82.37%| 90.02% | Ultra-lightweight |
-| `MNET_V2` ⭐ | MobileNetV2 | 2.29M | 4MB | 99.55%| 94.87%| 86.89%| 95.16% | **Mobile/Edge** |
-| `MNET_V3_SMALL` | MobileNetV3-S | 1.25M | 3MB | 99.30%| 93.77%| 85.29%| 92.79% | Mobile optimized |
-| `MNET_V3_LARGE` | MobileNetV3-L | 3.52M | 10MB | 99.53%| 94.56%| 86.79%| 95.13% | Balanced mobile |
+| Model Name | Backbone | Params | Size | LFW | CALFW | CPLFW | AgeDB-30 | Use Case |
+| ----------------- | ---------------- | ------ | ---- | ------ | ------ | ------ | -------- | --------------------- |
+| `MNET_025` | MobileNetV1 0.25 | 0.36M | 1MB | 98.76% | 92.02% | 82.37% | 90.02% | Ultra-lightweight |
+| `MNET_V2` ⭐ | MobileNetV2 | 2.29M | 4MB | 99.55% | 94.87% | 86.89% | 95.16% | **Mobile/Edge** |
+| `MNET_V3_SMALL` | MobileNetV3-S | 1.25M | 3MB | 99.30% | 93.77% | 85.29% | 92.79% | Mobile optimized |
+| `MNET_V3_LARGE` | MobileNetV3-L | 3.52M | 10MB | 99.53% | 94.56% | 86.79% | 95.13% | Balanced mobile |
**Dataset**: Trained on MS1M-V2 (5.8M images, 85K identities)
**Accuracy**: Evaluated on LFW, CALFW, CPLFW, and AgeDB-30 benchmarks
@@ -140,10 +182,10 @@ recognizer = MobileFace(model_name=MobileFaceWeights.MNET_V2)
Face recognition using angular softmax loss.
-| Model Name | Backbone | Params | Size | LFW | CALFW | CPLFW | AgeDB-30 | Use Case |
-|-------------|----------|--------|------|-------|-------|-------|----------|----------------------|
-| `SPHERE20` | Sphere20 | 24.5M | 50MB | 99.67%| 95.61%| 88.75%| 96.58% | Research/Comparison |
-| `SPHERE36` | Sphere36 | 34.6M | 92MB | 99.72%| 95.64%| 89.92%| 96.83% | Research/Comparison |
+| Model Name | Backbone | Params | Size | LFW | CALFW | CPLFW | AgeDB-30 | Use Case |
+| ------------ | -------- | ------ | ---- | ------ | ------ | ------ | -------- | ------------------- |
+| `SPHERE20` | Sphere20 | 24.5M | 50MB | 99.67% | 95.61% | 88.75% | 96.58% | Research/Comparison |
+| `SPHERE36` | Sphere36 | 34.6M | 92MB | 99.72% | 95.64% | 89.92% | 96.83% | Research/Comparison |
**Dataset**: Trained on MS1M-V2 (5.8M images, 85K identities)
**Accuracy**: Evaluated on LFW, CALFW, CPLFW, and AgeDB-30 benchmarks
@@ -166,9 +208,9 @@ recognizer = SphereFace(model_name=SphereFaceWeights.SPHERE20)
High-precision facial landmark localization.
-| Model Name | Points | Params | Size | Use Case |
-|------------|--------|--------|------|-----------------------------|
-| `2D106` | 106 | 3.7M | 14MB | Face alignment, analysis |
+| Model Name | Points | Params | Size | Use Case |
+| ---------- | ------ | ------ | ---- | ------------------------ |
+| `2D106` | 106 | 3.7M | 14MB | Face alignment, analysis |
**Note**: Provides 106 facial keypoints for detailed face analysis and alignment
@@ -183,6 +225,7 @@ landmarks = landmarker.get_landmarks(image, bbox)
```
**Landmark Groups:**
+
- Face contour: 0-32 (33 points)
- Eyebrows: 33-50 (18 points)
- Nose: 51-62 (12 points)
@@ -195,9 +238,9 @@ landmarks = landmarker.get_landmarks(image, bbox)
### Age & Gender Detection
-| Model Name | Attributes | Params | Size | Use Case |
-|------------|-------------|--------|------|-------------------|
-| `DEFAULT` | Age, Gender | 2.1M | 8MB | General purpose |
+| Model Name | Attributes | Params | Size | Use Case |
+| ----------- | ----------- | ------ | ---- | --------------- |
+| `DEFAULT` | Age, Gender | 2.1M | 8MB | General purpose |
**Dataset**: Trained on CelebA
**Note**: Accuracy varies by demographic and image quality. Test on your specific use case.
@@ -217,10 +260,10 @@ gender_id, age = predictor.predict(image, bbox)
### Emotion Detection
-| Model Name | Classes | Params | Size | Use Case |
-|--------------|---------|--------|------|-----------------------|
-| `AFFECNET7` | 7 | 0.5M | 2MB | 7-class emotion |
-| `AFFECNET8` | 8 | 0.5M | 2MB | 8-class emotion |
+| Model Name | Classes | Params | Size | Use Case |
+| ------------- | ------- | ------ | ---- | --------------- |
+| `AFFECNET7` | 7 | 0.5M | 2MB | 7-class emotion |
+| `AFFECNET8` | 8 | 0.5M | 2MB | 8-class emotion |
**Classes (7)**: Neutral, Happy, Sad, Surprise, Fear, Disgust, Anger
**Classes (8)**: Above + Contempt
@@ -240,118 +283,6 @@ emotion, confidence = predictor.predict(image, landmarks)
---
-## Model Selection Guide
-
-### By Use Case
-
-#### Mobile/Edge Devices
-- **Detection**: `RetinaFace(MNET_025)` or `SCRFD(SCRFD_500M)`
-- **Recognition**: `MobileFace(MNET_V2)`
-- **Priority**: Speed, small model size
-
-#### Real-Time Applications (Webcam, Video)
-- **Detection**: `RetinaFace(MNET_V2)` or `SCRFD(SCRFD_500M)`
-- **Recognition**: `ArcFace(MNET)`
-- **Priority**: Speed-accuracy balance
-
-#### High-Accuracy Applications (Security, Verification)
-- **Detection**: `SCRFD(SCRFD_10G)` or `RetinaFace(RESNET34)`
-- **Recognition**: `ArcFace(RESNET)`
-- **Priority**: Maximum accuracy
-
-#### Server/Cloud Deployment
-- **Detection**: `SCRFD(SCRFD_10G)`
-- **Recognition**: `ArcFace(RESNET)`
-- **Priority**: Accuracy, batch processing
-
----
-
-### By Hardware
-
-#### Apple Silicon (M1/M2/M3/M4)
-**Recommended**: All models work well with ARM64 optimizations (automatically included)
-
-```bash
-pip install uniface
-```
-
-**Recommended models**:
-- **Fast**: `SCRFD(SCRFD_500M)` - Lightweight, real-time capable
-- **Balanced**: `RetinaFace(MNET_V2)` - Good accuracy/speed tradeoff
-- **Accurate**: `SCRFD(SCRFD_10G)` - High accuracy
-
-**Benchmark on your M4**: `python scripts/run_detection.py --iterations 100`
-
-#### NVIDIA GPU (CUDA)
-**Recommended**: Larger models for maximum throughput
-
-```bash
-pip install uniface[gpu]
-```
-
-**Recommended models**:
-- **Fast**: `SCRFD(SCRFD_500M)` - Maximum throughput
-- **Balanced**: `SCRFD(SCRFD_10G)` - Best overall
-- **Accurate**: `RetinaFace(RESNET34)` - Highest accuracy
-
-#### CPU Only
-**Recommended**: Lightweight models
-
-**Recommended models**:
-- **Fast**: `RetinaFace(MNET_025)` - Smallest, fastest
-- **Balanced**: `RetinaFace(MNET_V2)` - Recommended default
-- **Accurate**: `SCRFD(SCRFD_10G)` - Best accuracy on CPU
-
-**Note**: FPS values vary significantly based on image size, number of faces, and hardware. Always benchmark on your specific setup.
-
----
-
-## Benchmark Details
-
-### How to Benchmark
-
-Run benchmarks on your own hardware:
-
-```bash
-# Detection speed
-python scripts/run_detection.py --image assets/test.jpg --iterations 100
-
-# Compare models
-python scripts/run_detection.py --image assets/test.jpg --method retinaface --iterations 100
-python scripts/run_detection.py --image assets/test.jpg --method scrfd --iterations 100
-```
-
-### Accuracy Metrics Explained
-
-- **WIDER FACE**: Standard face detection benchmark with three difficulty levels
- - **Easy**: Large faces (>50px), clear backgrounds
- - **Medium**: Medium-sized faces (30-50px), moderate occlusion
- - **Hard**: Small faces (<30px), heavy occlusion, blur
-
- *Accuracy values are from the original papers - see references below*
-
-- **Model Size**: ONNX model file size (affects download time and memory)
-- **Params**: Number of model parameters (affects inference speed)
-
-### Important Notes
-
-1. **Speed varies by**:
- - Image resolution
- - Number of faces in image
- - Hardware (CPU/GPU/CoreML)
- - Batch size
- - Operating system
-
-2. **Accuracy varies by**:
- - Image quality
- - Lighting conditions
- - Face pose and occlusion
- - Demographic factors
-
-3. **Always benchmark on your specific use case** before choosing a model
-
----
-
## Model Updates
Models are automatically downloaded and cached on first use. Cache location: `~/.uniface/models/`
@@ -388,6 +319,8 @@ python scripts/download_model.py --model MNET_V2
### Model Training & Architectures
- **RetinaFace Training**: [yakhyo/retinaface-pytorch](https://github.com/yakhyo/retinaface-pytorch) - PyTorch implementation and training code
+- **YOLOv5-Face Original**: [deepcam-cn/yolov5-face](https://github.com/deepcam-cn/yolov5-face) - Original PyTorch implementation
+- **YOLOv5-Face ONNX**: [yakhyo/yolov5-face-onnx-inference](https://github.com/yakhyo/yolov5-face-onnx-inference) - ONNX inference implementation
- **Face Recognition Training**: [yakhyo/face-recognition](https://github.com/yakhyo/face-recognition) - ArcFace, MobileFace, SphereFace training code
- **InsightFace**: [deepinsight/insightface](https://github.com/deepinsight/insightface) - Model architectures and pretrained weights
@@ -395,6 +328,6 @@ python scripts/download_model.py --model MNET_V2
- **RetinaFace**: [Single-Shot Multi-Level Face Localisation in the Wild](https://arxiv.org/abs/1905.00641)
- **SCRFD**: [Sample and Computation Redistribution for Efficient Face Detection](https://arxiv.org/abs/2105.04714)
+- **YOLOv5-Face**: [YOLO5Face: Why Reinventing a Face Detector](https://arxiv.org/abs/2105.12931)
- **ArcFace**: [Additive Angular Margin Loss for Deep Face Recognition](https://arxiv.org/abs/1801.07698)
- **SphereFace**: [Deep Hypersphere Embedding for Face Recognition](https://arxiv.org/abs/1704.08063)
-
diff --git a/QUICKSTART.md b/QUICKSTART.md
index ab502ea..384ea20 100644
--- a/QUICKSTART.md
+++ b/QUICKSTART.md
@@ -271,8 +271,8 @@ Choose the right model for your use case:
### Detection Models
```python
-from uniface.detection import RetinaFace, SCRFD
-from uniface.constants import RetinaFaceWeights, SCRFDWeights
+from uniface.detection import RetinaFace, SCRFD, YOLOv5Face
+from uniface.constants import RetinaFaceWeights, SCRFDWeights, YOLOv5FaceWeights
# Fast detection (mobile/edge devices)
detector = RetinaFace(
@@ -285,6 +285,13 @@ detector = RetinaFace(
model_name=RetinaFaceWeights.MNET_V2
)
+# Real-time with high accuracy
+detector = YOLOv5Face(
+ model_name=YOLOv5FaceWeights.YOLOV5S,
+ conf_thresh=0.6,
+ nms_thresh=0.5
+)
+
# High accuracy (server/GPU)
detector = SCRFD(
model_name=SCRFDWeights.SCRFD_10G_KPS,
@@ -367,9 +374,7 @@ from uniface import retinaface # Module, not class
## References
- **RetinaFace Training**: [yakhyo/retinaface-pytorch](https://github.com/yakhyo/retinaface-pytorch)
+- **YOLOv5-Face Original**: [deepcam-cn/yolov5-face](https://github.com/deepcam-cn/yolov5-face)
+- **YOLOv5-Face ONNX**: [yakhyo/yolov5-face-onnx-inference](https://github.com/yakhyo/yolov5-face-onnx-inference)
- **Face Recognition Training**: [yakhyo/face-recognition](https://github.com/yakhyo/face-recognition)
- **InsightFace**: [deepinsight/insightface](https://github.com/deepinsight/insightface)
-
----
-
-Happy coding! 🚀
diff --git a/README.md b/README.md
index 32a0f23..2ea115a 100644
--- a/README.md
+++ b/README.md
@@ -7,7 +7,6 @@
[](https://pepy.tech/project/uniface)
[](https://deepwiki.com/yakhyo/uniface)
-
@@ -190,8 +189,8 @@ landmarker = Landmark106()
### Direct Model Instantiation
```python
-from uniface import RetinaFace, SCRFD, ArcFace, MobileFace, SphereFace
-from uniface.constants import RetinaFaceWeights
+from uniface import RetinaFace, SCRFD, YOLOv5Face, ArcFace, MobileFace, SphereFace
+from uniface.constants import RetinaFaceWeights, YOLOv5FaceWeights
# Detection
detector = RetinaFace(
@@ -200,6 +199,13 @@ detector = RetinaFace(
nms_thresh=0.4
)
+# YOLOv5-Face detection
+detector = YOLOv5Face(
+ model_name=YOLOv5FaceWeights.YOLOV5S,
+ conf_thresh=0.6,
+ nms_thresh=0.5
+)
+
# Recognition
recognizer = ArcFace() # Uses default weights
recognizer = MobileFace() # Lightweight alternative
@@ -228,8 +234,10 @@ faces = detect_faces(image, method='retinaface', conf_thresh=0.8)
| retinaface_r34 | 94.16% | 93.12% | 88.90% | High accuracy |
| scrfd_500m | 90.57% | 88.12% | 68.51% | Real-time applications |
| scrfd_10g | 95.16% | 93.87% | 83.05% | Best accuracy/speed |
+| yolov5s_face | 94.33% | 92.61% | 83.15% | Real-time + accuracy |
+| yolov5m_face | 95.30% | 93.76% | 85.28% | High accuracy |
-_Accuracy values from original papers: [RetinaFace](https://arxiv.org/abs/1905.00641), [SCRFD](https://arxiv.org/abs/2105.04714)_
+_Accuracy values from original papers: [RetinaFace](https://arxiv.org/abs/1905.00641), [SCRFD](https://arxiv.org/abs/2105.04714), [YOLOv5-Face](https://arxiv.org/abs/2105.12931)_
**Benchmark on your hardware:**
@@ -443,20 +451,12 @@ uniface/
## References
-### Model Training & Architectures
-
- **RetinaFace Training**: [yakhyo/retinaface-pytorch](https://github.com/yakhyo/retinaface-pytorch) - PyTorch implementation and training code
+- **YOLOv5-Face Original**: [deepcam-cn/yolov5-face](https://github.com/deepcam-cn/yolov5-face) - Original PyTorch implementation
+- **YOLOv5-Face ONNX**: [yakhyo/yolov5-face-onnx-inference](https://github.com/yakhyo/yolov5-face-onnx-inference) - ONNX inference implementation
- **Face Recognition Training**: [yakhyo/face-recognition](https://github.com/yakhyo/face-recognition) - ArcFace, MobileFace, SphereFace training code
- **InsightFace**: [deepinsight/insightface](https://github.com/deepinsight/insightface) - Model architectures and pretrained weights
-### Papers
-
-- **RetinaFace**: [Single-Shot Multi-Level Face Localisation in the Wild](https://arxiv.org/abs/1905.00641)
-- **SCRFD**: [Sample and Computation Redistribution for Efficient Face Detection](https://arxiv.org/abs/2105.04714)
-- **ArcFace**: [Additive Angular Margin Loss for Deep Face Recognition](https://arxiv.org/abs/1801.07698)
-
----
-
## Contributing
Contributions are welcome! Please open an issue or submit a pull request on [GitHub](https://github.com/yakhyo/uniface).
diff --git a/pyproject.toml b/pyproject.toml
index f829d8f..a966389 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "uniface"
-version = "1.1.2"
+version = "1.2.0"
description = "UniFace: A Comprehensive Library for Face Detection, Recognition, Landmark Analysis, Age, and Gender Detection"
readme = "README.md"
license = { text = "MIT" }
diff --git a/scripts/run_detection.py b/scripts/run_detection.py
index ac8a8b5..8dc3d7e 100644
--- a/scripts/run_detection.py
+++ b/scripts/run_detection.py
@@ -7,7 +7,7 @@ import os
import cv2
-from uniface.detection import SCRFD, RetinaFace
+from uniface.detection import SCRFD, RetinaFace, YOLOv5Face
from uniface.visualization import draw_detections
@@ -75,15 +75,21 @@ def main():
parser = argparse.ArgumentParser(description='Run face detection')
parser.add_argument('--image', type=str, help='Path to input image')
parser.add_argument('--webcam', action='store_true', help='Use webcam')
- parser.add_argument('--method', type=str, default='retinaface', choices=['retinaface', 'scrfd'])
- parser.add_argument('--threshold', type=float, default=0.6, help='Visualization threshold')
+ parser.add_argument('--method', type=str, default='retinaface', choices=['retinaface', 'scrfd', 'yolov5face'])
+ parser.add_argument('--threshold', type=float, default=0.25, help='Visualization threshold')
parser.add_argument('--save_dir', type=str, default='outputs')
args = parser.parse_args()
if not args.image and not args.webcam:
parser.error('Either --image or --webcam must be specified')
- detector = RetinaFace() if args.method == 'retinaface' else SCRFD()
+ if args.method == 'retinaface':
+ detector = RetinaFace()
+ elif args.method == 'scrfd':
+ detector = SCRFD()
+ else:
+ from uniface.constants import YOLOv5FaceWeights
+ detector = YOLOv5Face(model_name=YOLOv5FaceWeights.YOLOV5M)
if args.webcam:
run_webcam(detector, args.threshold)
diff --git a/tests/test_factory.py b/tests/test_factory.py
index 0f134bd..2aab407 100644
--- a/tests/test_factory.py
+++ b/tests/test_factory.py
@@ -263,7 +263,7 @@ def test_factory_returns_correct_types():
"""
Test that factory functions return instances of the correct types.
"""
- from uniface import RetinaFace, ArcFace, Landmark106
+ from uniface import ArcFace, Landmark106, RetinaFace
detector = create_detector('retinaface')
recognizer = create_recognizer('arcface')
diff --git a/uniface/__init__.py b/uniface/__init__.py
index fed089b..6f92928 100644
--- a/uniface/__init__.py
+++ b/uniface/__init__.py
@@ -13,7 +13,7 @@
__license__ = 'MIT'
__author__ = 'Yakhyokhuja Valikhujaev'
-__version__ = '1.1.2'
+__version__ = '1.2.0'
from uniface.face_utils import compute_similarity, face_alignment
@@ -32,6 +32,7 @@ except ImportError:
from .detection import (
SCRFD,
RetinaFace,
+ YOLOv5Face,
create_detector,
detect_faces,
list_available_detectors,
@@ -55,6 +56,7 @@ __all__ = [
# Detection models
'RetinaFace',
'SCRFD',
+ 'YOLOv5Face',
# Recognition models
'ArcFace',
'MobileFace',
diff --git a/uniface/constants.py b/uniface/constants.py
index f2178a3..586ce9e 100644
--- a/uniface/constants.py
+++ b/uniface/constants.py
@@ -55,6 +55,20 @@ class SCRFDWeights(str, Enum):
SCRFD_500M_KPS = "scrfd_500m"
+class YOLOv5FaceWeights(str, Enum):
+ """
+ Trained on WIDER FACE dataset.
+ Original implementation: https://github.com/deepcam-cn/yolov5-face
+ Exported to ONNX from: https://github.com/yakhyo/yolov5-face-onnx-inference
+
+ Model Performance (WIDER FACE):
+ - YOLOV5S: 7.1M params, 28MB, 94.33% Easy / 92.61% Medium / 83.15% Hard
+ - YOLOV5M: 21.1M params, 84MB, 95.30% Easy / 93.76% Medium / 85.28% Hard
+ """
+ YOLOV5S = "yolov5s_face"
+ YOLOV5M = "yolov5m_face"
+
+
class DDAMFNWeights(str, Enum):
"""
Trained on AffectNet dataset.
@@ -102,6 +116,9 @@ MODEL_URLS: Dict[Enum, str] = {
# SCRFD
SCRFDWeights.SCRFD_10G_KPS: 'https://github.com/yakhyo/uniface/releases/download/weights/scrfd_10g_kps.onnx',
SCRFDWeights.SCRFD_500M_KPS: 'https://github.com/yakhyo/uniface/releases/download/weights/scrfd_500m_kps.onnx',
+ # YOLOv5-Face
+ YOLOv5FaceWeights.YOLOV5S: 'https://github.com/yakhyo/yolov5-face-onnx-inference/releases/download/weights/yolov5s_face.onnx',
+ YOLOv5FaceWeights.YOLOV5M: 'https://github.com/yakhyo/yolov5-face-onnx-inference/releases/download/weights/yolov5m_face.onnx',
# DDAFM
DDAMFNWeights.AFFECNET7: 'https://github.com/yakhyo/uniface/releases/download/weights/affecnet7.script',
DDAMFNWeights.AFFECNET8: 'https://github.com/yakhyo/uniface/releases/download/weights/affecnet8.script',
@@ -133,6 +150,9 @@ MODEL_SHA256: Dict[Enum, str] = {
# SCRFD
SCRFDWeights.SCRFD_10G_KPS: '5838f7fe053675b1c7a08b633df49e7af5495cee0493c7dcf6697200b85b5b91',
SCRFDWeights.SCRFD_500M_KPS: '5e4447f50245bbd7966bd6c0fa52938c61474a04ec7def48753668a9d8b4ea3a',
+ # YOLOv5-Face
+ YOLOv5FaceWeights.YOLOV5S: 'fc682801cd5880e1e296184a14aea0035486b5146ec1a1389d2e7149cb134bb2',
+ YOLOv5FaceWeights.YOLOV5M: '04302ce27a15bde3e20945691b688e2dd018a10e92dd8932146bede6a49207b2',
# DDAFM
DDAMFNWeights.AFFECNET7: '10535bf8b6afe8e9d6ae26cea6c3add9a93036e9addb6adebfd4a972171d015d',
DDAMFNWeights.AFFECNET8: '8c66963bc71db42796a14dfcbfcd181b268b65a3fc16e87147d6a3a3d7e0f487',
diff --git a/uniface/detection/__init__.py b/uniface/detection/__init__.py
index 87538f6..5d0b140 100644
--- a/uniface/detection/__init__.py
+++ b/uniface/detection/__init__.py
@@ -10,6 +10,7 @@ import numpy as np
from .base import BaseDetector
from .retinaface import RetinaFace
from .scrfd import SCRFD
+from .yolov5 import YOLOv5Face
# Global cache for detector instances
_detector_cache: Dict[str, BaseDetector] = {}
@@ -59,6 +60,7 @@ def create_detector(method: str = 'retinaface', **kwargs) -> BaseDetector:
method (str): Detection method. Options:
- 'retinaface': RetinaFace detector (default)
- 'scrfd': SCRFD detector (fast and accurate)
+ - 'yolov5face': YOLOv5-Face detector (accurate with landmarks)
**kwargs: Detector-specific parameters
Returns:
@@ -86,6 +88,14 @@ def create_detector(method: str = 'retinaface', **kwargs) -> BaseDetector:
... conf_thresh=0.8,
... nms_thresh=0.4
... )
+
+ >>> # YOLOv5-Face detector
+ >>> detector = create_detector(
+ ... 'yolov5face',
+ ... model_name=YOLOv5FaceWeights.YOLOV5S,
+ ... conf_thresh=0.25,
+ ... nms_thresh=0.45
+ ... )
"""
method = method.lower()
@@ -95,8 +105,11 @@ def create_detector(method: str = 'retinaface', **kwargs) -> BaseDetector:
elif method == 'scrfd':
return SCRFD(**kwargs)
+ elif method == 'yolov5face':
+ return YOLOv5Face(**kwargs)
+
else:
- available_methods = ['retinaface', 'scrfd']
+ available_methods = ['retinaface', 'scrfd', 'yolov5face']
raise ValueError(f"Unsupported detection method: '{method}'. Available methods: {available_methods}")
@@ -130,6 +143,17 @@ def list_available_detectors() -> Dict[str, Dict[str, Any]]:
'input_size': (640, 640),
},
},
+ 'yolov5face': {
+ 'description': 'YOLOv5-Face detector - accurate face detection with landmarks',
+ 'supports_landmarks': True,
+ 'paper': 'https://arxiv.org/abs/2105.12931',
+ 'default_params': {
+ 'model_name': 'yolov5s_face',
+ 'conf_thresh': 0.25,
+ 'nms_thresh': 0.45,
+ 'input_size': 640,
+ },
+ },
}
@@ -139,5 +163,6 @@ __all__ = [
'list_available_detectors',
'SCRFD',
'RetinaFace',
+ 'YOLOv5Face',
'BaseDetector',
]
diff --git a/uniface/detection/retinaface.py b/uniface/detection/retinaface.py
index fe06541..8109b4c 100644
--- a/uniface/detection/retinaface.py
+++ b/uniface/detection/retinaface.py
@@ -38,6 +38,7 @@ class RetinaFace(BaseDetector):
dynamic_size (bool, optional): If True, generate anchors dynamically per input image. Defaults to False.
input_size (Tuple[int, int], optional): Fixed input size (width, height) if `dynamic_size=False`.
Defaults to (640, 640).
+ Note: Non-default sizes may cause slower inference and CoreML compatibility issues.
Attributes:
model_name (RetinaFaceWeights): Selected model variant.
diff --git a/uniface/detection/scrfd.py b/uniface/detection/scrfd.py
index e62b83e..ee4e815 100644
--- a/uniface/detection/scrfd.py
+++ b/uniface/detection/scrfd.py
@@ -31,7 +31,9 @@ class SCRFD(BaseDetector):
Specifies the SCRFD variant to load. Defaults to SCRFD_10G_KPS.
conf_thresh (float, optional): Confidence threshold for filtering detections. Defaults to 0.5.
nms_thresh (float, optional): Non-Maximum Suppression threshold. Defaults to 0.4.
- input_size (Tuple[int, int], optional): Input image size (width, height). Defaults to (640, 640).
+ input_size (Tuple[int, int], optional): Input image size (width, height).
+ Defaults to (640, 640).
+ Note: Non-default sizes may cause slower inference and CoreML compatibility issues.
Attributes:
conf_thresh (float): Threshold used to filter low-confidence detections.
diff --git a/uniface/detection/yolov5.py b/uniface/detection/yolov5.py
new file mode 100644
index 0000000..188ba72
--- /dev/null
+++ b/uniface/detection/yolov5.py
@@ -0,0 +1,326 @@
+# Copyright 2025 Yakhyokhuja Valikhujaev
+# Author: Yakhyokhuja Valikhujaev
+# GitHub: https://github.com/yakhyo
+
+from typing import Any, Dict, List, Literal, Tuple
+
+import cv2
+import numpy as np
+
+from uniface.common import non_max_suppression
+from uniface.constants import YOLOv5FaceWeights
+from uniface.log import Logger
+from uniface.model_store import verify_model_weights
+from uniface.onnx_utils import create_onnx_session
+
+from .base import BaseDetector
+
+__all__ = ['YOLOv5Face']
+
+
+class YOLOv5Face(BaseDetector):
+ """
+ Face detector based on the YOLOv5-Face architecture.
+
+ Paper: https://arxiv.org/abs/2105.12931
+ Original Implementation: https://github.com/deepcam-cn/yolov5-face
+
+ Args:
+ **kwargs: Keyword arguments passed to BaseDetector and YOLOv5Face. Supported keys include:
+ model_name (YOLOv5FaceWeights, optional): Predefined model enum (e.g., `YOLOV5S`).
+ Specifies the YOLOv5-Face variant to load. Defaults to YOLOV5S.
+ conf_thresh (float, optional): Confidence threshold for filtering detections. Defaults to 0.25.
+ nms_thresh (float, optional): Non-Maximum Suppression threshold. Defaults to 0.45.
+ input_size (int, optional): Input image size. Defaults to 640.
+ Note: ONNX model is fixed at 640. Changing this will cause inference errors.
+ max_det (int, optional): Maximum number of detections to return. Defaults to 750.
+
+ Attributes:
+ conf_thresh (float): Threshold used to filter low-confidence detections.
+ nms_thresh (float): Threshold used during NMS to suppress overlapping boxes.
+ input_size (int): Image size to which inputs are resized before inference.
+ max_det (int): Maximum number of detections to return.
+ _model_path (str): Absolute path to the downloaded/verified model weights.
+
+ Raises:
+ ValueError: If the model weights are invalid or not found.
+ RuntimeError: If the ONNX model fails to load or initialize.
+ """
+
+ def __init__(self, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self._supports_landmarks = True # YOLOv5-Face supports landmarks
+
+ model_name = kwargs.get('model_name', YOLOv5FaceWeights.YOLOV5S)
+ conf_thresh = kwargs.get('conf_thresh', 0.6) # 0.6 is default from original YOLOv5-Face repository
+ nms_thresh = kwargs.get('nms_thresh', 0.5) # 0.5 is default from original YOLOv5-Face repository
+ input_size = kwargs.get('input_size', 640)
+ max_det = kwargs.get('max_det', 750)
+
+ # Validate input size
+ if input_size != 640:
+ raise ValueError(
+ f'YOLOv5Face only supports input_size=640 (got {input_size}). The ONNX model has a fixed input shape.'
+ )
+
+ self.conf_thresh = conf_thresh
+ self.nms_thresh = nms_thresh
+ self.input_size = input_size
+ self.max_det = max_det
+
+ Logger.info(
+ f'Initializing YOLOv5Face with model={model_name}, conf_thresh={conf_thresh}, '
+ f'nms_thresh={nms_thresh}, input_size={input_size}'
+ )
+
+ # Get path to model weights
+ self._model_path = verify_model_weights(model_name)
+ Logger.info(f'Verified model weights located at: {self._model_path}')
+
+ # Initialize model
+ self._initialize_model(self._model_path)
+
+ def _initialize_model(self, model_path: str) -> None:
+ """
+ Initializes an ONNX model session from the given path.
+
+ Args:
+ model_path (str): The file path to the ONNX model.
+
+ Raises:
+ RuntimeError: If the model fails to load, logs an error and raises an exception.
+ """
+ try:
+ self.session = create_onnx_session(model_path)
+ self.input_names = self.session.get_inputs()[0].name
+ self.output_names = [x.name for x in self.session.get_outputs()]
+ Logger.info(f'Successfully initialized the model from {model_path}')
+ except Exception as e:
+ Logger.error(f"Failed to load model from '{model_path}': {e}", exc_info=True)
+ raise RuntimeError(f"Failed to initialize model session for '{model_path}'") from e
+
+ def preprocess(self, image: np.ndarray) -> Tuple[np.ndarray, float, Tuple[int, int]]:
+ """
+ Preprocess image for inference.
+
+ Args:
+ image (np.ndarray): Input image (BGR format)
+
+ Returns:
+ Tuple[np.ndarray, float, Tuple[int, int]]: Preprocessed image, scale ratio, and padding
+ """
+ # Get original image shape
+ img_h, img_w = image.shape[:2]
+
+ # Calculate scale ratio
+ scale = min(self.input_size / img_h, self.input_size / img_w)
+ new_h, new_w = int(img_h * scale), int(img_w * scale)
+
+ # Resize image
+ img_resized = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
+
+ # Create padded image
+ img_padded = np.full((self.input_size, self.input_size, 3), 114, dtype=np.uint8)
+
+ # Calculate padding
+ pad_h = (self.input_size - new_h) // 2
+ pad_w = (self.input_size - new_w) // 2
+
+ # Place resized image in center
+ img_padded[pad_h : pad_h + new_h, pad_w : pad_w + new_w] = img_resized
+
+ # Convert to RGB and normalize
+ img_rgb = cv2.cvtColor(img_padded, cv2.COLOR_BGR2RGB)
+ img_normalized = img_rgb.astype(np.float32) / 255.0
+
+ # Transpose to CHW format (HWC -> CHW) and add batch dimension
+ img_transposed = np.transpose(img_normalized, (2, 0, 1))
+ img_batch = np.expand_dims(img_transposed, axis=0)
+ img_batch = np.ascontiguousarray(img_batch)
+
+ return img_batch, scale, (pad_w, pad_h)
+
+ def inference(self, input_tensor: np.ndarray) -> List[np.ndarray]:
+ """Perform model inference on the preprocessed image tensor.
+
+ Args:
+ input_tensor (np.ndarray): Preprocessed input tensor.
+
+ Returns:
+ List[np.ndarray]: Raw model outputs.
+ """
+ return self.session.run(self.output_names, {self.input_names: input_tensor})
+
+ def postprocess(
+ self,
+ predictions: np.ndarray,
+ scale: float,
+ padding: Tuple[int, int],
+ ) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ Postprocess model predictions.
+
+ Args:
+ predictions (np.ndarray): Raw model output
+ scale (float): Scale ratio used in preprocessing
+ padding (Tuple[int, int]): Padding used in preprocessing
+
+ Returns:
+ Tuple[np.ndarray, np.ndarray]: Filtered detections and landmarks
+ - detections: [x1, y1, x2, y2, conf]
+ - landmarks: [5, 2] for each detection
+ """
+ # predictions shape: (1, 25200, 16)
+ # 16 = [x, y, w, h, obj_conf, cls_conf, 10 landmarks (5 points * 2 coords)]
+
+ predictions = predictions[0] # Remove batch dimension
+
+ # Filter by confidence
+ mask = predictions[:, 4] >= self.conf_thresh
+ predictions = predictions[mask]
+
+ if len(predictions) == 0:
+ return np.array([]), np.array([])
+
+ # Convert from xywh to xyxy
+ boxes = self._xywh2xyxy(predictions[:, :4])
+
+ # Get confidence scores
+ scores = predictions[:, 4]
+
+ # Get landmarks (5 points, 10 coordinates)
+ landmarks = predictions[:, 5:15].copy()
+
+ # Apply NMS
+ detections_for_nms = np.hstack((boxes, scores[:, None])).astype(np.float32, copy=False)
+ keep = non_max_suppression(detections_for_nms, self.nms_thresh)
+
+ if len(keep) == 0:
+ return np.array([]), np.array([])
+
+ # Filter detections and limit to max_det
+ keep = keep[: self.max_det]
+ boxes = boxes[keep]
+ scores = scores[keep]
+ landmarks = landmarks[keep]
+
+ # Scale back to original image coordinates
+ pad_w, pad_h = padding
+ boxes[:, [0, 2]] = (boxes[:, [0, 2]] - pad_w) / scale
+ boxes[:, [1, 3]] = (boxes[:, [1, 3]] - pad_h) / scale
+
+ # Scale landmarks
+ for i in range(5):
+ landmarks[:, i * 2] = (landmarks[:, i * 2] - pad_w) / scale
+ landmarks[:, i * 2 + 1] = (landmarks[:, i * 2 + 1] - pad_h) / scale
+
+ # Reshape landmarks to (N, 5, 2)
+ landmarks = landmarks.reshape(-1, 5, 2)
+
+ # Combine results
+ detections = np.concatenate([boxes, scores[:, None]], axis=1)
+
+ return detections, landmarks
+
+ def _xywh2xyxy(self, x: np.ndarray) -> np.ndarray:
+ """
+ Convert bounding box format from xywh to xyxy.
+
+ Args:
+ x (np.ndarray): Boxes in [x, y, w, h] format
+
+ Returns:
+ np.ndarray: Boxes in [x1, y1, x2, y2] format
+ """
+ y = np.copy(x)
+ y[..., 0] = x[..., 0] - x[..., 2] / 2 # x1
+ y[..., 1] = x[..., 1] - x[..., 3] / 2 # y1
+ y[..., 2] = x[..., 0] + x[..., 2] / 2 # x2
+ y[..., 3] = x[..., 1] + x[..., 3] / 2 # y2
+ return y
+
+ def detect(
+ self,
+ image: np.ndarray,
+ max_num: int = 0,
+ metric: Literal['default', 'max'] = 'max',
+ center_weight: float = 2.0,
+ ) -> List[Dict[str, Any]]:
+ """
+ Perform face detection on an input image and return bounding boxes and facial landmarks.
+
+ Args:
+ image (np.ndarray): Input image as a NumPy array of shape (H, W, C).
+ max_num (int): Maximum number of detections to return. Use 0 to return all detections. Defaults to 0.
+ metric (Literal["default", "max"]): Metric for ranking detections when `max_num` is limited.
+ - "default": Prioritize detections closer to the image center.
+ - "max": Prioritize detections with larger bounding box areas.
+ center_weight (float): Weight for penalizing detections farther from the image center
+ when using the "default" metric. Defaults to 2.0.
+
+ Returns:
+ List[Dict[str, Any]]: List of face detection dictionaries, each containing:
+ - 'bbox' (np.ndarray): Bounding box coordinates with shape (4,) as [x1, y1, x2, y2]
+ - 'confidence' (float): Detection confidence score (0.0 to 1.0)
+ - 'landmarks' (np.ndarray): 5-point facial landmarks with shape (5, 2)
+
+ Example:
+ >>> faces = detector.detect(image)
+ >>> for face in faces:
+ ... bbox = face['bbox'] # np.ndarray with shape (4,)
+ ... confidence = face['confidence'] # float
+ ... landmarks = face['landmarks'] # np.ndarray with shape (5, 2)
+ ... # Can pass landmarks directly to recognition
+ ... embedding = recognizer.get_normalized_embedding(image, landmarks)
+ """
+
+ original_height, original_width = image.shape[:2]
+
+ # Preprocess
+ image_tensor, scale, padding = self.preprocess(image)
+
+ # ONNXRuntime inference
+ outputs = self.inference(image_tensor)
+
+ # Postprocess
+ detections, landmarks = self.postprocess(outputs[0], scale, padding)
+
+ # Handle case when no faces are detected
+ if len(detections) == 0:
+ return []
+
+ if 0 < max_num < detections.shape[0]:
+ # Calculate area of detections
+ area = (detections[:, 2] - detections[:, 0]) * (detections[:, 3] - detections[:, 1])
+
+ # Calculate offsets from image center
+ center = (original_height // 2, original_width // 2)
+ offsets = np.vstack(
+ [
+ (detections[:, 0] + detections[:, 2]) / 2 - center[1],
+ (detections[:, 1] + detections[:, 3]) / 2 - center[0],
+ ]
+ )
+
+ # Calculate scores based on the chosen metric
+ offset_dist_squared = np.sum(np.power(offsets, 2.0), axis=0)
+ if metric == 'max':
+ values = area
+ else:
+ values = area - offset_dist_squared * center_weight
+
+ # Sort by scores and select top `max_num`
+ sorted_indices = np.argsort(values)[::-1][:max_num]
+ detections = detections[sorted_indices]
+ landmarks = landmarks[sorted_indices]
+
+ faces = []
+ for i in range(detections.shape[0]):
+ face_dict = {
+ 'bbox': detections[i, :4].astype(np.float32),
+ 'confidence': float(detections[i, 4]),
+ 'landmarks': landmarks[i].astype(np.float32),
+ }
+ faces.append(face_dict)
+
+ return faces