mirror of
https://github.com/yakhyo/uniface.git
synced 2025-12-30 09:02:25 +00:00
feat: Add model name signature and several more updates
This commit is contained in:
57
README.md
57
README.md
@@ -63,21 +63,47 @@ Explore the following example notebooks to learn how to use **UniFace** effectiv
|
||||
- [Face Alignment](examples/face_alignment.ipynb): Shows how to align faces using detected landmarks.
|
||||
- [Age and Gender Detection](examples/age_gender.ipynb): Example for detecting age and gender from faces. (underdevelopment)
|
||||
|
||||
### Initialize the Model
|
||||
### 🚀 Initialize the RetinaFace Model
|
||||
|
||||
To use the RetinaFace model for face detection, initialize it with either custom or default configuration parameters.
|
||||
|
||||
#### Full Initialization (with custom parameters)
|
||||
|
||||
```python
|
||||
from uniface import RetinaFace
|
||||
from uniface.constants import RetinaFaceWeights
|
||||
|
||||
# Initialize RetinaFace with custom configuration
|
||||
uniface_inference = RetinaFace(
|
||||
model_name=RetinaFaceWeights.MNET_V2, # Model name from enum
|
||||
conf_thresh=0.5, # Confidence threshold for detections
|
||||
pre_nms_topk=5000, # Number of top detections before NMS
|
||||
nms_thresh=0.4, # IoU threshold for NMS
|
||||
post_nms_topk=750, # Number of top detections after NMS
|
||||
dynamic_size=False, # Whether to allow arbitrary input sizes
|
||||
input_size=(640, 640) # Input image size (HxW)
|
||||
)
|
||||
```
|
||||
|
||||
#### Minimal Initialization (uses default parameters)
|
||||
|
||||
```python
|
||||
from uniface import RetinaFace
|
||||
|
||||
# Initialize the RetinaFace model
|
||||
uniface_inference = RetinaFace(
|
||||
model_name="retinaface_mnet_v2", # Model name
|
||||
conf_thresh=0.5, # Confidence threshold
|
||||
pre_nms_topk=5000, # Pre-NMS Top-K detections
|
||||
nms_thresh=0.4, # NMS IoU threshold
|
||||
post_nms_topk=750, # Post-NMS Top-K detections
|
||||
dynamic_size=False, # Arbitrary image size inference
|
||||
input_size=(640, 640) # Pre-defined input image size
|
||||
)
|
||||
# Initialize with default settings
|
||||
uniface_inference = RetinaFace()
|
||||
```
|
||||
|
||||
**Default Parameters:**
|
||||
|
||||
```python
|
||||
model_name = RetinaFaceWeights.MNET_V2
|
||||
conf_thresh = 0.5
|
||||
pre_nms_topk = 5000
|
||||
nms_thresh = 0.4
|
||||
post_nms_topk = 750
|
||||
dynamic_size = False
|
||||
input_size = (640, 640)
|
||||
```
|
||||
|
||||
### Run Inference
|
||||
@@ -170,9 +196,11 @@ cv2.destroyAllWindows()
|
||||
|
||||
```python
|
||||
from typings import Tuple
|
||||
from uniface import RetinaFace
|
||||
from uniface.constants import RetinaFaceWeights
|
||||
|
||||
RetinaFace(
|
||||
model_name: str,
|
||||
model_name: RetinaFaceWeights,
|
||||
conf_thresh: float = 0.5,
|
||||
pre_nms_topk: int = 5000,
|
||||
nms_thresh: float = 0.4,
|
||||
@@ -184,9 +212,8 @@ RetinaFace(
|
||||
|
||||
**Parameters**:
|
||||
|
||||
- `model_name` _(str)_: Name of the model to use. Supported models:
|
||||
- `retinaface_mnet025`, `retinaface_mnet050`, `retinaface_mnet_v1`, `retinaface_mnet_v2`
|
||||
- `retinaface_r18`, `retinaface_r34`
|
||||
- `model_name` _(RetinaFaceWeights)_: Enum value for model to use. Supported values:
|
||||
- `MNET_025`, `MNET_050`, `MNET_V1`, `MNET_V2`, `RESNET18`, `RESNET34`
|
||||
- `conf_thresh` _(float, default=0.5)_: Minimum confidence score for detections.
|
||||
- `pre_nms_topk` _(int, default=5000)_: Max detections to keep before NMS.
|
||||
- `nms_thresh` _(float, default=0.4)_: IoU threshold for Non-Maximum Suppression.
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -30,14 +30,15 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import cv2\n",
|
||||
"from PIL import Image\n",
|
||||
"import IPython.display as display\n",
|
||||
"from uniface import RetinaFace, draw_detections"
|
||||
"from uniface import RetinaFace, draw_detections\n",
|
||||
"from uniface.constants import RetinaFaceWeights"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -56,27 +57,29 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2025-03-16 14:32:33,786 - INFO - Initializing RetinaFace with model=retinaface_mnet_v2, conf_thresh=0.5, nms_thresh=0.4, pre_nms_topk=5000, post_nms_topk=750, dynamic_size=False, input_size=(640, 640)\n",
|
||||
"2025-03-16 14:32:33,830 - INFO - Verified model weights located at: C:\\Users\\yakhyo/.uniface/models\\retinaface_mnet_v2.onnx\n",
|
||||
"2025-03-16 14:32:33,926 - INFO - Successfully initialized the model from C:\\Users\\yakhyo/.uniface/models\\retinaface_mnet_v2.onnx\n"
|
||||
"2025-03-26 11:43:28,753 - INFO - Initializing RetinaFace with model=RetinaFaceWeights.MNET_V2, conf_thresh=0.5, nms_thresh=0.4, pre_nms_topk=5000, post_nms_topk=750, dynamic_size=False, input_size=(640, 640)\n",
|
||||
"2025-03-26 11:43:28,753 - INFO - Downloading model 'RetinaFaceWeights.MNET_V2' from https://github.com/yakhyo/uniface/releases/download/v0.1.2/retinaface_mv2.onnx\n",
|
||||
"2025-03-26 11:43:31,322 - INFO - Successfully downloaded 'RetinaFaceWeights.MNET_V2' to C:\\Users\\yakhyo\\.uniface\\models\\RetinaFaceWeights.MNET_V2.onnx\n",
|
||||
"2025-03-26 11:43:31,334 - INFO - Verified model weights located at: C:\\Users\\yakhyo/.uniface/models\\RetinaFaceWeights.MNET_V2.onnx\n",
|
||||
"2025-03-26 11:43:31,393 - INFO - Successfully initialized the model from C:\\Users\\yakhyo/.uniface/models\\RetinaFaceWeights.MNET_V2.onnx\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Initialize the RetinaFace model\n",
|
||||
"uniface_inference = RetinaFace(\n",
|
||||
" model_name=\"retinaface_mnet_v2\", # Model name\n",
|
||||
" conf_thresh=0.5, # Confidence threshold\n",
|
||||
" pre_nms_topk=5000, # Pre-NMS Top-K detections\n",
|
||||
" nms_thresh=0.4, # NMS IoU threshold\n",
|
||||
" post_nms_topk=750 # Post-NMS Top-K detections,\n",
|
||||
" model_name=RetinaFaceWeights.MNET_V2, # Model name\n",
|
||||
" conf_thresh=0.5, # Confidence threshold\n",
|
||||
" pre_nms_topk=5000, # Pre-NMS Top-K detections\n",
|
||||
" nms_thresh=0.4, # NMS IoU threshold\n",
|
||||
" post_nms_topk=750 # Post-NMS Top-K detections,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
@@ -90,7 +93,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -101,7 +104,7 @@
|
||||
"<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1024x624>"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -122,7 +125,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -167,7 +170,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -209,7 +212,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -262,7 +265,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "torch",
|
||||
"display_name": "base",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -276,7 +279,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.9"
|
||||
"version": "3.12.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
18
scripts/README.md
Normal file
18
scripts/README.md
Normal file
@@ -0,0 +1,18 @@
|
||||
### `download_model.py`
|
||||
|
||||
# Download all models
|
||||
|
||||
```bash
|
||||
python scripts/download_model.py
|
||||
```
|
||||
|
||||
# Download just RESNET18
|
||||
|
||||
```bash
|
||||
python scripts/download_model.py --model RESNET18
|
||||
```
|
||||
|
||||
### `run_inference.py`
|
||||
```bash
|
||||
python scripts/run_inference.py --image assets/test.jpg --model MNET_V2 --iterations 10
|
||||
```
|
||||
31
scripts/download_model.py
Normal file
31
scripts/download_model.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import argparse
|
||||
from uniface.constants import RetinaFaceWeights
|
||||
from uniface.model_store import verify_model_weights
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Download and verify RetinaFace model weights.")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
choices=[m.name for m in RetinaFaceWeights],
|
||||
help="Model to download (e.g. MNET_V2). If not specified, all models will be downloaded.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.model:
|
||||
weight = RetinaFaceWeights[args.model]
|
||||
print(f"📥 Downloading model: {weight.value}")
|
||||
verify_model_weights(weight.value)
|
||||
else:
|
||||
print("📥 Downloading all models...")
|
||||
for weight in RetinaFaceWeights:
|
||||
verify_model_weights(weight.value)
|
||||
|
||||
print("✅ All requested weights are ready and verified.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
61
scripts/run_inference.py
Normal file
61
scripts/run_inference.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import os
|
||||
import cv2
|
||||
import time
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from uniface import RetinaFace, draw_detections
|
||||
from uniface.constants import RetinaFaceWeights
|
||||
|
||||
|
||||
def run_inference(model, image_path, vis_threshold=0.6, save_dir="outputs"):
|
||||
"""
|
||||
Run face detection on a single image.
|
||||
|
||||
Args:
|
||||
model (RetinaFace): Initialized RetinaFace model.
|
||||
image_path (str): Path to input image.
|
||||
vis_threshold (float): Threshold for drawing detections.
|
||||
save_dir (str): Directory to save output image.
|
||||
"""
|
||||
image = cv2.imread(image_path)
|
||||
if image is None:
|
||||
print(f"❌ Error: Failed to load image from '{image_path}'")
|
||||
return
|
||||
|
||||
boxes, landmarks = model.detect(image)
|
||||
draw_detections(image, (boxes, landmarks), vis_threshold)
|
||||
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
output_path = os.path.join(save_dir, f"{os.path.splitext(os.path.basename(image_path))[0]}_out.jpg")
|
||||
cv2.imwrite(output_path, image)
|
||||
print(f"✅ Output saved at: {output_path}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Run RetinaFace inference on an image.")
|
||||
parser.add_argument("--image", type=str, required=True, help="Path to the input image")
|
||||
parser.add_argument("--model", type=str, default="MNET_V2", choices=[m.name for m in RetinaFaceWeights], help="Model variant to use")
|
||||
parser.add_argument("--threshold", type=float, default=0.6, help="Visualization confidence threshold")
|
||||
parser.add_argument("--iterations", type=int, default=1, help="Number of inference runs for benchmarking")
|
||||
parser.add_argument("--save_dir", type=str, default="outputs", help="Directory to save output images")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
model_name = RetinaFaceWeights[args.model]
|
||||
model = RetinaFace(model_name=model_name)
|
||||
|
||||
avg_time = 0
|
||||
for i in range(args.iterations):
|
||||
start = time.time()
|
||||
run_inference(model, args.image, args.threshold, args.save_dir)
|
||||
elapsed = time.time() - start
|
||||
print(f"[{i + 1}/{args.iterations}] ⏱️ Inference time: {elapsed:.4f} seconds")
|
||||
avg_time += elapsed
|
||||
|
||||
if args.iterations > 1:
|
||||
print(f"\n🔥 Average inference time over {args.iterations} runs: {avg_time / args.iterations:.4f} seconds")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
57
test.py
57
test.py
@@ -1,57 +0,0 @@
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from uniface import RetinaFace, draw_detections
|
||||
|
||||
|
||||
def run_inference(image_path, save_image=False, vis_threshold=0.6):
|
||||
"""
|
||||
Perform inference on an image, draw detections, and optionally save the output image.
|
||||
|
||||
Args:
|
||||
image_path (str): Path to the input image.
|
||||
save_image (bool): Whether to save the output image with detections.
|
||||
vis_threshold (float): Confidence threshold for displaying detections.
|
||||
"""
|
||||
# Load the image
|
||||
original_image = cv2.imread(image_path)
|
||||
if original_image is None:
|
||||
print(f"Error: Could not read image from {image_path}")
|
||||
return
|
||||
|
||||
# Perform face detection
|
||||
boxes, landmarks = retinaface_inference.detect(original_image)
|
||||
|
||||
# Draw detections on the image
|
||||
draw_detections(original_image, (boxes, landmarks), vis_threshold)
|
||||
|
||||
# Save the output image if requested
|
||||
if save_image:
|
||||
im_name = os.path.splitext(os.path.basename(image_path))[0]
|
||||
save_name = f"{im_name}_out.jpg"
|
||||
cv2.imwrite(save_name, original_image)
|
||||
print(f"Image saved at '{save_name}'")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import time
|
||||
|
||||
# Initialize and run the ONNX inference
|
||||
retinaface_inference = RetinaFace(
|
||||
model_name="retinaface_mnet_v2",
|
||||
conf_thresh=0.5,
|
||||
pre_nms_topk=5000,
|
||||
nms_thresh=0.4,
|
||||
post_nms_topk=750,
|
||||
)
|
||||
|
||||
img_path = "assets/test.jpg"
|
||||
avg = 0
|
||||
for _ in range(50):
|
||||
st = time.time()
|
||||
run_inference(img_path, save_image=True, vis_threshold=0.6)
|
||||
d = time.time() - st
|
||||
print(d)
|
||||
avg += d
|
||||
print("avg", avg / 50)
|
||||
@@ -1,6 +1,7 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
from uniface import RetinaFace
|
||||
from uniface.constants import RetinaFaceWeights
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -9,7 +10,7 @@ def retinaface_model():
|
||||
Fixture to initialize the RetinaFace model for testing.
|
||||
"""
|
||||
return RetinaFace(
|
||||
model="retinaface_mnet_v2",
|
||||
model_name=RetinaFaceWeights.MNET_V2,
|
||||
conf_thresh=0.5,
|
||||
pre_nms_topk=5000,
|
||||
nms_thresh=0.4,
|
||||
|
||||
@@ -2,25 +2,35 @@
|
||||
# Author: Yakhyokhuja Valikhujaev
|
||||
# GitHub: https://github.com/yakhyo
|
||||
|
||||
from enum import Enum
|
||||
from typing import Dict
|
||||
|
||||
|
||||
MODEL_URLS: Dict[str, str] = {
|
||||
'retinaface_mnet025': 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/retinaface_mv1_0.25.onnx',
|
||||
'retinaface_mnet050': 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/retinaface_mv1_0.50.onnx',
|
||||
'retinaface_mnet_v1': 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/retinaface_mv1.onnx',
|
||||
'retinaface_mnet_v2': 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/retinaface_mv2.onnx',
|
||||
'retinaface_r18': 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/retinaface_r18.onnx',
|
||||
'retinaface_r34': 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/retinaface_r34.onnx'
|
||||
class RetinaFaceWeights(str, Enum):
|
||||
MNET_025 = "retinaface_mnet025"
|
||||
MNET_050 = "retinaface_mnet050"
|
||||
MNET_V1 = "retinaface_mnet_v1"
|
||||
MNET_V2 = "retinaface_mnet_v2"
|
||||
RESNET18 = "retinaface_r18"
|
||||
RESNET34 = "retinaface_r34"
|
||||
|
||||
|
||||
MODEL_URLS: Dict[RetinaFaceWeights, str] = {
|
||||
RetinaFaceWeights.MNET_025: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/retinaface_mv1_0.25.onnx',
|
||||
RetinaFaceWeights.MNET_050: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/retinaface_mv1_0.50.onnx',
|
||||
RetinaFaceWeights.MNET_V1: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/retinaface_mv1.onnx',
|
||||
RetinaFaceWeights.MNET_V2: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/retinaface_mv2.onnx',
|
||||
RetinaFaceWeights.RESNET18: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/retinaface_r18.onnx',
|
||||
RetinaFaceWeights.RESNET34: 'https://github.com/yakhyo/uniface/releases/download/v0.1.2/retinaface_r34.onnx'
|
||||
}
|
||||
|
||||
MODEL_SHA256: Dict[str, str] = {
|
||||
'retinaface_mnet025': 'b7a7acab55e104dce6f32cdfff929bd83946da5cd869b9e2e9bdffafd1b7e4a5',
|
||||
'retinaface_mnet050': 'd8977186f6037999af5b4113d42ba77a84a6ab0c996b17c713cc3d53b88bfc37',
|
||||
'retinaface_mnet_v1': '75c961aaf0aff03d13c074e9ec656e5510e174454dd4964a161aab4fe5f04153',
|
||||
'retinaface_mnet_v2': '3ca44c045651cabeed1193a1fae8946ad1f3a55da8fa74b341feab5a8319f757',
|
||||
'retinaface_r18': 'e8b5ddd7d2c3c8f7c942f9f10cec09d8e319f78f09725d3f709631de34fb649d',
|
||||
'retinaface_r34': 'bd0263dc2a465d32859555cb1741f2d98991eb0053696e8ee33fec583d30e630'
|
||||
MODEL_SHA256: Dict[RetinaFaceWeights, str] = {
|
||||
RetinaFaceWeights.MNET_025: 'b7a7acab55e104dce6f32cdfff929bd83946da5cd869b9e2e9bdffafd1b7e4a5',
|
||||
RetinaFaceWeights.MNET_050: 'd8977186f6037999af5b4113d42ba77a84a6ab0c996b17c713cc3d53b88bfc37',
|
||||
RetinaFaceWeights.MNET_V1: '75c961aaf0aff03d13c074e9ec656e5510e174454dd4964a161aab4fe5f04153',
|
||||
RetinaFaceWeights.MNET_V2: '3ca44c045651cabeed1193a1fae8946ad1f3a55da8fa74b341feab5a8319f757',
|
||||
RetinaFaceWeights.RESNET18: 'e8b5ddd7d2c3c8f7c942f9f10cec09d8e319f78f09725d3f709631de34fb649d',
|
||||
RetinaFaceWeights.RESNET34: 'bd0263dc2a465d32859555cb1741f2d98991eb0053696e8ee33fec583d30e630'
|
||||
}
|
||||
|
||||
CHUNK_SIZE = 8192
|
||||
|
||||
@@ -49,9 +49,11 @@ def verify_model_weights(model_name: str, root: str = '~/.uniface/models') -> st
|
||||
Logger.error(f"No URL found for model '{model_name}'")
|
||||
raise ValueError(f"No URL found for model '{model_name}'")
|
||||
|
||||
Logger.info(f"Downloading '{model_name}' from {url}")
|
||||
Logger.info(f"Downloading model '{model_name}' from {url}")
|
||||
download_file(url, model_path)
|
||||
Logger.info(f"Successfully '{model_name}' downloaded to {model_path}")
|
||||
Logger.info(f"Successfully downloaded '{model_name}' to {os.path.normpath(model_path)}")
|
||||
else:
|
||||
Logger.info(f"Model '{model_name}' already exists at {os.path.normpath(model_path)}")
|
||||
|
||||
expected_hash = const.MODEL_SHA256.get(model_name)
|
||||
if expected_hash and not verify_file_hash(model_path, expected_hash):
|
||||
@@ -88,14 +90,7 @@ def verify_file_hash(file_path: str, expected_hash: str) -> bool:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model_names = [
|
||||
'retinaface_mnet025',
|
||||
'retinaface_mnet050',
|
||||
'retinaface_mnet_v1',
|
||||
'retinaface_mnet_v2',
|
||||
'retinaface_r18',
|
||||
'retinaface_r34'
|
||||
]
|
||||
model_names = [model.value for model in const.RetinaFaceWeights]
|
||||
|
||||
# Download each model in the list
|
||||
for model_name in model_names:
|
||||
|
||||
@@ -11,6 +11,7 @@ from typing import Tuple, List, Optional, Literal
|
||||
|
||||
from uniface.log import Logger
|
||||
from uniface.model_store import verify_model_weights
|
||||
from uniface.constants import RetinaFaceWeights
|
||||
from uniface.common import (
|
||||
nms,
|
||||
resize_image,
|
||||
@@ -25,7 +26,7 @@ class RetinaFace:
|
||||
A class for face detection using the RetinaFace model.
|
||||
|
||||
Args:
|
||||
model_name (str): Name of the model.
|
||||
model_name (RetinaFaceWeights): Name of the model. Defaults to "retinaface_mnet_v2".
|
||||
conf_thresh (float, optional): Confidence threshold for detections. Defaults to 0.5.
|
||||
nms_thresh (float, optional): Non-maximum suppression (NMS) threshold. Defaults to 0.4.
|
||||
pre_nms_topk (int, optional): Maximum number of detections considered before applying NMS. Defaults to 5000.
|
||||
@@ -50,7 +51,7 @@ class RetinaFace:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
model_name: RetinaFaceWeights = RetinaFaceWeights.MNET_V2,
|
||||
conf_thresh: float = 0.5,
|
||||
nms_thresh: float = 0.4,
|
||||
pre_nms_topk: int = 5000,
|
||||
|
||||
Reference in New Issue
Block a user