Compare commits
8 Commits
v3.0.0
...
feat/unifa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cb81d2fcf8 | ||
|
|
f0bae6dd80 | ||
|
|
eec8f99850 | ||
|
|
3682a2124f | ||
|
|
2ef6a1ebe8 | ||
|
|
78a2dba7c7 | ||
|
|
87e496d1f5 | ||
|
|
5604ebf4f1 |
BIN
.github/logos/gaze_crop.png
vendored
|
Before Width: | Height: | Size: 716 KiB |
BIN
.github/logos/gaze_org.png
vendored
|
Before Width: | Height: | Size: 673 KiB |
BIN
.github/logos/logo_preview.jpg
vendored
|
Before Width: | Height: | Size: 826 KiB |
BIN
.github/logos/logo_readme.png
vendored
|
Before Width: | Height: | Size: 563 KiB |
BIN
.github/logos/logo_web.webp
vendored
|
Before Width: | Height: | Size: 33 KiB |
|
Before Width: | Height: | Size: 427 KiB After Width: | Height: | Size: 427 KiB |
|
Before Width: | Height: | Size: 1.7 MiB After Width: | Height: | Size: 1.7 MiB |
|
Before Width: | Height: | Size: 1.8 MiB After Width: | Height: | Size: 1.8 MiB |
|
Before Width: | Height: | Size: 1.9 MiB After Width: | Height: | Size: 1.9 MiB |
|
Before Width: | Height: | Size: 872 KiB After Width: | Height: | Size: 872 KiB |
|
Before Width: | Height: | Size: 62 KiB After Width: | Height: | Size: 62 KiB |
1
.gitignore
vendored
@@ -1,5 +1,6 @@
|
||||
tmp_*
|
||||
.vscode/
|
||||
*.onnx
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
|
||||
27
README.md
@@ -14,7 +14,7 @@
|
||||
</div>
|
||||
|
||||
<div align="center">
|
||||
<img src="https://raw.githubusercontent.com/yakhyo/uniface/main/.github/logos/new/uniface_rounded_q80.webp" width="90%" alt="UniFace - All-in-One Open-Source Face Analysis Library">
|
||||
<img src="https://raw.githubusercontent.com/yakhyo/uniface/main/.github/logos/uniface_rounded_q80.webp" width="90%" alt="UniFace - All-in-One Open-Source Face Analysis Library">
|
||||
</div>
|
||||
|
||||
---
|
||||
@@ -32,6 +32,7 @@
|
||||
- **Face Parsing** — BiSeNet semantic segmentation (19 classes), XSeg face masking
|
||||
- **Gaze Estimation** — Real-time gaze direction with MobileGaze
|
||||
- **Attribute Analysis** — Age, gender, race (FairFace), and emotion
|
||||
- **Vector Indexing** — FAISS-backed embedding store for fast multi-identity search
|
||||
- **Anti-Spoofing** — Face liveness detection with MiniFASNet
|
||||
- **Face Anonymization** — 5 blur methods for privacy protection
|
||||
- **Hardware Acceleration** — ARM64 (Apple Silicon), CUDA (NVIDIA), CPU
|
||||
@@ -59,6 +60,12 @@ git clone https://github.com/yakhyo/uniface.git
|
||||
cd uniface && pip install -e .
|
||||
```
|
||||
|
||||
**FAISS vector indexing**
|
||||
|
||||
```bash
|
||||
pip install faiss-cpu # or faiss-gpu for CUDA
|
||||
```
|
||||
|
||||
**Optional dependencies**
|
||||
- Emotion model uses TorchScript and requires `torch`:
|
||||
`pip install torch` (choose the correct build for your OS/CUDA)
|
||||
@@ -165,6 +172,23 @@ Full documentation: https://yakhyo.github.io/uniface/
|
||||
| [API Reference](https://yakhyo.github.io/uniface/modules/detection/) | Detailed module documentation |
|
||||
| [Tutorials](https://yakhyo.github.io/uniface/recipes/image-pipeline/) | Step-by-step workflow examples |
|
||||
| [Guides](https://yakhyo.github.io/uniface/concepts/overview/) | Architecture and design principles |
|
||||
| [Datasets](https://yakhyo.github.io/uniface/datasets/) | Training data and evaluation benchmarks |
|
||||
|
||||
---
|
||||
|
||||
## Datasets
|
||||
|
||||
| Task | Training Dataset | Models |
|
||||
|------|-----------------|--------|
|
||||
| Detection | WIDER FACE | RetinaFace, SCRFD, YOLOv5-Face, YOLOv8-Face |
|
||||
| Recognition | MS1MV2 | MobileFace, SphereFace |
|
||||
| Recognition | WebFace600K | ArcFace |
|
||||
| Recognition | WebFace4M / 12M | AdaFace |
|
||||
| Gaze | Gaze360 | MobileGaze |
|
||||
| Parsing | CelebAMask-HQ | BiSeNet |
|
||||
| Attributes | CelebA, FairFace, AffectNet | AgeGender, FairFace, Emotion |
|
||||
|
||||
> See [Datasets documentation](https://yakhyo.github.io/uniface/datasets/) for download links, benchmarks, and details.
|
||||
|
||||
---
|
||||
|
||||
@@ -181,6 +205,7 @@ Full documentation: https://yakhyo.github.io/uniface/
|
||||
| [07_face_anonymization.ipynb](examples/07_face_anonymization.ipynb) | [](https://colab.research.google.com/github/yakhyo/uniface/blob/main/examples/07_face_anonymization.ipynb) | Privacy-preserving blur |
|
||||
| [08_gaze_estimation.ipynb](examples/08_gaze_estimation.ipynb) | [](https://colab.research.google.com/github/yakhyo/uniface/blob/main/examples/08_gaze_estimation.ipynb) | Gaze direction estimation |
|
||||
| [09_face_segmentation.ipynb](examples/09_face_segmentation.ipynb) | [](https://colab.research.google.com/github/yakhyo/uniface/blob/main/examples/09_face_segmentation.ipynb) | Face segmentation with XSeg |
|
||||
| [10_face_vector_store.ipynb](examples/10_face_vector_store.ipynb) | [](https://colab.research.google.com/github/yakhyo/uniface/blob/main/examples/10_face_vector_store.ipynb) | FAISS-backed face database |
|
||||
|
||||
---
|
||||
|
||||
|
||||
BIN
assets/einstein/img_0.png
Normal file
|
After Width: | Height: | Size: 99 KiB |
@@ -32,6 +32,10 @@ graph TB
|
||||
TRK[BYTETracker]
|
||||
end
|
||||
|
||||
subgraph Indexing
|
||||
IDX[FAISS Vector Store]
|
||||
end
|
||||
|
||||
subgraph Output
|
||||
FACE[Face Objects]
|
||||
end
|
||||
@@ -45,6 +49,7 @@ graph TB
|
||||
DET --> SPOOF
|
||||
DET --> PRIV
|
||||
DET --> TRK
|
||||
REC --> IDX
|
||||
REC --> FACE
|
||||
LMK --> FACE
|
||||
ATTR --> FACE
|
||||
@@ -57,12 +62,14 @@ graph TB
|
||||
|
||||
### 1. ONNX-First
|
||||
|
||||
All models use ONNX Runtime for inference:
|
||||
UniFace runs inference primarily via ONNX Runtime for core components:
|
||||
|
||||
- **Cross-platform**: Same models work on macOS, Linux, Windows
|
||||
- **Hardware acceleration**: Automatic selection of optimal provider
|
||||
- **Production-ready**: No Python-only dependencies for inference
|
||||
|
||||
Some optional components (e.g., emotion TorchScript, torchvision NMS) require PyTorch.
|
||||
|
||||
### 2. Minimal Dependencies
|
||||
|
||||
Core dependencies are kept minimal:
|
||||
@@ -114,6 +121,7 @@ uniface/
|
||||
├── gaze/ # Gaze estimation
|
||||
├── spoofing/ # Anti-spoofing
|
||||
├── privacy/ # Face anonymization
|
||||
├── indexing/ # Vector indexing (FAISS)
|
||||
├── types.py # Dataclasses (Face, GazeResult, etc.)
|
||||
├── constants.py # Model weights and URLs
|
||||
├── model_store.py # Model download and caching
|
||||
|
||||
324
docs/datasets.md
Normal file
@@ -0,0 +1,324 @@
|
||||
# Datasets
|
||||
|
||||
Overview of all training datasets and evaluation benchmarks used by UniFace models.
|
||||
|
||||
---
|
||||
|
||||
## Quick Reference
|
||||
|
||||
| Task | Dataset | Scale | Models |
|
||||
| ----------- | ------------------------------------------------ | ---------------------- | ------------------------------------------- |
|
||||
| Detection | [WIDER FACE](#wider-face) | 32K images | RetinaFace, SCRFD, YOLOv5-Face, YOLOv8-Face |
|
||||
| Recognition | [MS1MV2](#ms1mv2) | 5.8M images, 85.7K IDs | MobileFace, SphereFace |
|
||||
| Recognition | [WebFace600K](#webface600k) | 600K images | ArcFace |
|
||||
| Recognition | [WebFace4M / WebFace12M](#webface4m--webface12m) | 4M / 12M images | AdaFace |
|
||||
| Gaze | [Gaze360](#gaze360) | 238 subjects | MobileGaze |
|
||||
| Parsing | [CelebAMask-HQ](#celebamask-hq) | 30K images | BiSeNet |
|
||||
| Attributes | [CelebA](#celeba) | 200K images | AgeGender |
|
||||
| Attributes | [FairFace](#fairface) | Balanced demographics | FairFace |
|
||||
| Attributes | [AffectNet](#affectnet) | Emotion labels | Emotion |
|
||||
|
||||
---
|
||||
|
||||
## Training Datasets
|
||||
|
||||
### Face Detection
|
||||
|
||||
#### WIDER FACE
|
||||
|
||||
Large-scale face detection benchmark with images across 61 event categories. Contains faces with a high degree of variability in scale, pose, occlusion, expression, and illumination.
|
||||
|
||||
| Property | Value |
|
||||
| -------- | ------------------------------------------- |
|
||||
| Images | ~32,000 (train/val/test split) |
|
||||
| Faces | ~394,000 annotated |
|
||||
| Subsets | Easy, Medium, Hard |
|
||||
| Used by | RetinaFace, SCRFD, YOLOv5-Face, YOLOv8-Face |
|
||||
|
||||
!!! info "Download & References"
|
||||
**Paper**: [WIDER FACE: A Face Detection Benchmark](https://arxiv.org/abs/1511.06523)
|
||||
|
||||
**Download**: [http://shuoyang1213.me/WIDERFACE/](http://shuoyang1213.me/WIDERFACE/)
|
||||
|
||||
---
|
||||
|
||||
### Face Recognition
|
||||
|
||||
#### MS1MV2
|
||||
|
||||
Refined version of the MS-Celeb-1M dataset, cleaned by InsightFace. Widely used for training face recognition models.
|
||||
|
||||
| Property | Value |
|
||||
| ---------- | ------------------------------ |
|
||||
| Identities | 85.7K |
|
||||
| Images | 5.8M |
|
||||
| Format | Aligned and cropped to 112x112 |
|
||||
| Used by | MobileFace, SphereFace |
|
||||
|
||||
!!! info "Download"
|
||||
**Kaggle (aligned 112x112)**: [ms1m-arcface-dataset](https://www.kaggle.com/datasets/yakhyokhuja/ms1m-arcface-dataset) (from InsightFace)
|
||||
|
||||
**Training code**: [yakhyo/face-recognition](https://github.com/yakhyo/face-recognition)
|
||||
|
||||
---
|
||||
|
||||
#### WebFace600K
|
||||
|
||||
Medium-scale face recognition dataset from the WebFace series.
|
||||
|
||||
| Property | Value |
|
||||
| -------- | ------- |
|
||||
| Images | ~600K |
|
||||
| Used by | ArcFace |
|
||||
|
||||
!!! info "Source"
|
||||
**Origin**: [InsightFace](https://github.com/deepinsight/insightface)
|
||||
|
||||
**Paper**: [ArcFace: Additive Angular Margin Loss for Deep Face Recognition](https://arxiv.org/abs/1801.07698)
|
||||
|
||||
---
|
||||
|
||||
#### WebFace4M / WebFace12M
|
||||
|
||||
Large-scale face recognition datasets from the WebFace260M collection. Used for training AdaFace models with adaptive quality-aware margin.
|
||||
|
||||
| Property | WebFace4M | WebFace12M |
|
||||
| -------- | ------------- | -------------- |
|
||||
| Images | ~4M | ~12M |
|
||||
| Used by | AdaFace IR_18 | AdaFace IR_101 |
|
||||
|
||||
!!! info "Source"
|
||||
**Paper**: [AdaFace: Quality Adaptive Margin for Face Recognition](https://arxiv.org/abs/2204.00964)
|
||||
|
||||
**Original code**: [mk-minchul/AdaFace](https://github.com/mk-minchul/AdaFace)
|
||||
|
||||
---
|
||||
|
||||
#### CASIA-WebFace
|
||||
|
||||
Smaller-scale face recognition dataset suitable for academic research and lighter training runs.
|
||||
|
||||
| Property | Value |
|
||||
| ---------- | ------------------------------ |
|
||||
| Identities | 10.6K |
|
||||
| Images | 491K |
|
||||
| Format | Aligned and cropped to 112x112 |
|
||||
| Used by | Alternative training set |
|
||||
|
||||
!!! info "Download"
|
||||
**Kaggle (aligned 112x112)**: [webface-112x112](https://www.kaggle.com/datasets/yakhyokhuja/webface-112x112) (from OpenSphere)
|
||||
|
||||
---
|
||||
|
||||
#### VGGFace2
|
||||
|
||||
Large-scale dataset with wide variations in pose, age, illumination, ethnicity, and profession.
|
||||
|
||||
| Property | Value |
|
||||
| ---------- | ------------------------------ |
|
||||
| Identities | 8.6K |
|
||||
| Images | 3.1M |
|
||||
| Format | Aligned and cropped to 112x112 |
|
||||
| Used by | Alternative training set |
|
||||
|
||||
!!! info "Download"
|
||||
**Kaggle (aligned 112x112)**: [vggface2-112x112](https://www.kaggle.com/datasets/yakhyokhuja/vggface2-112x112) (from OpenSphere)
|
||||
|
||||
---
|
||||
|
||||
### Gaze Estimation
|
||||
|
||||
#### Gaze360
|
||||
|
||||
Large-scale gaze estimation dataset collected in indoor and outdoor environments with diverse head poses and wide gaze ranges (up to 360 degrees).
|
||||
|
||||
| Property | Value |
|
||||
| ----------- | --------------------- |
|
||||
| Subjects | 238 |
|
||||
| Environment | Indoor and outdoor |
|
||||
| Used by | All MobileGaze models |
|
||||
|
||||
!!! info "Download & Preprocessing"
|
||||
**Download**: [gaze360.csail.mit.edu/download.php](https://gaze360.csail.mit.edu/download.php)
|
||||
|
||||
**Preprocessing**: [GazeHub - Gaze360](https://phi-ai.buaa.edu.cn/Gazehub/3D-dataset/#gaze360)
|
||||
|
||||
!!! note "UniFace Models"
|
||||
All MobileGaze models shipped with UniFace are trained exclusively on Gaze360 for 200 epochs.
|
||||
|
||||
**Dataset structure:**
|
||||
|
||||
```
|
||||
data/
|
||||
└── Gaze360/
|
||||
├── Image/
|
||||
└── Label/
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
#### MPIIFaceGaze
|
||||
|
||||
Dataset for appearance-based gaze estimation from laptop webcam images of participants during everyday laptop usage. Supported by the gaze estimation training code but not used for the UniFace pretrained weights.
|
||||
|
||||
| Property | Value |
|
||||
| ----------- | ---------------------------------------- |
|
||||
| Subjects | 15 |
|
||||
| Environment | Everyday laptop usage |
|
||||
| Used by | Supported (not used for UniFace weights) |
|
||||
|
||||
!!! info "Download & Preprocessing"
|
||||
**Download**: [MPIIFaceGaze download page](https://www.mpi-inf.mpg.de/departments/computer-vision-and-machine-learning/research/gaze-based-human-computer-interaction/its-written-all-over-your-face-full-face-appearance-based-gaze-estimation)
|
||||
|
||||
**Preprocessing**: [GazeHub - MPIIFaceGaze](https://phi-ai.buaa.edu.cn/Gazehub/3D-dataset/#mpiifacegaze)
|
||||
|
||||
**Dataset structure:**
|
||||
|
||||
```
|
||||
data/
|
||||
└── MPIIFaceGaze/
|
||||
├── Image/
|
||||
└── Label/
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Face Parsing
|
||||
|
||||
#### CelebAMask-HQ
|
||||
|
||||
High-quality face parsing dataset with pixel-level annotations for 19 facial component classes.
|
||||
|
||||
| Property | Value |
|
||||
| ---------- | ---------------------------- |
|
||||
| Images | 30,000 |
|
||||
| Classes | 19 facial components |
|
||||
| Resolution | High quality |
|
||||
| Used by | BiSeNet (ResNet18, ResNet34) |
|
||||
|
||||
!!! info "Source"
|
||||
**GitHub**: [switchablenorms/CelebAMask-HQ](https://github.com/switchablenorms/CelebAMask-HQ)
|
||||
|
||||
**Training code**: [yakhyo/face-parsing](https://github.com/yakhyo/face-parsing)
|
||||
|
||||
**Dataset structure:**
|
||||
|
||||
```
|
||||
dataset/
|
||||
├── images/ # Input face images
|
||||
│ ├── image1.jpg
|
||||
│ └── ...
|
||||
└── labels/ # Segmentation masks
|
||||
├── image1.png
|
||||
└── ...
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Attribute Analysis
|
||||
|
||||
#### CelebA
|
||||
|
||||
Large-scale face attributes dataset widely used for training age and gender prediction models.
|
||||
|
||||
| Property | Value |
|
||||
| ---------- | -------------------- |
|
||||
| Images | ~200K |
|
||||
| Attributes | 40 binary attributes |
|
||||
| Used by | AgeGender |
|
||||
|
||||
!!! info "Reference"
|
||||
**Paper**: [Deep Learning Face Attributes in the Wild](https://arxiv.org/abs/1411.7766)
|
||||
|
||||
---
|
||||
|
||||
#### FairFace
|
||||
|
||||
Face attribute dataset designed for balanced representation across race, gender, and age groups. Provides more equitable predictions compared to imbalanced datasets.
|
||||
|
||||
| Property | Value |
|
||||
| ---------- | ----------------------------------- |
|
||||
| Attributes | Race (7), Gender (2), Age Group (9) |
|
||||
| Used by | FairFace |
|
||||
| License | CC BY 4.0 |
|
||||
|
||||
!!! info "Reference"
|
||||
**Paper**: [FairFace: Face Attribute Dataset for Balanced Race, Gender, and Age](https://arxiv.org/abs/1908.04913)
|
||||
|
||||
**ONNX inference**: [yakhyo/fairface-onnx](https://github.com/yakhyo/fairface-onnx)
|
||||
|
||||
---
|
||||
|
||||
#### AffectNet
|
||||
|
||||
Large-scale facial expression dataset for emotion recognition training.
|
||||
|
||||
| Property | Value |
|
||||
| -------- | ----------------------------------------------------------------------- |
|
||||
| Classes | 7 or 8 (Neutral, Happy, Sad, Surprise, Fear, Disgust, Angry + Contempt) |
|
||||
| Used by | Emotion (AFFECNET7, AFFECNET8) |
|
||||
|
||||
!!! info "Reference"
|
||||
**Paper**: [AffectNet: A Database for Facial Expression, Valence, and Arousal Computing in the Wild](https://ieeexplore.ieee.org/document/8013713)
|
||||
|
||||
---
|
||||
|
||||
## Evaluation Benchmarks
|
||||
|
||||
### Face Detection
|
||||
|
||||
#### WIDER FACE Validation Set
|
||||
|
||||
The standard benchmark for face detection models. Results are reported across three difficulty subsets.
|
||||
|
||||
| Subset | Criteria |
|
||||
| ------ | --------------------------------------------- |
|
||||
| Easy | Large, clear, unoccluded faces |
|
||||
| Medium | Moderate scale and occlusion |
|
||||
| Hard | Small, heavily occluded, or challenging faces |
|
||||
|
||||
See [Model Zoo - Detection](models.md#face-detection-models) for per-model accuracy on each subset.
|
||||
|
||||
---
|
||||
|
||||
### Face Recognition
|
||||
|
||||
Recognition models are evaluated across multiple benchmarks. Aligned 112x112 validation datasets are available as a single download.
|
||||
|
||||
!!! info "Download"
|
||||
**Kaggle**: [agedb-30-calfw-cplfw-lfw-aligned-112x112](https://www.kaggle.com/datasets/yakhyokhuja/agedb-30-calfw-cplfw-lfw-aligned-112x112)
|
||||
|
||||
| Benchmark | Description | Used by |
|
||||
| ------------ | ----------------------------------------------------------------- | ------------------------------- |
|
||||
| **LFW** | Labeled Faces in the Wild - standard face verification benchmark | ArcFace, MobileFace, SphereFace |
|
||||
| **CALFW** | Cross-Age LFW - face verification across age gaps | MobileFace, SphereFace |
|
||||
| **CPLFW** | Cross-Pose LFW - face verification across pose variations | MobileFace, SphereFace |
|
||||
| **AgeDB-30** | Age database with 30-year age gaps | ArcFace, MobileFace, SphereFace |
|
||||
| **CFP-FP** | Celebrities in Frontal-Profile - frontal vs. profile verification | ArcFace |
|
||||
| **IJB-B** | IARPA Janus Benchmark B - TAR@FAR=0.01% | AdaFace |
|
||||
| **IJB-C** | IARPA Janus Benchmark C - TAR@FAR=1e-4 | AdaFace, ArcFace |
|
||||
|
||||
See [Model Zoo - Recognition](models.md#face-recognition-models) for per-model accuracy on each benchmark.
|
||||
|
||||
---
|
||||
|
||||
### Gaze Estimation
|
||||
|
||||
| Benchmark | Metric | Description |
|
||||
| -------------------- | ------------- | -------------------------------------------- |
|
||||
| **Gaze360 test set** | MAE (degrees) | Mean Absolute Error in gaze angle prediction |
|
||||
|
||||
See [Model Zoo - Gaze](models.md#gaze-estimation-models) for per-model MAE scores.
|
||||
|
||||
---
|
||||
|
||||
## Training Repositories
|
||||
|
||||
For training your own models or reproducing results, see the following repositories:
|
||||
|
||||
| Task | Repository | Datasets Supported |
|
||||
| ----------- | ------------------------------------------------------------------------- | ------------------------------- |
|
||||
| Detection | [yakhyo/retinaface-pytorch](https://github.com/yakhyo/retinaface-pytorch) | WIDER FACE |
|
||||
| Recognition | [yakhyo/face-recognition](https://github.com/yakhyo/face-recognition) | MS1MV2, CASIA-WebFace, VGGFace2 |
|
||||
| Gaze | [yakhyo/gaze-estimation](https://github.com/yakhyo/gaze-estimation) | Gaze360, MPIIFaceGaze |
|
||||
| Parsing | [yakhyo/face-parsing](https://github.com/yakhyo/face-parsing) | CelebAMask-HQ |
|
||||
@@ -20,7 +20,7 @@ template: home.html
|
||||
[](https://www.kaggle.com/yakhyokhuja/code)
|
||||
[](https://discord.gg/wdzrjr7R5j)
|
||||
|
||||
<!-- <img src="https://raw.githubusercontent.com/yakhyo/uniface/main/.github/logos/new/uniface_rounded_q80.webp" alt="UniFace - All-in-One Open-Source Face Analysis Library" style="max-width: 70%; margin: 1rem 0;"> -->
|
||||
<!-- <img src="https://raw.githubusercontent.com/yakhyo/uniface/main/.github/logos/uniface_rounded_q80.webp" alt="UniFace - All-in-One Open-Source Face Analysis Library" style="max-width: 70%; margin: 1rem 0;"> -->
|
||||
|
||||
[Get Started](quickstart.md){ .md-button .md-button--primary }
|
||||
[View on GitHub](https://github.com/yakhyo/uniface){ .md-button }
|
||||
@@ -74,31 +74,35 @@ Face liveness detection with MiniFASNet to prevent fraud.
|
||||
Face anonymization with 5 blur methods for privacy protection.
|
||||
</div>
|
||||
|
||||
<div class="feature-card" markdown>
|
||||
### :material-database-search: Vector Indexing
|
||||
FAISS-backed embedding store for fast multi-identity face search.
|
||||
</div>
|
||||
|
||||
</div>
|
||||
|
||||
---
|
||||
|
||||
## Installation
|
||||
|
||||
=== "Standard"
|
||||
UniFace runs inference primarily via **ONNX Runtime**; some optional components (e.g., emotion TorchScript, torchvision NMS) require **PyTorch**.
|
||||
|
||||
```bash
|
||||
pip install uniface
|
||||
```
|
||||
**Standard**
|
||||
```bash
|
||||
pip install uniface
|
||||
```
|
||||
|
||||
=== "GPU (CUDA)"
|
||||
**GPU (CUDA)**
|
||||
```bash
|
||||
pip install uniface[gpu]
|
||||
```
|
||||
|
||||
```bash
|
||||
pip install uniface[gpu]
|
||||
```
|
||||
|
||||
=== "From Source"
|
||||
|
||||
```bash
|
||||
git clone https://github.com/yakhyo/uniface.git
|
||||
cd uniface
|
||||
pip install -e .
|
||||
```
|
||||
**From Source**
|
||||
```bash
|
||||
git clone https://github.com/yakhyo/uniface.git
|
||||
cd uniface
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -55,11 +55,10 @@ pip install uniface[gpu]
|
||||
|
||||
**Requirements:**
|
||||
|
||||
- CUDA 11.x or 12.x
|
||||
- cuDNN 8.x
|
||||
- `uniface[gpu]` automatically installs `onnxruntime-gpu`. Requirements depend on the ORT version and execution provider.
|
||||
|
||||
!!! info "CUDA Compatibility"
|
||||
See [ONNX Runtime GPU requirements](https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html) for detailed compatibility matrix.
|
||||
See the [ONNX Runtime GPU compatibility matrix](https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html) for matching CUDA and cuDNN versions.
|
||||
|
||||
Verify GPU installation:
|
||||
|
||||
@@ -71,6 +70,19 @@ print("Available providers:", ort.get_available_providers())
|
||||
|
||||
---
|
||||
|
||||
### FAISS Vector Indexing
|
||||
|
||||
For fast multi-identity face search using a FAISS index:
|
||||
|
||||
```bash
|
||||
pip install faiss-cpu # CPU
|
||||
pip install faiss-gpu # NVIDIA GPU (CUDA)
|
||||
```
|
||||
|
||||
See the [Indexing module](modules/indexing.md) for usage.
|
||||
|
||||
---
|
||||
|
||||
### CPU-Only (All Platforms)
|
||||
|
||||
```bash
|
||||
@@ -107,12 +119,20 @@ UniFace has minimal dependencies:
|
||||
|---------|---------|
|
||||
| `numpy` | Array operations |
|
||||
| `opencv-python` | Image processing |
|
||||
| `onnx` | ONNX model format support |
|
||||
| `onnxruntime` | Model inference |
|
||||
| `scikit-image` | Geometric transforms |
|
||||
| `requests` | Model download |
|
||||
| `tqdm` | Progress bars |
|
||||
|
||||
**Optional:**
|
||||
|
||||
| Package | Install extra | Purpose |
|
||||
|---------|---------------|---------|
|
||||
| `faiss-cpu` / `faiss-gpu` | `pip install faiss-cpu` | FAISS vector indexing |
|
||||
| `onnxruntime-gpu` | `uniface[gpu]` | CUDA acceleration |
|
||||
| `torch` | `pip install torch` | Emotion model uses TorchScript |
|
||||
| `torchvision` | `pip install torchvision` | Faster NMS for YOLO detectors |
|
||||
|
||||
---
|
||||
|
||||
## Verify Installation
|
||||
|
||||
@@ -8,7 +8,7 @@ Complete guide to all available models and their performance characteristics.
|
||||
|
||||
### RetinaFace Family
|
||||
|
||||
RetinaFace models are trained on the WIDER FACE dataset.
|
||||
RetinaFace models are trained on the [WIDER FACE](datasets.md#wider-face) dataset.
|
||||
|
||||
| Model Name | Params | Size | Easy | Medium | Hard |
|
||||
| -------------- | ------ | ----- | ------ | ------ | ------ |
|
||||
@@ -28,7 +28,7 @@ RetinaFace models are trained on the WIDER FACE dataset.
|
||||
|
||||
### SCRFD Family
|
||||
|
||||
SCRFD (Sample and Computation Redistribution for Efficient Face Detection) models trained on WIDER FACE dataset.
|
||||
SCRFD (Sample and Computation Redistribution for Efficient Face Detection) models trained on [WIDER FACE](datasets.md#wider-face) dataset.
|
||||
|
||||
| Model Name | Params | Size | Easy | Medium | Hard |
|
||||
| ---------------- | ------ | ----- | ------ | ------ | ------ |
|
||||
@@ -44,7 +44,7 @@ SCRFD (Sample and Computation Redistribution for Efficient Face Detection) model
|
||||
|
||||
### YOLOv5-Face Family
|
||||
|
||||
YOLOv5-Face models provide detection with 5-point facial landmarks, trained on WIDER FACE dataset.
|
||||
YOLOv5-Face models provide detection with 5-point facial landmarks, trained on [WIDER FACE](datasets.md#wider-face) dataset.
|
||||
|
||||
| Model Name | Size | Easy | Medium | Hard |
|
||||
| -------------- | ---- | ------ | ------ | ------ |
|
||||
@@ -93,7 +93,7 @@ Face recognition using adaptive margin based on image quality.
|
||||
| `IR_101` | IR-101 | WebFace12M | 249 MB | - | 97.66% |
|
||||
|
||||
!!! info "Training Data & Accuracy"
|
||||
**Dataset**: WebFace4M (4M images) / WebFace12M (12M images)
|
||||
**Dataset**: [WebFace4M / WebFace12M](datasets.md#webface4m--webface12m) (4M / 12M images)
|
||||
|
||||
**Accuracy**: IJB-B and IJB-C benchmarks, TAR@FAR=0.01%
|
||||
|
||||
@@ -113,7 +113,7 @@ Face recognition using additive angular margin loss.
|
||||
| `RESNET` | ResNet50 | 43.6M | 166MB | 99.83% | 99.33% | 98.23% | 97.25% |
|
||||
|
||||
!!! info "Training Data"
|
||||
**Dataset**: Trained on WebFace600K (600K images)
|
||||
**Dataset**: Trained on [WebFace600K](datasets.md#webface600k) (600K images)
|
||||
|
||||
**Accuracy**: IJB-C accuracy reported as TAR@FAR=1e-4
|
||||
|
||||
@@ -131,7 +131,7 @@ Lightweight face recognition models with MobileNet backbones.
|
||||
| `MNET_V3_LARGE` | MobileNetV3-L | 3.52M | 10MB | 99.53% | 94.56% | 86.79% | 95.13% |
|
||||
|
||||
!!! info "Training Data"
|
||||
**Dataset**: Trained on MS1M-V2 (5.8M images, 85K identities)
|
||||
**Dataset**: Trained on [MS1MV2](datasets.md#ms1mv2) (5.8M images, 85K identities)
|
||||
|
||||
**Accuracy**: Evaluated on LFW, CALFW, CPLFW, and AgeDB-30 benchmarks
|
||||
|
||||
@@ -147,7 +147,7 @@ Face recognition using angular softmax loss.
|
||||
| `SPHERE36` | Sphere36 | 34.6M | 92MB | 99.72% | 95.64% | 89.92% | 96.83% |
|
||||
|
||||
!!! info "Training Data"
|
||||
**Dataset**: Trained on MS1M-V2 (5.8M images, 85K identities)
|
||||
**Dataset**: Trained on [MS1MV2](datasets.md#ms1mv2) (5.8M images, 85K identities)
|
||||
|
||||
**Accuracy**: Evaluated on LFW, CALFW, CPLFW, and AgeDB-30 benchmarks
|
||||
|
||||
@@ -187,7 +187,7 @@ Facial landmark localization model.
|
||||
| `AgeGender` | Age, Gender | 2.1M | 8MB |
|
||||
|
||||
!!! info "Training Data"
|
||||
**Dataset**: Trained on CelebA
|
||||
**Dataset**: Trained on [CelebA](datasets.md#celeba)
|
||||
|
||||
!!! warning "Accuracy Note"
|
||||
Accuracy varies by demographic and image quality. Test on your specific use case.
|
||||
@@ -201,7 +201,7 @@ Facial landmark localization model.
|
||||
| `FairFace` | Race, Gender, Age Group | - | 44MB |
|
||||
|
||||
!!! info "Training Data"
|
||||
**Dataset**: Trained on FairFace dataset with balanced demographics
|
||||
**Dataset**: Trained on [FairFace](datasets.md#fairface) dataset with balanced demographics
|
||||
|
||||
!!! tip "Equitable Predictions"
|
||||
FairFace provides more equitable predictions across different racial and gender groups.
|
||||
@@ -224,7 +224,7 @@ Facial landmark localization model.
|
||||
**Classes (8)**: Above + Contempt
|
||||
|
||||
!!! info "Training Data"
|
||||
**Dataset**: Trained on AffectNet
|
||||
**Dataset**: Trained on [AffectNet](datasets.md#affectnet)
|
||||
|
||||
!!! note "Accuracy Note"
|
||||
Emotion detection accuracy depends heavily on facial expression clarity and cultural context.
|
||||
@@ -235,7 +235,7 @@ Facial landmark localization model.
|
||||
|
||||
### MobileGaze Family
|
||||
|
||||
Gaze direction prediction models trained on Gaze360 dataset. Returns pitch (vertical) and yaw (horizontal) angles in radians.
|
||||
Gaze direction prediction models trained on [Gaze360](datasets.md#gaze360) dataset. Returns pitch (vertical) and yaw (horizontal) angles in radians.
|
||||
|
||||
| Model Name | Params | Size | MAE* |
|
||||
| -------------- | ------ | ------- | ----- |
|
||||
@@ -248,7 +248,7 @@ Gaze direction prediction models trained on Gaze360 dataset. Returns pitch (vert
|
||||
*MAE (Mean Absolute Error) in degrees on Gaze360 test set - lower is better
|
||||
|
||||
!!! info "Training Data"
|
||||
**Dataset**: Trained on Gaze360 (indoor/outdoor scenes with diverse head poses)
|
||||
**Dataset**: Trained on [Gaze360](datasets.md#gaze360) (indoor/outdoor scenes with diverse head poses)
|
||||
|
||||
**Training**: 200 epochs with classification-based approach (binned angles)
|
||||
|
||||
@@ -269,7 +269,7 @@ BiSeNet (Bilateral Segmentation Network) models for semantic face parsing. Segme
|
||||
| `RESNET34` | 24.1M | 89.2 MB | 19 |
|
||||
|
||||
!!! info "Training Data"
|
||||
**Dataset**: Trained on CelebAMask-HQ
|
||||
**Dataset**: Trained on [CelebAMask-HQ](datasets.md#celebamask-hq)
|
||||
|
||||
**Architecture**: BiSeNet with ResNet backbone
|
||||
|
||||
|
||||
172
docs/modules/indexing.md
Normal file
@@ -0,0 +1,172 @@
|
||||
# Indexing
|
||||
|
||||
FAISS-backed vector store for fast similarity search over embeddings.
|
||||
|
||||
!!! info "Optional dependency"
|
||||
```bash
|
||||
pip install faiss-cpu
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## FAISS
|
||||
|
||||
```python
|
||||
from uniface.indexing import FAISS
|
||||
```
|
||||
|
||||
A thin wrapper around a FAISS `IndexFlatIP` (inner-product) index. Vectors
|
||||
**must** be L2-normalised before adding so that inner product equals cosine
|
||||
similarity. The store does not normalise internally.
|
||||
|
||||
Each vector is paired with a metadata `dict` that can carry any
|
||||
JSON-serialisable payload (person ID, name, source path, etc.).
|
||||
|
||||
### Constructor
|
||||
|
||||
```python
|
||||
store = FAISS(embedding_size=512, db_path="./vector_index")
|
||||
```
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `embedding_size` | `int` | `512` | Dimension of embedding vectors |
|
||||
| `db_path` | `str` | `"./vector_index"` | Directory for persisting index and metadata |
|
||||
|
||||
---
|
||||
|
||||
### Methods
|
||||
|
||||
#### `add(embedding, metadata)`
|
||||
|
||||
Add a single embedding with associated metadata.
|
||||
|
||||
```python
|
||||
store.add(embedding, {"person_id": "alice", "source": "photo.jpg"})
|
||||
```
|
||||
|
||||
| Parameter | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `embedding` | `np.ndarray` | L2-normalised embedding vector |
|
||||
| `metadata` | `dict[str, Any]` | Arbitrary JSON-serialisable key-value pairs |
|
||||
|
||||
---
|
||||
|
||||
#### `search(embedding, threshold=0.4)`
|
||||
|
||||
Find the closest match for a query embedding.
|
||||
|
||||
```python
|
||||
result, similarity = store.search(query_embedding, threshold=0.4)
|
||||
if result:
|
||||
print(result["person_id"], similarity)
|
||||
```
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `embedding` | `np.ndarray` | — | L2-normalised query vector |
|
||||
| `threshold` | `float` | `0.4` | Minimum cosine similarity to accept a match |
|
||||
|
||||
**Returns:** `(metadata, similarity)` if a match is found, or `(None, similarity)` when below threshold or the index is empty.
|
||||
|
||||
---
|
||||
|
||||
#### `remove(key, value)`
|
||||
|
||||
Remove all entries where `metadata[key] == value` and rebuild the index.
|
||||
|
||||
```python
|
||||
removed = store.remove("person_id", "bob")
|
||||
print(f"Removed {removed} entries")
|
||||
```
|
||||
|
||||
| Parameter | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `key` | `str` | Metadata key to match |
|
||||
| `value` | `Any` | Value to match |
|
||||
|
||||
**Returns:** Number of entries removed.
|
||||
|
||||
---
|
||||
|
||||
#### `save()`
|
||||
|
||||
Persist the FAISS index and metadata to disk.
|
||||
|
||||
```python
|
||||
store.save()
|
||||
```
|
||||
|
||||
Writes two files to `db_path`:
|
||||
|
||||
- `faiss_index.bin` — binary FAISS index
|
||||
- `metadata.json` — JSON array of metadata dicts
|
||||
|
||||
---
|
||||
|
||||
#### `load()`
|
||||
|
||||
Load a previously saved index and metadata.
|
||||
|
||||
```python
|
||||
store = FAISS(db_path="./vector_index")
|
||||
loaded = store.load() # True if files exist
|
||||
```
|
||||
|
||||
**Returns:** `True` if loaded successfully, `False` if files are missing.
|
||||
|
||||
**Raises:** `RuntimeError` if files exist but cannot be read.
|
||||
|
||||
---
|
||||
|
||||
### Properties
|
||||
|
||||
| Property | Type | Description |
|
||||
|----------|------|-------------|
|
||||
| `size` | `int` | Number of vectors in the index |
|
||||
| `len(store)` | `int` | Same as `size` |
|
||||
|
||||
---
|
||||
|
||||
## Example: End-to-End
|
||||
|
||||
```python
|
||||
import cv2
|
||||
from uniface.detection import RetinaFace
|
||||
from uniface.recognition import ArcFace
|
||||
from uniface.indexing import FAISS
|
||||
|
||||
detector = RetinaFace()
|
||||
recognizer = ArcFace()
|
||||
|
||||
# Build
|
||||
store = FAISS(db_path="./my_index")
|
||||
|
||||
image = cv2.imread("alice.jpg")
|
||||
faces = detector.detect(image)
|
||||
embedding = recognizer.get_normalized_embedding(image, faces[0].landmarks)
|
||||
store.add(embedding, {"person_id": "alice"})
|
||||
store.save()
|
||||
|
||||
# Search
|
||||
store2 = FAISS(db_path="./my_index")
|
||||
store2.load()
|
||||
|
||||
query = cv2.imread("unknown.jpg")
|
||||
faces = detector.detect(query)
|
||||
emb = recognizer.get_normalized_embedding(query, faces[0].landmarks)
|
||||
|
||||
result, sim = store2.search(emb)
|
||||
if result:
|
||||
print(f"Matched: {result['person_id']} (similarity: {sim:.3f})")
|
||||
else:
|
||||
print(f"No match (similarity: {sim:.3f})")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## See Also
|
||||
|
||||
- [Face Search Recipe](../recipes/face-search.md) - Building and querying indexes
|
||||
- [Recognition Module](recognition.md) - Embedding extraction
|
||||
- [Thresholds Guide](../concepts/thresholds-calibration.md) - Tuning similarity thresholds
|
||||
@@ -17,6 +17,7 @@ Run UniFace examples directly in your browser with Google Colab, or download and
|
||||
| [Face Anonymization](https://github.com/yakhyo/uniface/blob/main/examples/07_face_anonymization.ipynb) | [](https://colab.research.google.com/github/yakhyo/uniface/blob/main/examples/07_face_anonymization.ipynb) | Privacy-preserving blur |
|
||||
| [Gaze Estimation](https://github.com/yakhyo/uniface/blob/main/examples/08_gaze_estimation.ipynb) | [](https://colab.research.google.com/github/yakhyo/uniface/blob/main/examples/08_gaze_estimation.ipynb) | Gaze direction estimation |
|
||||
| [Face Segmentation](https://github.com/yakhyo/uniface/blob/main/examples/09_face_segmentation.ipynb) | [](https://colab.research.google.com/github/yakhyo/uniface/blob/main/examples/09_face_segmentation.ipynb) | Face segmentation with XSeg |
|
||||
| [Face Vector Store](https://github.com/yakhyo/uniface/blob/main/examples/10_face_vector_store.ipynb) | [](https://colab.research.google.com/github/yakhyo/uniface/blob/main/examples/10_face_vector_store.ipynb) | FAISS-backed face database |
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -475,6 +475,7 @@ from uniface.privacy import BlurFace
|
||||
from uniface.spoofing import MiniFASNet
|
||||
from uniface.tracking import BYTETracker
|
||||
from uniface.analyzer import FaceAnalyzer
|
||||
from uniface.indexing import FAISS # pip install faiss-cpu
|
||||
from uniface.draw import draw_detections, draw_tracks
|
||||
```
|
||||
|
||||
|
||||
@@ -1,179 +1,166 @@
|
||||
# Face Search
|
||||
|
||||
Build a face search system for finding people in images.
|
||||
Find and identify people in images and video streams.
|
||||
|
||||
!!! note "Work in Progress"
|
||||
This page contains example code patterns. Test thoroughly before using in production.
|
||||
UniFace supports two search approaches:
|
||||
|
||||
| Approach | Use case | Tool |
|
||||
| -------------------- | ------------------------------------------------ | ----------------------- |
|
||||
| **Reference search** | "Is this specific person in the video?" | `tools/search.py` |
|
||||
| **Vector search** | "Who is this?" against a database of known faces | `tools/faiss_search.py` |
|
||||
|
||||
---
|
||||
|
||||
## Basic Face Database
|
||||
## Reference Search (single image)
|
||||
|
||||
Compare every detected face against a single reference photo:
|
||||
|
||||
```python
|
||||
import cv2
|
||||
import numpy as np
|
||||
from uniface.detection import RetinaFace
|
||||
from uniface.recognition import ArcFace
|
||||
from uniface.face_utils import compute_similarity
|
||||
|
||||
detector = RetinaFace()
|
||||
recognizer = ArcFace()
|
||||
|
||||
ref_image = cv2.imread("reference.jpg")
|
||||
ref_faces = detector.detect(ref_image)
|
||||
ref_embedding = recognizer.get_normalized_embedding(ref_image, ref_faces[0].landmarks)
|
||||
|
||||
query_image = cv2.imread("group_photo.jpg")
|
||||
faces = detector.detect(query_image)
|
||||
|
||||
for face in faces:
|
||||
embedding = recognizer.get_normalized_embedding(query_image, face.landmarks)
|
||||
sim = compute_similarity(ref_embedding, embedding)
|
||||
|
||||
label = f"Match ({sim:.2f})" if sim > 0.4 else f"Unknown ({sim:.2f})"
|
||||
print(label)
|
||||
```
|
||||
|
||||
**CLI tool:**
|
||||
|
||||
```bash
|
||||
python tools/search.py --reference ref.jpg --source video.mp4
|
||||
python tools/search.py --reference ref.jpg --source 0 # webcam
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Vector Search (FAISS index)
|
||||
|
||||
For identifying faces against a database of many known people, use the
|
||||
[`FAISS`](../modules/indexing.md) vector store.
|
||||
|
||||
!!! info "Install extra"
|
||||
`bash
|
||||
pip install faiss-cpu
|
||||
`
|
||||
|
||||
### Build an index
|
||||
|
||||
Organise face images in person sub-folders:
|
||||
|
||||
```
|
||||
dataset/
|
||||
├── alice/
|
||||
│ ├── 001.jpg
|
||||
│ └── 002.jpg
|
||||
├── bob/
|
||||
│ └── 001.jpg
|
||||
└── charlie/
|
||||
├── 001.jpg
|
||||
└── 002.jpg
|
||||
```
|
||||
|
||||
```python
|
||||
import cv2
|
||||
from pathlib import Path
|
||||
from uniface.detection import RetinaFace
|
||||
from uniface.recognition import ArcFace
|
||||
from uniface.indexing import FAISS
|
||||
|
||||
class FaceDatabase:
|
||||
def __init__(self):
|
||||
self.detector = RetinaFace()
|
||||
self.recognizer = ArcFace()
|
||||
self.embeddings = {}
|
||||
detector = RetinaFace()
|
||||
recognizer = ArcFace()
|
||||
store = FAISS(db_path="./my_index")
|
||||
|
||||
def add_face(self, person_id, image):
|
||||
"""Add a face to the database."""
|
||||
faces = self.detector.detect(image)
|
||||
if not faces:
|
||||
raise ValueError(f"No face found for {person_id}")
|
||||
for person_dir in sorted(Path("dataset").iterdir()):
|
||||
if not person_dir.is_dir():
|
||||
continue
|
||||
for img_path in person_dir.glob("*.jpg"):
|
||||
image = cv2.imread(str(img_path))
|
||||
faces = detector.detect(image)
|
||||
if faces:
|
||||
emb = recognizer.get_normalized_embedding(image, faces[0].landmarks)
|
||||
store.add(emb, {"person_id": person_dir.name, "source": str(img_path)})
|
||||
|
||||
face = max(faces, key=lambda f: f.confidence)
|
||||
embedding = self.recognizer.get_normalized_embedding(image, face.landmarks)
|
||||
self.embeddings[person_id] = embedding
|
||||
return True
|
||||
|
||||
def search(self, image, threshold=0.6):
|
||||
"""Search for faces in an image."""
|
||||
faces = self.detector.detect(image)
|
||||
results = []
|
||||
|
||||
for face in faces:
|
||||
embedding = self.recognizer.get_normalized_embedding(image, face.landmarks)
|
||||
|
||||
best_match = None
|
||||
best_similarity = -1
|
||||
|
||||
for person_id, db_embedding in self.embeddings.items():
|
||||
similarity = np.dot(embedding, db_embedding.T)[0][0]
|
||||
if similarity > best_similarity:
|
||||
best_similarity = similarity
|
||||
best_match = person_id
|
||||
|
||||
results.append({
|
||||
'bbox': face.bbox,
|
||||
'match': best_match if best_similarity >= threshold else None,
|
||||
'similarity': best_similarity
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
def save(self, path):
|
||||
"""Save database to file."""
|
||||
np.savez(path, embeddings=dict(self.embeddings))
|
||||
|
||||
def load(self, path):
|
||||
"""Load database from file."""
|
||||
data = np.load(path, allow_pickle=True)
|
||||
self.embeddings = data['embeddings'].item()
|
||||
|
||||
# Usage
|
||||
db = FaceDatabase()
|
||||
|
||||
# Add faces
|
||||
for image_path in Path("known_faces/").glob("*.jpg"):
|
||||
person_id = image_path.stem
|
||||
image = cv2.imread(str(image_path))
|
||||
try:
|
||||
db.add_face(person_id, image)
|
||||
print(f"Added: {person_id}")
|
||||
except ValueError as e:
|
||||
print(f"Skipped: {e}")
|
||||
|
||||
# Save database
|
||||
db.save("face_database.npz")
|
||||
|
||||
# Search
|
||||
query_image = cv2.imread("group_photo.jpg")
|
||||
results = db.search(query_image)
|
||||
|
||||
for r in results:
|
||||
if r['match']:
|
||||
print(f"Found: {r['match']} (similarity: {r['similarity']:.3f})")
|
||||
store.save()
|
||||
print(f"Index saved: {store}")
|
||||
```
|
||||
|
||||
---
|
||||
**CLI tool:**
|
||||
|
||||
## Visualization
|
||||
```bash
|
||||
python tools/faiss_search.py build --faces-dir dataset/ --db-path ./my_index
|
||||
```
|
||||
|
||||
### Search against the index
|
||||
|
||||
```python
|
||||
import cv2
|
||||
from uniface.detection import RetinaFace
|
||||
from uniface.recognition import ArcFace
|
||||
from uniface.indexing import FAISS
|
||||
|
||||
def visualize_search_results(image, results):
|
||||
"""Draw search results on image."""
|
||||
for r in results:
|
||||
x1, y1, x2, y2 = map(int, r['bbox'])
|
||||
detector = RetinaFace()
|
||||
recognizer = ArcFace()
|
||||
|
||||
if r['match']:
|
||||
color = (0, 255, 0) # Green for match
|
||||
label = f"{r['match']} ({r['similarity']:.2f})"
|
||||
else:
|
||||
color = (0, 0, 255) # Red for unknown
|
||||
label = f"Unknown ({r['similarity']:.2f})"
|
||||
store = FAISS(db_path="./my_index")
|
||||
store.load()
|
||||
|
||||
cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
|
||||
cv2.putText(image, label, (x1, y1 - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
|
||||
image = cv2.imread("query.jpg")
|
||||
faces = detector.detect(image)
|
||||
|
||||
return image
|
||||
for face in faces:
|
||||
embedding = recognizer.get_normalized_embedding(image, face.landmarks)
|
||||
result, similarity = store.search(embedding, threshold=0.4)
|
||||
|
||||
# Usage
|
||||
results = db.search(image)
|
||||
annotated = visualize_search_results(image.copy(), results)
|
||||
cv2.imwrite("search_result.jpg", annotated)
|
||||
if result:
|
||||
print(f"Matched: {result['person_id']} ({similarity:.2f})")
|
||||
else:
|
||||
print(f"Unknown ({similarity:.2f})")
|
||||
```
|
||||
|
||||
---
|
||||
**CLI tool:**
|
||||
|
||||
## Real-Time Search
|
||||
```bash
|
||||
python tools/faiss_search.py run --db-path ./my_index --source video.mp4
|
||||
python tools/faiss_search.py run --db-path ./my_index --source 0 # webcam
|
||||
```
|
||||
|
||||
### Manage the index
|
||||
|
||||
```python
|
||||
import cv2
|
||||
from uniface.indexing import FAISS
|
||||
|
||||
def realtime_search(db):
|
||||
"""Real-time face search from webcam."""
|
||||
cap = cv2.VideoCapture(0)
|
||||
store = FAISS(db_path="./my_index")
|
||||
store.load()
|
||||
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
print(f"Total vectors: {len(store)}")
|
||||
|
||||
results = db.search(frame, threshold=0.5)
|
||||
removed = store.remove("person_id", "bob")
|
||||
print(f"Removed {removed} entries")
|
||||
|
||||
for r in results:
|
||||
x1, y1, x2, y2 = map(int, r['bbox'])
|
||||
|
||||
if r['match']:
|
||||
color = (0, 255, 0)
|
||||
label = r['match']
|
||||
else:
|
||||
color = (0, 0, 255)
|
||||
label = "Unknown"
|
||||
|
||||
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
|
||||
cv2.putText(frame, label, (x1, y1 - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
|
||||
|
||||
cv2.imshow("Face Search", frame)
|
||||
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||
break
|
||||
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
# Usage
|
||||
db = FaceDatabase()
|
||||
db.load("face_database.npz")
|
||||
realtime_search(db)
|
||||
store.save()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## See Also
|
||||
|
||||
- [Indexing Module](../modules/indexing.md) - Full `FAISS` API reference
|
||||
- [Recognition Module](../modules/recognition.md) - Face recognition details
|
||||
- [Batch Processing](batch-processing.md) - Process multiple files
|
||||
- [Video & Webcam](video-webcam.md) - Real-time processing
|
||||
- [Concepts: Thresholds](../concepts/thresholds-calibration.md) - Tuning similarity thresholds
|
||||
|
||||
@@ -51,7 +51,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2.0.0\n"
|
||||
"3.0.0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -62,7 +62,7 @@
|
||||
"\n",
|
||||
"import uniface\n",
|
||||
"from uniface.detection import RetinaFace\n",
|
||||
"from uniface.visualization import draw_detections\n",
|
||||
"from uniface.draw import draw_detections\n",
|
||||
"\n",
|
||||
"print(uniface.__version__)"
|
||||
]
|
||||
@@ -162,7 +162,7 @@
|
||||
"landmarks = [f.landmarks for f in faces]\n",
|
||||
"\n",
|
||||
"# Draw detections\n",
|
||||
"draw_detections(image=image, bboxes=bboxes, scores=scores, landmarks=landmarks, vis_threshold=0.6, fancy_bbox=True)\n",
|
||||
"draw_detections(image=image, bboxes=bboxes, scores=scores, landmarks=landmarks, vis_threshold=0.6, corner_bbox=True)\n",
|
||||
"\n",
|
||||
"# Display result\n",
|
||||
"output_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
|
||||
@@ -214,7 +214,7 @@
|
||||
"scores = [f.confidence for f in faces]\n",
|
||||
"landmarks = [f.landmarks for f in faces]\n",
|
||||
"\n",
|
||||
"draw_detections(image=image, bboxes=bboxes, scores=scores, landmarks=landmarks, vis_threshold=0.6, fancy_bbox=True)\n",
|
||||
"draw_detections(image=image, bboxes=bboxes, scores=scores, landmarks=landmarks, vis_threshold=0.6, corner_bbox=True)\n",
|
||||
"\n",
|
||||
"output_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
|
||||
"display.display(Image.fromarray(output_image))"
|
||||
@@ -261,7 +261,7 @@
|
||||
"scores = [f.confidence for f in faces]\n",
|
||||
"landmarks = [f.landmarks for f in faces]\n",
|
||||
"\n",
|
||||
"draw_detections(image=image, bboxes=bboxes, scores=scores, landmarks=landmarks, vis_threshold=0.6, fancy_bbox=True)\n",
|
||||
"draw_detections(image=image, bboxes=bboxes, scores=scores, landmarks=landmarks, vis_threshold=0.6, corner_bbox=True)\n",
|
||||
"\n",
|
||||
"output_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
|
||||
"display.display(Image.fromarray(output_image))"
|
||||
|
||||
@@ -55,7 +55,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2.0.0\n"
|
||||
"3.0.0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -67,7 +67,7 @@
|
||||
"import uniface\n",
|
||||
"from uniface.detection import RetinaFace\n",
|
||||
"from uniface.face_utils import face_alignment\n",
|
||||
"from uniface.visualization import draw_detections\n",
|
||||
"from uniface.draw import draw_detections\n",
|
||||
"\n",
|
||||
"print(uniface.__version__)"
|
||||
]
|
||||
@@ -142,7 +142,7 @@
|
||||
" bboxes = [f.bbox for f in faces]\n",
|
||||
" scores = [f.confidence for f in faces]\n",
|
||||
" landmarks = [f.landmarks for f in faces]\n",
|
||||
" draw_detections(image=bbox_image, bboxes=bboxes, scores=scores, landmarks=landmarks, vis_threshold=0.6, fancy_bbox=True)\n",
|
||||
" draw_detections(image=bbox_image, bboxes=bboxes, scores=scores, landmarks=landmarks, vis_threshold=0.6, corner_bbox=True)\n",
|
||||
"\n",
|
||||
" # Align first detected face (returns aligned image and inverse transform matrix)\n",
|
||||
" first_landmarks = faces[0].landmarks\n",
|
||||
|
||||
@@ -44,7 +44,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2.0.0\n"
|
||||
"3.0.0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -49,7 +49,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2.0.0\n"
|
||||
"3.0.0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -69,16 +69,7 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"✓ Model loaded (CoreML (Apple Silicon))\n",
|
||||
"✓ Model loaded (CoreML (Apple Silicon))\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"analyzer = FaceAnalyzer(\n",
|
||||
" detector=RetinaFace(confidence_threshold=0.5),\n",
|
||||
|
||||
@@ -51,7 +51,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2.0.0\n"
|
||||
"3.0.0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -64,7 +64,7 @@
|
||||
"from uniface.detection import RetinaFace\n",
|
||||
"from uniface.recognition import ArcFace\n",
|
||||
"from uniface.attribute import AgeGender\n",
|
||||
"from uniface.visualization import draw_detections\n",
|
||||
"from uniface.draw import draw_detections\n",
|
||||
"\n",
|
||||
"print(uniface.__version__)"
|
||||
]
|
||||
@@ -148,7 +148,7 @@
|
||||
" bboxes = [f.bbox for f in faces]\n",
|
||||
" scores = [f.confidence for f in faces]\n",
|
||||
" landmarks = [f.landmarks for f in faces]\n",
|
||||
" draw_detections(image=vis_image, bboxes=bboxes, scores=scores, landmarks=landmarks, vis_threshold=0.5, fancy_bbox=True)\n",
|
||||
" draw_detections(image=vis_image, bboxes=bboxes, scores=scores, landmarks=landmarks, vis_threshold=0.5, corner_bbox=True)\n",
|
||||
"\n",
|
||||
" results.append((image_path, cv2.cvtColor(vis_image, cv2.COLOR_BGR2RGB), faces))"
|
||||
]
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -53,7 +53,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"UniFace version: 2.0.0\n"
|
||||
"UniFace version: 3.0.0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -66,7 +66,7 @@
|
||||
"import uniface\n",
|
||||
"from uniface.parsing import BiSeNet\n",
|
||||
"from uniface.constants import ParsingWeights\n",
|
||||
"from uniface.visualization import vis_parsing_maps\n",
|
||||
"from uniface.draw import vis_parsing_maps\n",
|
||||
"\n",
|
||||
"print(f\"UniFace version: {uniface.__version__}\")"
|
||||
]
|
||||
@@ -82,15 +82,7 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"✓ Model loaded (CoreML (Apple Silicon))\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Initialize face parser (uses ResNet18 by default)\n",
|
||||
"parser = BiSeNet(model_name=ParsingWeights.RESNET34) # use resnet34 for better accuracy"
|
||||
|
||||
@@ -51,7 +51,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"UniFace version: 2.0.0\n"
|
||||
"UniFace version: 3.0.0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -65,7 +65,7 @@
|
||||
"import uniface\n",
|
||||
"from uniface.detection import RetinaFace\n",
|
||||
"from uniface.gaze import MobileGaze\n",
|
||||
"from uniface.visualization import draw_gaze\n",
|
||||
"from uniface.draw import draw_gaze\n",
|
||||
"\n",
|
||||
"print(f\"UniFace version: {uniface.__version__}\")"
|
||||
]
|
||||
@@ -110,19 +110,19 @@
|
||||
"text": [
|
||||
"Processing: image0.jpg\n",
|
||||
" Detected 1 face(s)\n",
|
||||
" Face 1: pitch=-0.0°, yaw=7.1°\n",
|
||||
" Face 1: pitch=7.1°, yaw=-0.0°\n",
|
||||
"Processing: image1.jpg\n",
|
||||
" Detected 1 face(s)\n",
|
||||
" Face 1: pitch=-3.3°, yaw=-5.6°\n",
|
||||
" Face 1: pitch=-5.6°, yaw=-3.3°\n",
|
||||
"Processing: image2.jpg\n",
|
||||
" Detected 1 face(s)\n",
|
||||
" Face 1: pitch=-3.9°, yaw=-0.3°\n",
|
||||
" Face 1: pitch=-0.3°, yaw=-3.9°\n",
|
||||
"Processing: image3.jpg\n",
|
||||
" Detected 1 face(s)\n",
|
||||
" Face 1: pitch=-22.1°, yaw=1.0°\n",
|
||||
" Face 1: pitch=1.0°, yaw=-22.1°\n",
|
||||
"Processing: image4.jpg\n",
|
||||
" Detected 1 face(s)\n",
|
||||
" Face 1: pitch=2.1°, yaw=5.0°\n",
|
||||
" Face 1: pitch=5.0°, yaw=2.1°\n",
|
||||
"\n",
|
||||
"Processed 5 images\n"
|
||||
]
|
||||
|
||||
@@ -53,7 +53,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"UniFace version: 2.2.1\n"
|
||||
"UniFace version: 3.0.0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -364,7 +364,7 @@
|
||||
],
|
||||
"source": [
|
||||
"from uniface.parsing import BiSeNet\n",
|
||||
"from uniface.visualization import vis_parsing_maps\n",
|
||||
"from uniface.draw import vis_parsing_maps\n",
|
||||
"\n",
|
||||
"# Load image and detect\n",
|
||||
"image = cv2.imread(\"../assets/einstien.png\")\n",
|
||||
@@ -481,13 +481,21 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"display_name": "base",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"version": "3.10.0"
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.13.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
366
examples/10_face_vector_store.ipynb
Normal file
@@ -135,6 +135,7 @@ nav:
|
||||
- Quickstart: quickstart.md
|
||||
- Notebooks: notebooks.md
|
||||
- Model Zoo: models.md
|
||||
- Datasets: datasets.md
|
||||
- Tutorials:
|
||||
- Image Pipeline: recipes/image-pipeline.md
|
||||
- Video & Webcam: recipes/video-webcam.md
|
||||
@@ -152,6 +153,7 @@ nav:
|
||||
- Gaze: modules/gaze.md
|
||||
- Anti-Spoofing: modules/spoofing.md
|
||||
- Privacy: modules/privacy.md
|
||||
- Indexing: modules/indexing.md
|
||||
- Guides:
|
||||
- Overview: concepts/overview.md
|
||||
- Inputs & Outputs: concepts/inputs-outputs.md
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "uniface"
|
||||
version = "3.0.0"
|
||||
version = "3.1.0"
|
||||
description = "UniFace: A Comprehensive Library for Face Detection, Recognition, Tracking, Landmark Analysis, Face Parsing, Gaze Estimation, Age, and Gender Detection"
|
||||
readme = "README.md"
|
||||
license = "MIT"
|
||||
@@ -29,7 +29,7 @@ keywords = [
|
||||
]
|
||||
|
||||
classifiers = [
|
||||
"Development Status :: 4 - Beta",
|
||||
"Development Status :: 5 - Production/Stable",
|
||||
"Intended Audience :: Developers",
|
||||
"Intended Audience :: Science/Research",
|
||||
"Operating System :: OS Independent",
|
||||
|
||||
@@ -15,6 +15,7 @@ from __future__ import annotations
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
import time
|
||||
|
||||
from _common import get_source_type
|
||||
import cv2
|
||||
@@ -83,6 +84,7 @@ def process_video(
|
||||
if not ret:
|
||||
break
|
||||
|
||||
t0 = time.perf_counter()
|
||||
frame_count += 1
|
||||
faces = detector.detect(frame)
|
||||
total_faces += len(faces)
|
||||
@@ -100,7 +102,9 @@ def process_video(
|
||||
corner_bbox=True,
|
||||
)
|
||||
|
||||
cv2.putText(frame, f'Faces: {len(faces)}', (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
|
||||
inference_fps = 1.0 / max(time.perf_counter() - t0, 1e-9)
|
||||
cv2.putText(frame, f'FPS: {inference_fps:.1f}', (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
|
||||
cv2.putText(frame, f'Faces: {len(faces)}', (10, 65), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
|
||||
out.write(frame)
|
||||
|
||||
if show_preview:
|
||||
@@ -128,6 +132,7 @@ def run_camera(detector, camera_id: int = 0, threshold: float = 0.6):
|
||||
|
||||
print("Press 'q' to quit")
|
||||
|
||||
prev_time = time.perf_counter()
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
frame = cv2.flip(frame, 1)
|
||||
@@ -149,7 +154,11 @@ def run_camera(detector, camera_id: int = 0, threshold: float = 0.6):
|
||||
corner_bbox=True,
|
||||
)
|
||||
|
||||
cv2.putText(frame, f'Faces: {len(faces)}', (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
|
||||
curr_time = time.perf_counter()
|
||||
fps = 1.0 / max(curr_time - prev_time, 1e-9)
|
||||
prev_time = curr_time
|
||||
cv2.putText(frame, f'FPS: {fps:.1f}', (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
|
||||
cv2.putText(frame, f'Faces: {len(faces)}', (10, 65), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
|
||||
cv2.imshow('Face Detection', frame)
|
||||
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||
|
||||
208
tools/faiss_search.py
Normal file
@@ -0,0 +1,208 @@
|
||||
# Copyright 2025-2026 Yakhyokhuja Valikhujaev
|
||||
# Author: Yakhyokhuja Valikhujaev
|
||||
# GitHub: https://github.com/yakhyo
|
||||
|
||||
"""FAISS index build and multi-identity face search.
|
||||
|
||||
Build a vector index from a directory of person sub-folders, then search
|
||||
against it in a video or webcam stream.
|
||||
|
||||
Usage:
|
||||
python tools/faiss_search.py build --faces-dir dataset/ --db-path ./vector_index
|
||||
python tools/faiss_search.py run --db-path ./vector_index --source video.mp4
|
||||
python tools/faiss_search.py run --db-path ./vector_index --source 0 # webcam
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from _common import IMAGE_EXTENSIONS, get_source_type
|
||||
import cv2
|
||||
|
||||
from uniface import create_detector, create_recognizer
|
||||
from uniface.draw import draw_corner_bbox, draw_text_label
|
||||
from uniface.indexing import FAISS
|
||||
|
||||
|
||||
def _draw_face(image, bbox, text: str, color: tuple[int, int, int]) -> None:
|
||||
x1, y1, x2, y2 = map(int, bbox[:4])
|
||||
thickness = max(round(sum(image.shape[:2]) / 2 * 0.003), 2)
|
||||
font_scale = max(0.4, min(0.7, (y2 - y1) / 200))
|
||||
draw_corner_bbox(image, (x1, y1, x2, y2), color=color, thickness=thickness)
|
||||
draw_text_label(image, text, x1, y1, bg_color=color, font_scale=font_scale)
|
||||
|
||||
|
||||
def process_frame(frame, detector, recognizer, store: FAISS, threshold: float = 0.4):
|
||||
faces = detector.detect(frame)
|
||||
if not faces:
|
||||
return frame
|
||||
|
||||
for face in faces:
|
||||
embedding = recognizer.get_normalized_embedding(frame, face.landmarks)
|
||||
result, sim = store.search(embedding, threshold=threshold)
|
||||
|
||||
text = f'{result["person_id"]} ({sim:.2f})' if result else f'Unknown ({sim:.2f})'
|
||||
color = (0, 255, 0) if result else (0, 0, 255)
|
||||
_draw_face(frame, face.bbox, text, color)
|
||||
|
||||
return frame
|
||||
|
||||
|
||||
def process_video(detector, recognizer, store: FAISS, video_path: str, save_dir: str, threshold: float = 0.4):
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
if not cap.isOpened():
|
||||
print(f"Error: Cannot open video file '{video_path}'")
|
||||
return
|
||||
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
output_path = os.path.join(save_dir, f'{Path(video_path).stem}_faiss_search.mp4')
|
||||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
|
||||
|
||||
print(f'Processing video: {video_path} ({total_frames} frames)')
|
||||
frame_count = 0
|
||||
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
frame_count += 1
|
||||
frame = process_frame(frame, detector, recognizer, store, threshold)
|
||||
out.write(frame)
|
||||
|
||||
if frame_count % 100 == 0:
|
||||
print(f' Processed {frame_count}/{total_frames} frames...')
|
||||
|
||||
cap.release()
|
||||
out.release()
|
||||
print(f'Done! Output saved: {output_path}')
|
||||
|
||||
|
||||
def run_camera(detector, recognizer, store: FAISS, camera_id: int = 0, threshold: float = 0.4):
|
||||
cap = cv2.VideoCapture(camera_id)
|
||||
if not cap.isOpened():
|
||||
print(f'Cannot open camera {camera_id}')
|
||||
return
|
||||
|
||||
print("Press 'q' to quit")
|
||||
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
frame = cv2.flip(frame, 1)
|
||||
if not ret:
|
||||
break
|
||||
|
||||
frame = process_frame(frame, detector, recognizer, store, threshold)
|
||||
|
||||
cv2.imshow('Vector Search', frame)
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||
break
|
||||
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
|
||||
def build(args: argparse.Namespace) -> None:
|
||||
faces_dir = Path(args.faces_dir)
|
||||
if not faces_dir.is_dir():
|
||||
print(f"Error: '{faces_dir}' is not a directory")
|
||||
return
|
||||
|
||||
detector = create_detector()
|
||||
recognizer = create_recognizer()
|
||||
store = FAISS(db_path=args.db_path)
|
||||
|
||||
persons = sorted(p.name for p in faces_dir.iterdir() if p.is_dir())
|
||||
if not persons:
|
||||
print(f"Error: No sub-folders found in '{faces_dir}'")
|
||||
return
|
||||
|
||||
print(f'Found {len(persons)} persons: {", ".join(persons)}')
|
||||
|
||||
total_added = 0
|
||||
for person_id in persons:
|
||||
person_dir = faces_dir / person_id
|
||||
images = [f for f in person_dir.iterdir() if f.suffix.lower() in IMAGE_EXTENSIONS]
|
||||
|
||||
added = 0
|
||||
for img_path in images:
|
||||
image = cv2.imread(str(img_path))
|
||||
if image is None:
|
||||
print(f' Warning: Failed to read {img_path}, skipping')
|
||||
continue
|
||||
|
||||
faces = detector.detect(image)
|
||||
if not faces:
|
||||
print(f' Warning: No face detected in {img_path}, skipping')
|
||||
continue
|
||||
|
||||
embedding = recognizer.get_normalized_embedding(image, faces[0].landmarks)
|
||||
store.add(embedding, {'person_id': person_id, 'source': str(img_path)})
|
||||
added += 1
|
||||
|
||||
total_added += added
|
||||
if added:
|
||||
print(f' {person_id}: {added} embeddings added')
|
||||
else:
|
||||
print(f' {person_id}: no valid faces found')
|
||||
|
||||
store.save()
|
||||
print(f'\nIndex saved to {args.db_path} ({total_added} vectors, {len(persons)} persons)')
|
||||
|
||||
|
||||
def run(args: argparse.Namespace) -> None:
|
||||
detector = create_detector()
|
||||
recognizer = create_recognizer()
|
||||
|
||||
store = FAISS(db_path=args.db_path)
|
||||
if not store.load():
|
||||
print(f"Error: No index found at '{args.db_path}'")
|
||||
return
|
||||
print(f'Loaded FAISS index: {store}')
|
||||
|
||||
source_type = get_source_type(args.source)
|
||||
|
||||
if source_type == 'camera':
|
||||
run_camera(detector, recognizer, store, int(args.source), args.threshold)
|
||||
elif source_type == 'video':
|
||||
if not os.path.exists(args.source):
|
||||
print(f'Error: Video not found: {args.source}')
|
||||
return
|
||||
process_video(detector, recognizer, store, args.source, args.save_dir, args.threshold)
|
||||
else:
|
||||
print(f"Error: Source must be a video file or camera ID, not '{args.source}'")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='FAISS vector search')
|
||||
sub = parser.add_subparsers(dest='command', required=True)
|
||||
|
||||
build_p = sub.add_parser('build', help='Build a FAISS index from person sub-folders')
|
||||
build_p.add_argument('--faces-dir', type=str, required=True, help='Directory with person sub-folders')
|
||||
build_p.add_argument('--db-path', type=str, default='./vector_index', help='Where to save the index')
|
||||
|
||||
run_p = sub.add_parser('run', help='Search faces against a FAISS index')
|
||||
run_p.add_argument('--db-path', type=str, required=True, help='Path to saved FAISS index')
|
||||
run_p.add_argument('--source', type=str, required=True, help='Video path or camera ID')
|
||||
run_p.add_argument('--threshold', type=float, default=0.4, help='Similarity threshold')
|
||||
run_p.add_argument('--save-dir', type=str, default='outputs', help='Output directory')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == 'build':
|
||||
build(args)
|
||||
elif args.command == 'run':
|
||||
run(args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -2,11 +2,14 @@
|
||||
# Author: Yakhyokhuja Valikhujaev
|
||||
# GitHub: https://github.com/yakhyo
|
||||
|
||||
"""Real-time face search: match faces against a reference image.
|
||||
"""Single-reference face search on video or webcam.
|
||||
|
||||
Given a reference face image, detects faces in the source and shows
|
||||
whether each face matches the reference.
|
||||
|
||||
Usage:
|
||||
python tools/search.py --reference person.jpg --source 0 # webcam
|
||||
python tools/search.py --reference person.jpg --source video.mp4
|
||||
python tools/search.py --reference ref.jpg --source video.mp4
|
||||
python tools/search.py --reference ref.jpg --source 0 # webcam
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -19,23 +22,12 @@ from _common import get_source_type
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from uniface.detection import SCRFD, RetinaFace
|
||||
from uniface import create_detector, create_recognizer
|
||||
from uniface.draw import draw_corner_bbox, draw_text_label
|
||||
from uniface.face_utils import compute_similarity
|
||||
from uniface.recognition import ArcFace, MobileFace, SphereFace
|
||||
|
||||
|
||||
def get_recognizer(name: str):
|
||||
"""Get recognizer by name."""
|
||||
if name == 'arcface':
|
||||
return ArcFace()
|
||||
elif name == 'mobileface':
|
||||
return MobileFace()
|
||||
else:
|
||||
return SphereFace()
|
||||
|
||||
|
||||
def extract_reference_embedding(detector, recognizer, image_path: str) -> np.ndarray:
|
||||
"""Extract embedding from reference image."""
|
||||
image = cv2.imread(image_path)
|
||||
if image is None:
|
||||
raise RuntimeError(f'Failed to load image: {image_path}')
|
||||
@@ -44,33 +36,34 @@ def extract_reference_embedding(detector, recognizer, image_path: str) -> np.nda
|
||||
if not faces:
|
||||
raise RuntimeError('No faces found in reference image.')
|
||||
|
||||
landmarks = faces[0].landmarks
|
||||
return recognizer.get_normalized_embedding(image, landmarks)
|
||||
return recognizer.get_normalized_embedding(image, faces[0].landmarks)
|
||||
|
||||
|
||||
def _draw_face(image, bbox, text: str, color: tuple[int, int, int]) -> None:
|
||||
x1, y1, x2, y2 = map(int, bbox[:4])
|
||||
thickness = max(round(sum(image.shape[:2]) / 2 * 0.003), 2)
|
||||
font_scale = max(0.4, min(0.7, (y2 - y1) / 200))
|
||||
draw_corner_bbox(image, (x1, y1, x2, y2), color=color, thickness=thickness)
|
||||
draw_text_label(image, text, x1, y1, bg_color=color, font_scale=font_scale)
|
||||
|
||||
|
||||
def process_frame(frame, detector, recognizer, ref_embedding: np.ndarray, threshold: float = 0.4):
|
||||
"""Process a single frame and return annotated frame."""
|
||||
faces = detector.detect(frame)
|
||||
|
||||
for face in faces:
|
||||
bbox = face.bbox
|
||||
landmarks = face.landmarks
|
||||
x1, y1, x2, y2 = map(int, bbox)
|
||||
|
||||
embedding = recognizer.get_normalized_embedding(frame, landmarks)
|
||||
embedding = recognizer.get_normalized_embedding(frame, face.landmarks)
|
||||
sim = compute_similarity(ref_embedding, embedding)
|
||||
|
||||
label = f'Match ({sim:.2f})' if sim > threshold else f'Unknown ({sim:.2f})'
|
||||
text = f'Match ({sim:.2f})' if sim > threshold else f'Unknown ({sim:.2f})'
|
||||
color = (0, 255, 0) if sim > threshold else (0, 0, 255)
|
||||
|
||||
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
|
||||
cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
|
||||
_draw_face(frame, face.bbox, text, color)
|
||||
|
||||
return frame
|
||||
|
||||
|
||||
def process_video(detector, recognizer, ref_embedding: np.ndarray, video_path: str, save_dir: str, threshold: float):
|
||||
"""Process a video file."""
|
||||
def process_video(
|
||||
detector, recognizer, video_path: str, save_dir: str, ref_embedding: np.ndarray, threshold: float = 0.4
|
||||
):
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
if not cap.isOpened():
|
||||
print(f"Error: Cannot open video file '{video_path}'")
|
||||
@@ -107,7 +100,6 @@ def process_video(detector, recognizer, ref_embedding: np.ndarray, video_path: s
|
||||
|
||||
|
||||
def run_camera(detector, recognizer, ref_embedding: np.ndarray, camera_id: int = 0, threshold: float = 0.4):
|
||||
"""Run real-time face search on webcam."""
|
||||
cap = cv2.VideoCapture(camera_id)
|
||||
if not cap.isOpened():
|
||||
print(f'Cannot open camera {camera_id}')
|
||||
@@ -123,7 +115,7 @@ def run_camera(detector, recognizer, ref_embedding: np.ndarray, camera_id: int =
|
||||
|
||||
frame = process_frame(frame, detector, recognizer, ref_embedding, threshold)
|
||||
|
||||
cv2.imshow('Face Recognition', frame)
|
||||
cv2.imshow('Face Search', frame)
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||
break
|
||||
|
||||
@@ -132,17 +124,10 @@ def run_camera(detector, recognizer, ref_embedding: np.ndarray, camera_id: int =
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Face search using a reference image')
|
||||
parser = argparse.ArgumentParser(description='Single-reference face search')
|
||||
parser.add_argument('--reference', type=str, required=True, help='Reference face image')
|
||||
parser.add_argument('--source', type=str, required=True, help='Video path or camera ID (0, 1, ...)')
|
||||
parser.add_argument('--threshold', type=float, default=0.4, help='Match threshold')
|
||||
parser.add_argument('--detector', type=str, default='scrfd', choices=['retinaface', 'scrfd'])
|
||||
parser.add_argument(
|
||||
'--recognizer',
|
||||
type=str,
|
||||
default='arcface',
|
||||
choices=['arcface', 'mobileface', 'sphereface'],
|
||||
)
|
||||
parser.add_argument('--source', type=str, required=True, help='Video path or camera ID')
|
||||
parser.add_argument('--threshold', type=float, default=0.4, help='Similarity threshold')
|
||||
parser.add_argument('--save-dir', type=str, default='outputs', help='Output directory')
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -150,8 +135,8 @@ def main():
|
||||
print(f'Error: Reference image not found: {args.reference}')
|
||||
return
|
||||
|
||||
detector = RetinaFace() if args.detector == 'retinaface' else SCRFD()
|
||||
recognizer = get_recognizer(args.recognizer)
|
||||
detector = create_detector()
|
||||
recognizer = create_recognizer()
|
||||
|
||||
print(f'Loading reference: {args.reference}')
|
||||
ref_embedding = extract_reference_embedding(detector, recognizer, args.reference)
|
||||
@@ -164,10 +149,9 @@ def main():
|
||||
if not os.path.exists(args.source):
|
||||
print(f'Error: Video not found: {args.source}')
|
||||
return
|
||||
process_video(detector, recognizer, ref_embedding, args.source, args.save_dir, args.threshold)
|
||||
process_video(detector, recognizer, args.source, args.save_dir, ref_embedding, args.threshold)
|
||||
else:
|
||||
print(f"Error: Source must be a video file or camera ID, not '{args.source}'")
|
||||
print('Supported formats: videos (.mp4, .avi, ...) or camera ID (0, 1, ...)')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
180
uniface-cpp/.clang-format
Normal file
@@ -0,0 +1,180 @@
|
||||
---
|
||||
# Modern C++ style based on Google with enhancements
|
||||
Language: Cpp
|
||||
Standard: c++17
|
||||
|
||||
BasedOnStyle: Google
|
||||
ColumnLimit: 100
|
||||
IndentWidth: 4
|
||||
TabWidth: 4
|
||||
UseTab: Never
|
||||
|
||||
# Access modifiers
|
||||
AccessModifierOffset: -4
|
||||
IndentAccessModifiers: false
|
||||
|
||||
# Alignment
|
||||
AlignAfterOpenBracket: BlockIndent
|
||||
AlignArrayOfStructures: Right
|
||||
AlignConsecutiveAssignments:
|
||||
Enabled: false
|
||||
AlignConsecutiveBitFields:
|
||||
Enabled: true
|
||||
AlignConsecutiveDeclarations:
|
||||
Enabled: false
|
||||
AlignConsecutiveMacros:
|
||||
Enabled: true
|
||||
AlignEscapedNewlines: Left
|
||||
AlignOperands: AlignAfterOperator
|
||||
AlignTrailingComments:
|
||||
Kind: Always
|
||||
OverEmptyLines: 1
|
||||
|
||||
# Arguments and parameters
|
||||
AllowAllArgumentsOnNextLine: true
|
||||
AllowAllParametersOfDeclarationOnNextLine: true
|
||||
BinPackArguments: false
|
||||
BinPackParameters: false
|
||||
|
||||
# Short forms
|
||||
AllowShortBlocksOnASingleLine: Empty
|
||||
AllowShortCaseLabelsOnASingleLine: false
|
||||
AllowShortEnumsOnASingleLine: false
|
||||
AllowShortFunctionsOnASingleLine: Inline
|
||||
AllowShortIfStatementsOnASingleLine: Never
|
||||
AllowShortLambdasOnASingleLine: Inline
|
||||
AllowShortLoopsOnASingleLine: false
|
||||
|
||||
# Break behavior
|
||||
AlwaysBreakAfterReturnType: None
|
||||
AlwaysBreakBeforeMultilineStrings: true
|
||||
AlwaysBreakTemplateDeclarations: Yes
|
||||
BreakAfterAttributes: Leave
|
||||
BreakBeforeBinaryOperators: None
|
||||
BreakBeforeBraces: Attach
|
||||
BreakBeforeConceptDeclarations: Always
|
||||
BreakBeforeTernaryOperators: true
|
||||
BreakConstructorInitializers: BeforeComma
|
||||
BreakInheritanceList: BeforeComma
|
||||
BreakStringLiterals: true
|
||||
|
||||
# Braces
|
||||
InsertBraces: false
|
||||
RemoveBracesLLVM: false
|
||||
|
||||
# Constructors
|
||||
PackConstructorInitializers: CurrentLine
|
||||
ConstructorInitializerIndentWidth: 4
|
||||
|
||||
# Empty lines
|
||||
EmptyLineAfterAccessModifier: Never
|
||||
EmptyLineBeforeAccessModifier: LogicalBlock
|
||||
KeepEmptyLinesAtTheStartOfBlocks: false
|
||||
MaxEmptyLinesToKeep: 1
|
||||
SeparateDefinitionBlocks: Always
|
||||
|
||||
# Includes
|
||||
IncludeBlocks: Regroup
|
||||
IncludeCategories:
|
||||
# Main header (same name as source file)
|
||||
- Regex: '^"([a-zA-Z0-9_]+)\.(h|hpp)"$'
|
||||
Priority: 1
|
||||
SortPriority: 1
|
||||
CaseSensitive: true
|
||||
# Project headers
|
||||
- Regex: '^".*"$'
|
||||
Priority: 2
|
||||
SortPriority: 2
|
||||
# C system headers
|
||||
- Regex: '^<(assert|complex|ctype|errno|fenv|float|inttypes|iso646|limits|locale|math|setjmp|signal|stdalign|stdarg|stdatomic|stdbool|stddef|stdint|stdio|stdlib|stdnoreturn|string|tgmath|threads|time|uchar|wchar|wctype)\.h>$'
|
||||
Priority: 3
|
||||
SortPriority: 3
|
||||
# C++ standard library
|
||||
- Regex: '^<[a-z_]+>$'
|
||||
Priority: 4
|
||||
SortPriority: 4
|
||||
# External libraries
|
||||
- Regex: '^<.*>$'
|
||||
Priority: 5
|
||||
SortPriority: 5
|
||||
SortIncludes: CaseSensitive
|
||||
|
||||
# Indentation
|
||||
IndentCaseBlocks: false
|
||||
IndentCaseLabels: true
|
||||
IndentExternBlock: NoIndent
|
||||
IndentGotoLabels: false
|
||||
IndentPPDirectives: AfterHash
|
||||
IndentRequiresClause: true
|
||||
IndentWrappedFunctionNames: false
|
||||
|
||||
# Lambdas
|
||||
LambdaBodyIndentation: Signature
|
||||
|
||||
# Namespaces
|
||||
CompactNamespaces: false
|
||||
FixNamespaceComments: true
|
||||
NamespaceIndentation: None
|
||||
ShortNamespaceLines: 0
|
||||
|
||||
# Penalties (guide formatting decisions)
|
||||
PenaltyBreakAssignment: 25
|
||||
PenaltyBreakBeforeFirstCallParameter: 19
|
||||
PenaltyBreakComment: 300
|
||||
PenaltyBreakFirstLessLess: 120
|
||||
PenaltyBreakOpenParenthesis: 0
|
||||
PenaltyBreakString: 1000
|
||||
PenaltyBreakTemplateDeclaration: 10
|
||||
PenaltyExcessCharacter: 1000000
|
||||
PenaltyIndentedWhitespace: 0
|
||||
PenaltyReturnTypeOnItsOwnLine: 200
|
||||
|
||||
# Pointers and references
|
||||
DerivePointerAlignment: false
|
||||
PointerAlignment: Left
|
||||
ReferenceAlignment: Pointer
|
||||
QualifierAlignment: Leave
|
||||
|
||||
# Requires clause (C++20 concepts)
|
||||
RequiresClausePosition: OwnLine
|
||||
RequiresExpressionIndentation: OuterScope
|
||||
|
||||
# Spacing
|
||||
BitFieldColonSpacing: Both
|
||||
SpaceAfterCStyleCast: false
|
||||
SpaceAfterLogicalNot: false
|
||||
SpaceAfterTemplateKeyword: true
|
||||
SpaceAroundPointerQualifiers: Default
|
||||
SpaceBeforeAssignmentOperators: true
|
||||
SpaceBeforeCaseColon: false
|
||||
SpaceBeforeCpp11BracedList: false
|
||||
SpaceBeforeCtorInitializerColon: true
|
||||
SpaceBeforeInheritanceColon: true
|
||||
SpaceBeforeParens: ControlStatements
|
||||
SpaceBeforeRangeBasedForLoopColon: true
|
||||
SpaceBeforeSquareBrackets: false
|
||||
SpaceInEmptyBlock: false
|
||||
SpaceInEmptyParentheses: false
|
||||
SpacesBeforeTrailingComments: 2
|
||||
SpacesInAngles: Never
|
||||
SpacesInCStyleCastParentheses: false
|
||||
SpacesInConditionalStatement: false
|
||||
SpacesInContainerLiterals: false
|
||||
SpacesInLineCommentPrefix:
|
||||
Minimum: 1
|
||||
Maximum: -1
|
||||
SpacesInParentheses: false
|
||||
SpacesInSquareBrackets: false
|
||||
|
||||
# Other
|
||||
Cpp11BracedListStyle: true
|
||||
InsertNewlineAtEOF: true
|
||||
InsertTrailingCommas: None
|
||||
IntegerLiteralSeparator:
|
||||
Binary: 4
|
||||
Decimal: 3
|
||||
Hex: 4
|
||||
ReflowComments: true
|
||||
RemoveSemicolon: false
|
||||
SortUsingDeclarations: LexicographicNumeric
|
||||
...
|
||||
51
uniface-cpp/CMakeLists.txt
Normal file
@@ -0,0 +1,51 @@
|
||||
cmake_minimum_required(VERSION 3.14)
|
||||
|
||||
project(uniface
|
||||
VERSION 1.0.0
|
||||
DESCRIPTION "Uniface C++ face analysis library"
|
||||
LANGUAGES CXX
|
||||
)
|
||||
|
||||
# Options
|
||||
option(UNIFACE_BUILD_EXAMPLES "Build example programs" ON)
|
||||
|
||||
# C++ standard
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_CXX_EXTENSIONS OFF)
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
# Compiler warnings
|
||||
if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang")
|
||||
add_compile_options(-Wall -Wextra -Wpedantic)
|
||||
elseif(MSVC)
|
||||
add_compile_options(/W4)
|
||||
endif()
|
||||
|
||||
# Find dependencies
|
||||
find_package(OpenCV REQUIRED COMPONENTS core imgproc dnn calib3d)
|
||||
|
||||
# Library
|
||||
add_library(uniface
|
||||
src/utils.cpp
|
||||
src/detector.cpp
|
||||
src/recognizer.cpp
|
||||
src/landmarker.cpp
|
||||
src/analyzer.cpp
|
||||
)
|
||||
|
||||
target_include_directories(uniface
|
||||
PUBLIC
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
|
||||
$<INSTALL_INTERFACE:include>
|
||||
)
|
||||
|
||||
target_link_libraries(uniface
|
||||
PUBLIC
|
||||
${OpenCV_LIBS}
|
||||
)
|
||||
|
||||
# Examples
|
||||
if(UNIFACE_BUILD_EXAMPLES)
|
||||
add_subdirectory(examples)
|
||||
endif()
|
||||
69
uniface-cpp/README.md
Normal file
@@ -0,0 +1,69 @@
|
||||
# Uniface C++
|
||||
|
||||
C++ implementation of the Uniface face analysis library.
|
||||
|
||||
## Features
|
||||
|
||||
- **Face Detection** - RetinaFace detector with 5-point landmarks
|
||||
|
||||
## Requirements
|
||||
|
||||
- C++17 compiler
|
||||
- CMake 3.14+
|
||||
- OpenCV 4.x
|
||||
|
||||
## Build
|
||||
|
||||
```bash
|
||||
mkdir build && cd build
|
||||
cmake ..
|
||||
make -j$(nproc)
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Image Detection
|
||||
|
||||
```bash
|
||||
./examples/detect <model_path> <image_path>
|
||||
```
|
||||
|
||||
### Webcam Demo
|
||||
|
||||
```bash
|
||||
./examples/webcam <model_path> [camera_id]
|
||||
```
|
||||
|
||||
### Code Example
|
||||
|
||||
```cpp
|
||||
#include <uniface/uniface.hpp>
|
||||
#include <opencv2/highgui.hpp>
|
||||
|
||||
int main() {
|
||||
uniface::RetinaFace detector("retinaface.onnx");
|
||||
|
||||
cv::Mat image = cv::imread("photo.jpg");
|
||||
auto faces = detector.detect(image);
|
||||
|
||||
for (const auto& face : faces) {
|
||||
cv::rectangle(image, face.bbox, cv::Scalar(0, 255, 0), 2);
|
||||
}
|
||||
|
||||
cv::imwrite("result.jpg", image);
|
||||
return 0;
|
||||
}
|
||||
```
|
||||
|
||||
## Models
|
||||
|
||||
Download models from the main uniface repository or use:
|
||||
|
||||
```bash
|
||||
# RetinaFace MobileNet V2
|
||||
wget https://github.com/your-repo/uniface/releases/download/v1.0/retinaface_mv2.onnx -P models/
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
Same license as the main uniface project.
|
||||
23
uniface-cpp/examples/CMakeLists.txt
Normal file
@@ -0,0 +1,23 @@
|
||||
# Examples
|
||||
|
||||
find_package(OpenCV REQUIRED COMPONENTS highgui imgcodecs videoio)
|
||||
|
||||
# Image detection example
|
||||
add_executable(detect detect.cpp)
|
||||
target_link_libraries(detect PRIVATE uniface ${OpenCV_LIBS})
|
||||
|
||||
# Face recognition example
|
||||
add_executable(recognize recognize.cpp)
|
||||
target_link_libraries(recognize PRIVATE uniface ${OpenCV_LIBS})
|
||||
|
||||
# Facial landmarks example
|
||||
add_executable(landmarks landmarks.cpp)
|
||||
target_link_libraries(landmarks PRIVATE uniface ${OpenCV_LIBS})
|
||||
|
||||
# Face analyzer example
|
||||
add_executable(analyzer analyzer.cpp)
|
||||
target_link_libraries(analyzer PRIVATE uniface ${OpenCV_LIBS})
|
||||
|
||||
# Webcam example
|
||||
add_executable(webcam webcam.cpp)
|
||||
target_link_libraries(webcam PRIVATE uniface ${OpenCV_LIBS})
|
||||
113
uniface-cpp/examples/analyzer.cpp
Normal file
@@ -0,0 +1,113 @@
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
|
||||
#include <opencv2/highgui.hpp>
|
||||
#include <opencv2/imgproc.hpp>
|
||||
#include <uniface/uniface.hpp>
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
if (argc < 3) {
|
||||
std::cout << "Usage: " << argv[0]
|
||||
<< " <detector_model> <image_path> [recognizer_model] [landmark_model]"
|
||||
<< std::endl;
|
||||
std::cout << "\nAnalyzes faces in an image using available models." << std::endl;
|
||||
std::cout << " - detector_model: Required. Path to face detector ONNX model." << std::endl;
|
||||
std::cout << " - recognizer_model: Optional. Path to face recognizer ONNX model."
|
||||
<< std::endl;
|
||||
std::cout << " - landmark_model: Optional. Path to 106-point landmark ONNX model."
|
||||
<< std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
const std::string detector_path = argv[1];
|
||||
const std::string image_path = argv[2];
|
||||
const std::string recognizer_path = (argc > 3) ? argv[3] : "";
|
||||
const std::string landmark_path = (argc > 4) ? argv[4] : "";
|
||||
|
||||
try {
|
||||
// Create analyzer and load components
|
||||
uniface::FaceAnalyzer analyzer;
|
||||
|
||||
std::cout << "Loading detector: " << detector_path << std::endl;
|
||||
analyzer.loadDetector(detector_path);
|
||||
|
||||
if (!recognizer_path.empty()) {
|
||||
std::cout << "Loading recognizer: " << recognizer_path << std::endl;
|
||||
analyzer.loadRecognizer(recognizer_path);
|
||||
}
|
||||
|
||||
if (!landmark_path.empty()) {
|
||||
std::cout << "Loading landmarker: " << landmark_path << std::endl;
|
||||
analyzer.loadLandmarker(landmark_path);
|
||||
}
|
||||
|
||||
// Load image
|
||||
cv::Mat image = cv::imread(image_path);
|
||||
if (image.empty()) {
|
||||
std::cerr << "Failed to load image: " << image_path << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::cout << "\nAnalyzing image..." << std::endl;
|
||||
|
||||
// Analyze faces
|
||||
auto results = analyzer.analyze(image);
|
||||
|
||||
std::cout << "Found " << results.size() << " face(s)\n" << std::endl;
|
||||
|
||||
// Process each face
|
||||
for (size_t i = 0; i < results.size(); ++i) {
|
||||
const auto& result = results[i];
|
||||
|
||||
std::cout << "Face " << (i + 1) << ":" << std::endl;
|
||||
std::cout << " BBox: [" << result.face.bbox.x << ", " << result.face.bbox.y << ", "
|
||||
<< result.face.bbox.width << ", " << result.face.bbox.height << "]"
|
||||
<< std::endl;
|
||||
std::cout << std::fixed << std::setprecision(3);
|
||||
std::cout << " Confidence: " << result.face.confidence << std::endl;
|
||||
|
||||
// Draw bounding box
|
||||
cv::rectangle(image, result.face.bbox, cv::Scalar(0, 255, 0), 2);
|
||||
|
||||
// Draw 5-point landmarks from detector
|
||||
for (const auto& pt : result.face.landmarks) {
|
||||
cv::circle(image, pt, 3, cv::Scalar(0, 0, 255), -1);
|
||||
}
|
||||
|
||||
// If 106-point landmarks available
|
||||
if (result.landmarks) {
|
||||
std::cout << " Landmarks: 106 points detected" << std::endl;
|
||||
for (const auto& pt : result.landmarks->points) {
|
||||
cv::circle(image, pt, 1, cv::Scalar(0, 255, 255), -1);
|
||||
}
|
||||
}
|
||||
|
||||
// If embedding available
|
||||
if (result.embedding) {
|
||||
// Show first few values of embedding
|
||||
std::cout << " Embedding: [";
|
||||
for (size_t j = 0; j < 5; ++j) {
|
||||
std::cout << (*result.embedding)[j];
|
||||
if (j < 4)
|
||||
std::cout << ", ";
|
||||
}
|
||||
std::cout << ", ... ] (512-dim)" << std::endl;
|
||||
}
|
||||
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
// Save result
|
||||
cv::imwrite("analyzer_result.jpg", image);
|
||||
std::cout << "Saved result to analyzer_result.jpg" << std::endl;
|
||||
|
||||
} catch (const cv::Exception& e) {
|
||||
std::cerr << "OpenCV Error: " << e.what() << std::endl;
|
||||
return 1;
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << "Error: " << e.what() << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
48
uniface-cpp/examples/detect.cpp
Normal file
@@ -0,0 +1,48 @@
|
||||
#include <iostream>
|
||||
|
||||
#include <opencv2/highgui.hpp>
|
||||
#include <opencv2/imgproc.hpp>
|
||||
#include <uniface/uniface.hpp>
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
if (argc < 3) {
|
||||
std::cout << "Usage: " << argv[0] << " <model_path> <image_path>" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
const std::string model_path = argv[1];
|
||||
const std::string image_path = argv[2];
|
||||
|
||||
try {
|
||||
uniface::RetinaFace detector(model_path);
|
||||
|
||||
cv::Mat image = cv::imread(image_path);
|
||||
if (image.empty()) {
|
||||
std::cerr << "Failed to load image: " << image_path << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
const auto faces = detector.detect(image);
|
||||
std::cout << "Detected " << faces.size() << " faces." << std::endl;
|
||||
|
||||
// Draw results
|
||||
for (const auto& face : faces) {
|
||||
cv::rectangle(image, face.bbox, cv::Scalar(0, 255, 0), 2);
|
||||
for (const auto& pt : face.landmarks) {
|
||||
cv::circle(image, pt, 2, cv::Scalar(0, 0, 255), -1);
|
||||
}
|
||||
}
|
||||
|
||||
cv::imwrite("result.jpg", image);
|
||||
std::cout << "Saved result to result.jpg" << std::endl;
|
||||
|
||||
} catch (const cv::Exception& e) {
|
||||
std::cerr << "OpenCV Error: " << e.what() << std::endl;
|
||||
return 1;
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << "Error: " << e.what() << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
66
uniface-cpp/examples/landmarks.cpp
Normal file
@@ -0,0 +1,66 @@
|
||||
#include <iostream>
|
||||
|
||||
#include <opencv2/highgui.hpp>
|
||||
#include <opencv2/imgproc.hpp>
|
||||
#include <uniface/uniface.hpp>
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
if (argc < 4) {
|
||||
std::cout << "Usage: " << argv[0] << " <detector_model> <landmark_model> <image_path>"
|
||||
<< std::endl;
|
||||
std::cout << "\nDetects 106-point facial landmarks and saves visualization." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
const std::string detector_path = argv[1];
|
||||
const std::string landmark_path = argv[2];
|
||||
const std::string image_path = argv[3];
|
||||
|
||||
try {
|
||||
// Load models
|
||||
uniface::RetinaFace detector(detector_path);
|
||||
uniface::Landmark106 landmarker(landmark_path);
|
||||
|
||||
// Load image
|
||||
cv::Mat image = cv::imread(image_path);
|
||||
if (image.empty()) {
|
||||
std::cerr << "Failed to load image: " << image_path << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Detect faces
|
||||
auto faces = detector.detect(image);
|
||||
std::cout << "Detected " << faces.size() << " face(s)" << std::endl;
|
||||
|
||||
// Process each face
|
||||
for (size_t i = 0; i < faces.size(); ++i) {
|
||||
const auto& face = faces[i];
|
||||
|
||||
// Draw bounding box
|
||||
cv::rectangle(image, face.bbox, cv::Scalar(0, 255, 0), 2);
|
||||
|
||||
// Get 106-point landmarks
|
||||
auto landmarks = landmarker.getLandmarks(image, face.bbox);
|
||||
|
||||
// Draw all 106 points
|
||||
for (const auto& pt : landmarks.points) {
|
||||
cv::circle(image, pt, 1, cv::Scalar(0, 255, 255), -1);
|
||||
}
|
||||
|
||||
std::cout << "Face " << (i + 1) << ": 106 landmarks detected" << std::endl;
|
||||
}
|
||||
|
||||
// Save result
|
||||
cv::imwrite("landmarks_result.jpg", image);
|
||||
std::cout << "Saved result to landmarks_result.jpg" << std::endl;
|
||||
|
||||
} catch (const cv::Exception& e) {
|
||||
std::cerr << "OpenCV Error: " << e.what() << std::endl;
|
||||
return 1;
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << "Error: " << e.what() << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
81
uniface-cpp/examples/recognize.cpp
Normal file
@@ -0,0 +1,81 @@
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
|
||||
#include <opencv2/highgui.hpp>
|
||||
#include <opencv2/imgproc.hpp>
|
||||
#include <uniface/uniface.hpp>
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
if (argc < 5) {
|
||||
std::cout << "Usage: " << argv[0]
|
||||
<< " <detector_model> <recognizer_model> <image1> <image2>" << std::endl;
|
||||
std::cout << "\nCompares faces from two images and outputs similarity score." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
const std::string detector_path = argv[1];
|
||||
const std::string recognizer_path = argv[2];
|
||||
const std::string image1_path = argv[3];
|
||||
const std::string image2_path = argv[4];
|
||||
|
||||
try {
|
||||
// Load models
|
||||
uniface::RetinaFace detector(detector_path);
|
||||
uniface::ArcFace recognizer(recognizer_path);
|
||||
|
||||
// Load images
|
||||
cv::Mat image1 = cv::imread(image1_path);
|
||||
cv::Mat image2 = cv::imread(image2_path);
|
||||
|
||||
if (image1.empty()) {
|
||||
std::cerr << "Failed to load image: " << image1_path << std::endl;
|
||||
return 1;
|
||||
}
|
||||
if (image2.empty()) {
|
||||
std::cerr << "Failed to load image: " << image2_path << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Detect faces
|
||||
auto faces1 = detector.detect(image1);
|
||||
auto faces2 = detector.detect(image2);
|
||||
|
||||
if (faces1.empty()) {
|
||||
std::cerr << "No face detected in image1" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
if (faces2.empty()) {
|
||||
std::cerr << "No face detected in image2" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::cout << "Detected " << faces1.size() << " face(s) in image1" << std::endl;
|
||||
std::cout << "Detected " << faces2.size() << " face(s) in image2" << std::endl;
|
||||
|
||||
// Get embeddings for first face in each image
|
||||
auto embedding1 = recognizer.getNormalizedEmbedding(image1, faces1[0].landmarks);
|
||||
auto embedding2 = recognizer.getNormalizedEmbedding(image2, faces2[0].landmarks);
|
||||
|
||||
// Compute similarity
|
||||
float similarity = uniface::cosineSimilarity(embedding1, embedding2);
|
||||
|
||||
std::cout << std::fixed << std::setprecision(4);
|
||||
std::cout << "\nCosine Similarity: " << similarity << std::endl;
|
||||
|
||||
// Interpretation
|
||||
if (similarity > 0.4f) {
|
||||
std::cout << "Result: Same person (similarity > 0.4)" << std::endl;
|
||||
} else {
|
||||
std::cout << "Result: Different persons (similarity <= 0.4)" << std::endl;
|
||||
}
|
||||
|
||||
} catch (const cv::Exception& e) {
|
||||
std::cerr << "OpenCV Error: " << e.what() << std::endl;
|
||||
return 1;
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << "Error: " << e.what() << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
209
uniface-cpp/examples/webcam.cpp
Normal file
@@ -0,0 +1,209 @@
|
||||
#include <chrono>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
|
||||
#include <opencv2/highgui.hpp>
|
||||
#include <opencv2/imgproc.hpp>
|
||||
#include <opencv2/videoio.hpp>
|
||||
#include <uniface/uniface.hpp>
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
if (argc < 2) {
|
||||
std::cout << "Usage: " << argv[0] << " <detector_model> [landmark_model] [camera_id]"
|
||||
<< std::endl;
|
||||
std::cout << "\nArguments:" << std::endl;
|
||||
std::cout << " detector_model : Path to face detector ONNX model (required)" << std::endl;
|
||||
std::cout << " landmark_model : Path to 106-point landmark ONNX model (optional)"
|
||||
<< std::endl;
|
||||
std::cout << " camera_id : Camera device ID, default 0 (optional)" << std::endl;
|
||||
std::cout << "\nExamples:" << std::endl;
|
||||
std::cout << " " << argv[0] << " detector.onnx" << std::endl;
|
||||
std::cout << " " << argv[0] << " detector.onnx landmark.onnx" << std::endl;
|
||||
std::cout << " " << argv[0] << " detector.onnx landmark.onnx 1" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
const std::string detector_path = argv[1];
|
||||
std::string landmark_path;
|
||||
int camera_id = 0;
|
||||
|
||||
// Parse arguments - landmark_model is optional
|
||||
if (argc >= 3) {
|
||||
// Check if argv[2] is a number (camera_id) or a path (landmark_model)
|
||||
if (std::isdigit(argv[2][0]) && strlen(argv[2]) <= 2) {
|
||||
camera_id = std::atoi(argv[2]);
|
||||
} else {
|
||||
landmark_path = argv[2];
|
||||
if (argc >= 4) {
|
||||
camera_id = std::atoi(argv[3]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
// Load detector
|
||||
std::cout << "Loading detector: " << detector_path << std::endl;
|
||||
uniface::RetinaFace detector(detector_path);
|
||||
std::cout << "Detector loaded!" << std::endl;
|
||||
|
||||
// Load landmark model if provided
|
||||
std::unique_ptr<uniface::Landmark106> landmarker;
|
||||
if (!landmark_path.empty()) {
|
||||
std::cout << "Loading landmarker: " << landmark_path << std::endl;
|
||||
landmarker = std::make_unique<uniface::Landmark106>(landmark_path);
|
||||
std::cout << "Landmarker loaded!" << std::endl;
|
||||
}
|
||||
|
||||
// Open camera
|
||||
cv::VideoCapture cap(camera_id);
|
||||
if (!cap.isOpened()) {
|
||||
std::cerr << "Error: Cannot open camera " << camera_id << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
const int frame_width = static_cast<int>(cap.get(cv::CAP_PROP_FRAME_WIDTH));
|
||||
const int frame_height = static_cast<int>(cap.get(cv::CAP_PROP_FRAME_HEIGHT));
|
||||
std::cout << "\nCamera opened: " << frame_width << "x" << frame_height << std::endl;
|
||||
std::cout << "Press 'q' to quit, 's' to save screenshot, 'l' to toggle landmarks"
|
||||
<< std::endl;
|
||||
|
||||
cv::Mat frame;
|
||||
int frame_count = 0;
|
||||
double total_time = 0.0;
|
||||
bool show_landmarks = true; // Toggle for 106-point landmarks
|
||||
|
||||
while (true) {
|
||||
cap >> frame;
|
||||
if (frame.empty()) {
|
||||
std::cerr << "Error: Empty frame captured" << std::endl;
|
||||
break;
|
||||
}
|
||||
|
||||
const auto start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
// Detect faces
|
||||
const auto faces = detector.detect(frame);
|
||||
|
||||
// Get 106-point landmarks if available
|
||||
std::vector<uniface::Landmarks> all_landmarks;
|
||||
if (landmarker && show_landmarks) {
|
||||
all_landmarks.reserve(faces.size());
|
||||
for (const auto& face : faces) {
|
||||
all_landmarks.push_back(landmarker->getLandmarks(frame, face.bbox));
|
||||
}
|
||||
}
|
||||
|
||||
const auto end = std::chrono::high_resolution_clock::now();
|
||||
const std::chrono::duration<double, std::milli> elapsed = end - start;
|
||||
const double inference_time = elapsed.count();
|
||||
|
||||
++frame_count;
|
||||
total_time += inference_time;
|
||||
const double avg_time = total_time / static_cast<double>(frame_count);
|
||||
const double fps = 1000.0 / avg_time;
|
||||
|
||||
// Draw results
|
||||
for (size_t i = 0; i < faces.size(); ++i) {
|
||||
const auto& face = faces[i];
|
||||
|
||||
// Draw bounding box
|
||||
cv::rectangle(frame, face.bbox, cv::Scalar(0, 255, 0), 2);
|
||||
|
||||
// Draw 5-point landmarks from detector
|
||||
for (size_t j = 0; j < face.landmarks.size(); ++j) {
|
||||
cv::Scalar color;
|
||||
if (j < 2) {
|
||||
color = cv::Scalar(255, 0, 0); // Eyes - Blue
|
||||
} else if (j == 2) {
|
||||
color = cv::Scalar(0, 255, 0); // Nose - Green
|
||||
} else {
|
||||
color = cv::Scalar(0, 0, 255); // Mouth - Red
|
||||
}
|
||||
cv::circle(frame, face.landmarks[j], 3, color, -1);
|
||||
}
|
||||
|
||||
// Draw 106-point landmarks if available
|
||||
if (i < all_landmarks.size()) {
|
||||
const auto& lm = all_landmarks[i];
|
||||
|
||||
// Draw all 106 points
|
||||
for (const auto& pt : lm.points) {
|
||||
cv::circle(frame, pt, 1, cv::Scalar(0, 255, 255), -1);
|
||||
}
|
||||
}
|
||||
|
||||
// Draw confidence
|
||||
const std::string conf_text = cv::format("%.2f", face.confidence);
|
||||
const cv::Point text_org(
|
||||
static_cast<int>(face.bbox.x), static_cast<int>(face.bbox.y) - 5
|
||||
);
|
||||
cv::putText(
|
||||
frame,
|
||||
conf_text,
|
||||
text_org,
|
||||
cv::FONT_HERSHEY_SIMPLEX,
|
||||
0.5,
|
||||
cv::Scalar(0, 255, 0),
|
||||
1
|
||||
);
|
||||
}
|
||||
|
||||
// Draw info overlay
|
||||
std::string mode = landmarker
|
||||
? (show_landmarks ? "Detection + 106 Landmarks" : "Detection Only")
|
||||
: "Detection Only";
|
||||
const std::string info_text = cv::format(
|
||||
"FPS: %.1f | Faces: %zu | Time: %.1fms", fps, faces.size(), inference_time
|
||||
);
|
||||
cv::putText(
|
||||
frame,
|
||||
info_text,
|
||||
cv::Point(10, 30),
|
||||
cv::FONT_HERSHEY_SIMPLEX,
|
||||
0.7,
|
||||
cv::Scalar(0, 255, 0),
|
||||
2
|
||||
);
|
||||
cv::putText(
|
||||
frame,
|
||||
mode,
|
||||
cv::Point(10, 60),
|
||||
cv::FONT_HERSHEY_SIMPLEX,
|
||||
0.6,
|
||||
cv::Scalar(255, 255, 0),
|
||||
2
|
||||
);
|
||||
|
||||
cv::imshow("Uniface - Face Detection & Landmarks", frame);
|
||||
|
||||
const char key = static_cast<char>(cv::waitKey(1));
|
||||
if (key == 'q' || key == 27) {
|
||||
break;
|
||||
} else if (key == 's') {
|
||||
const std::string filename = cv::format("screenshot_%d.jpg", frame_count);
|
||||
cv::imwrite(filename, frame);
|
||||
std::cout << "Screenshot saved: " << filename << std::endl;
|
||||
} else if (key == 'l' && landmarker) {
|
||||
show_landmarks = !show_landmarks;
|
||||
std::cout << "106-point landmarks: " << (show_landmarks ? "ON" : "OFF")
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
cap.release();
|
||||
cv::destroyAllWindows();
|
||||
|
||||
std::cout << "\n=== Statistics ===" << std::endl;
|
||||
std::cout << "Total frames: " << frame_count << std::endl;
|
||||
std::cout << "Average inference time: " << (total_time / frame_count) << " ms" << std::endl;
|
||||
|
||||
} catch (const cv::Exception& e) {
|
||||
std::cerr << "OpenCV Error: " << e.what() << std::endl;
|
||||
return 1;
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << "Error: " << e.what() << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
63
uniface-cpp/include/uniface/analyzer.hpp
Normal file
@@ -0,0 +1,63 @@
|
||||
#ifndef UNIFACE_ANALYZER_HPP_
|
||||
#define UNIFACE_ANALYZER_HPP_
|
||||
|
||||
#include "uniface/detector.hpp"
|
||||
#include "uniface/landmarker.hpp"
|
||||
#include "uniface/recognizer.hpp"
|
||||
#include "uniface/types.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace uniface {
|
||||
|
||||
// Result of face analysis
|
||||
struct AnalyzedFace {
|
||||
Face face; // detection result (bbox, confidence, 5-point landmarks)
|
||||
std::optional<Landmarks> landmarks; // 106-point landmarks (if landmarker loaded)
|
||||
std::optional<Embedding> embedding; // face embedding (if recognizer loaded)
|
||||
};
|
||||
|
||||
// Unified face analysis combining detection, recognition, and landmarks
|
||||
class FaceAnalyzer {
|
||||
public:
|
||||
FaceAnalyzer() = default;
|
||||
~FaceAnalyzer() = default;
|
||||
|
||||
FaceAnalyzer(const FaceAnalyzer&) = delete;
|
||||
FaceAnalyzer& operator=(const FaceAnalyzer&) = delete;
|
||||
FaceAnalyzer(FaceAnalyzer&&) = default;
|
||||
FaceAnalyzer& operator=(FaceAnalyzer&&) = default;
|
||||
|
||||
// Load components (returns *this for chaining)
|
||||
FaceAnalyzer& loadDetector(const std::string& path, const DetectorConfig& config = DetectorConfig{});
|
||||
FaceAnalyzer& loadRecognizer(const std::string& path, const RecognizerConfig& config = RecognizerConfig{});
|
||||
FaceAnalyzer& loadLandmarker(const std::string& path, const LandmarkerConfig& config = LandmarkerConfig{});
|
||||
|
||||
// Analyze faces in BGR image (throws if detector not loaded)
|
||||
[[nodiscard]] std::vector<AnalyzedFace> analyze(const cv::Mat& image);
|
||||
|
||||
// Component checks
|
||||
[[nodiscard]] bool hasDetector() const noexcept { return detector_ != nullptr; }
|
||||
[[nodiscard]] bool hasRecognizer() const noexcept { return recognizer_ != nullptr; }
|
||||
[[nodiscard]] bool hasLandmarker() const noexcept { return landmarker_ != nullptr; }
|
||||
|
||||
// Direct component access
|
||||
[[nodiscard]] RetinaFace* detector() noexcept { return detector_.get(); }
|
||||
[[nodiscard]] ArcFace* recognizer() noexcept { return recognizer_.get(); }
|
||||
[[nodiscard]] Landmark106* landmarker() noexcept { return landmarker_.get(); }
|
||||
[[nodiscard]] const RetinaFace* detector() const noexcept { return detector_.get(); }
|
||||
[[nodiscard]] const ArcFace* recognizer() const noexcept { return recognizer_.get(); }
|
||||
[[nodiscard]] const Landmark106* landmarker() const noexcept { return landmarker_.get(); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<RetinaFace> detector_;
|
||||
std::unique_ptr<ArcFace> recognizer_;
|
||||
std::unique_ptr<Landmark106> landmarker_;
|
||||
};
|
||||
|
||||
} // namespace uniface
|
||||
|
||||
#endif // UNIFACE_ANALYZER_HPP_
|
||||
47
uniface-cpp/include/uniface/detector.hpp
Normal file
@@ -0,0 +1,47 @@
|
||||
#ifndef UNIFACE_DETECTOR_HPP_
|
||||
#define UNIFACE_DETECTOR_HPP_
|
||||
|
||||
#include "uniface/types.hpp"
|
||||
|
||||
#include <array>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include <opencv2/dnn.hpp>
|
||||
|
||||
namespace uniface {
|
||||
|
||||
// RetinaFace detector using OpenCV DNN backend
|
||||
class RetinaFace {
|
||||
public:
|
||||
explicit RetinaFace(
|
||||
const std::string& model_path,
|
||||
float conf_thresh = 0.5f,
|
||||
float nms_thresh = 0.4f,
|
||||
cv::Size input_size = cv::Size(640, 640)
|
||||
);
|
||||
|
||||
// Detect faces in BGR image, returns bboxes + 5-point landmarks
|
||||
[[nodiscard]] std::vector<Face> detect(const cv::Mat& image);
|
||||
|
||||
// Accessors
|
||||
[[nodiscard]] float getConfidenceThreshold() const noexcept { return confidence_threshold_; }
|
||||
[[nodiscard]] float getNmsThreshold() const noexcept { return nms_threshold_; }
|
||||
[[nodiscard]] cv::Size getInputSize() const noexcept { return input_size_; }
|
||||
|
||||
void setConfidenceThreshold(float threshold) noexcept { confidence_threshold_ = threshold; }
|
||||
void setNmsThreshold(float threshold) noexcept { nms_threshold_ = threshold; }
|
||||
|
||||
private:
|
||||
cv::dnn::Net net_;
|
||||
float confidence_threshold_;
|
||||
float nms_threshold_;
|
||||
cv::Size input_size_;
|
||||
std::vector<std::array<float, 4>> anchors_;
|
||||
|
||||
void generateAnchors();
|
||||
};
|
||||
|
||||
} // namespace uniface
|
||||
|
||||
#endif // UNIFACE_DETECTOR_HPP_
|
||||
32
uniface-cpp/include/uniface/landmarker.hpp
Normal file
@@ -0,0 +1,32 @@
|
||||
#ifndef UNIFACE_LANDMARKER_HPP_
|
||||
#define UNIFACE_LANDMARKER_HPP_
|
||||
|
||||
#include "uniface/types.hpp"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include <opencv2/dnn.hpp>
|
||||
|
||||
namespace uniface {
|
||||
|
||||
// 106-point facial landmark detector
|
||||
class Landmark106 {
|
||||
public:
|
||||
explicit Landmark106(const std::string& model_path, const LandmarkerConfig& config = LandmarkerConfig{});
|
||||
|
||||
// Detect 106 landmarks for a face, returns points in original image coordinates
|
||||
[[nodiscard]] Landmarks getLandmarks(const cv::Mat& image, const cv::Rect2f& bbox);
|
||||
|
||||
[[nodiscard]] cv::Size getInputSize() const noexcept { return config_.input_size; }
|
||||
|
||||
private:
|
||||
cv::dnn::Net net_;
|
||||
LandmarkerConfig config_;
|
||||
|
||||
[[nodiscard]] cv::Mat preprocess(const cv::Mat& image, const cv::Rect2f& bbox, cv::Mat& transform);
|
||||
[[nodiscard]] Landmarks postprocess(const cv::Mat& predictions, const cv::Mat& transform);
|
||||
};
|
||||
|
||||
} // namespace uniface
|
||||
|
||||
#endif // UNIFACE_LANDMARKER_HPP_
|
||||
37
uniface-cpp/include/uniface/recognizer.hpp
Normal file
@@ -0,0 +1,37 @@
|
||||
#ifndef UNIFACE_RECOGNIZER_HPP_
|
||||
#define UNIFACE_RECOGNIZER_HPP_
|
||||
|
||||
#include "uniface/types.hpp"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include <opencv2/dnn.hpp>
|
||||
|
||||
namespace uniface {
|
||||
|
||||
// ArcFace face recognition (MobileNet/ResNet backbones)
|
||||
class ArcFace {
|
||||
public:
|
||||
explicit ArcFace(const std::string& model_path, const RecognizerConfig& config = RecognizerConfig{});
|
||||
|
||||
// Get 512-dim embedding from pre-aligned 112x112 face
|
||||
[[nodiscard]] Embedding getEmbedding(const cv::Mat& aligned_face);
|
||||
|
||||
// Get 512-dim embedding with automatic alignment
|
||||
[[nodiscard]] Embedding getEmbedding(const cv::Mat& image, const std::array<cv::Point2f, 5>& landmarks);
|
||||
|
||||
// Get L2-normalized embedding with automatic alignment
|
||||
[[nodiscard]] Embedding getNormalizedEmbedding(const cv::Mat& image, const std::array<cv::Point2f, 5>& landmarks);
|
||||
|
||||
[[nodiscard]] cv::Size getInputSize() const noexcept { return config_.input_size; }
|
||||
|
||||
private:
|
||||
cv::dnn::Net net_;
|
||||
RecognizerConfig config_;
|
||||
|
||||
[[nodiscard]] cv::Mat preprocess(const cv::Mat& face_image);
|
||||
};
|
||||
|
||||
} // namespace uniface
|
||||
|
||||
#endif // UNIFACE_RECOGNIZER_HPP_
|
||||
45
uniface-cpp/include/uniface/types.hpp
Normal file
@@ -0,0 +1,45 @@
|
||||
#ifndef UNIFACE_TYPES_HPP_
|
||||
#define UNIFACE_TYPES_HPP_
|
||||
|
||||
#include <array>
|
||||
#include <vector>
|
||||
|
||||
#include <opencv2/core.hpp>
|
||||
|
||||
namespace uniface {
|
||||
|
||||
// Detected face with bbox, confidence, and 5-point landmarks
|
||||
struct Face {
|
||||
cv::Rect2f bbox;
|
||||
float confidence;
|
||||
std::array<cv::Point2f, 5> landmarks; // left_eye, right_eye, nose, left_mouth, right_mouth
|
||||
};
|
||||
|
||||
// 512-dimensional face embedding
|
||||
using Embedding = std::array<float, 512>;
|
||||
|
||||
// 106-point facial landmarks
|
||||
struct Landmarks {
|
||||
std::array<cv::Point2f, 106> points;
|
||||
};
|
||||
|
||||
// Configuration structs
|
||||
struct DetectorConfig {
|
||||
float conf_thresh = 0.5f;
|
||||
float nms_thresh = 0.4f;
|
||||
cv::Size input_size = cv::Size(640, 640);
|
||||
};
|
||||
|
||||
struct RecognizerConfig {
|
||||
float input_mean = 127.5f;
|
||||
float input_std = 127.5f;
|
||||
cv::Size input_size = cv::Size(112, 112);
|
||||
};
|
||||
|
||||
struct LandmarkerConfig {
|
||||
cv::Size input_size = cv::Size(192, 192);
|
||||
};
|
||||
|
||||
} // namespace uniface
|
||||
|
||||
#endif // UNIFACE_TYPES_HPP_
|
||||
11
uniface-cpp/include/uniface/uniface.hpp
Normal file
@@ -0,0 +1,11 @@
|
||||
#ifndef UNIFACE_HPP_
|
||||
#define UNIFACE_HPP_
|
||||
|
||||
#include "uniface/analyzer.hpp"
|
||||
#include "uniface/detector.hpp"
|
||||
#include "uniface/landmarker.hpp"
|
||||
#include "uniface/recognizer.hpp"
|
||||
#include "uniface/types.hpp"
|
||||
#include "uniface/utils.hpp"
|
||||
|
||||
#endif // UNIFACE_HPP_
|
||||
58
uniface-cpp/include/uniface/utils.hpp
Normal file
@@ -0,0 +1,58 @@
|
||||
#ifndef UNIFACE_UTILS_HPP_
|
||||
#define UNIFACE_UTILS_HPP_
|
||||
|
||||
#include "uniface/types.hpp"
|
||||
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
|
||||
#include <opencv2/core.hpp>
|
||||
|
||||
namespace uniface {
|
||||
|
||||
// Reference 5-point landmarks for ArcFace alignment (112x112)
|
||||
inline constexpr std::array<float, 10> kReferenceAlignment = {
|
||||
38.2946f, 51.6963f, // left eye
|
||||
73.5318f, 51.5014f, // right eye
|
||||
56.0252f, 71.7366f, // nose
|
||||
41.5493f, 92.3655f, // left mouth
|
||||
70.7299f, 92.2041f // right mouth
|
||||
};
|
||||
|
||||
// Align face using 5-point landmarks (default 112x112 for ArcFace)
|
||||
[[nodiscard]] cv::Mat alignFace(
|
||||
const cv::Mat& image,
|
||||
const std::array<cv::Point2f, 5>& landmarks,
|
||||
cv::Size output_size = cv::Size(112, 112)
|
||||
);
|
||||
|
||||
// Cosine similarity between embeddings, returns [-1, 1]
|
||||
[[nodiscard]] float cosineSimilarity(const Embedding& a, const Embedding& b) noexcept;
|
||||
|
||||
// Apply 2x3 affine transform to points
|
||||
template <size_t N>
|
||||
[[nodiscard]] std::array<cv::Point2f, N> transformPoints2D(
|
||||
const std::array<cv::Point2f, N>& points, const cv::Mat& transform
|
||||
) {
|
||||
std::array<cv::Point2f, N> result{};
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
const float x = points[i].x;
|
||||
const float y = points[i].y;
|
||||
result[i].x = static_cast<float>(
|
||||
transform.at<double>(0, 0) * x + transform.at<double>(0, 1) * y +
|
||||
transform.at<double>(0, 2)
|
||||
);
|
||||
result[i].y = static_cast<float>(
|
||||
transform.at<double>(1, 0) * x + transform.at<double>(1, 1) * y +
|
||||
transform.at<double>(1, 2)
|
||||
);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Letterbox resize preserving aspect ratio, returns scale factor
|
||||
[[nodiscard]] float letterboxResize(const cv::Mat& src, cv::Mat& dst, cv::Size target_size);
|
||||
|
||||
} // namespace uniface
|
||||
|
||||
#endif // UNIFACE_UTILS_HPP_
|
||||
0
uniface-cpp/models/.gitkeep
Normal file
55
uniface-cpp/src/analyzer.cpp
Normal file
@@ -0,0 +1,55 @@
|
||||
#include "uniface/analyzer.hpp"
|
||||
|
||||
#include <stdexcept>
|
||||
|
||||
namespace uniface {
|
||||
|
||||
FaceAnalyzer& FaceAnalyzer::loadDetector(const std::string& path, const DetectorConfig& config) {
|
||||
detector_ = std::make_unique<RetinaFace>(
|
||||
path, config.conf_thresh, config.nms_thresh, config.input_size
|
||||
);
|
||||
return *this;
|
||||
}
|
||||
|
||||
FaceAnalyzer& FaceAnalyzer::loadRecognizer(
|
||||
const std::string& path, const RecognizerConfig& config
|
||||
) {
|
||||
recognizer_ = std::make_unique<ArcFace>(path, config);
|
||||
return *this;
|
||||
}
|
||||
|
||||
FaceAnalyzer& FaceAnalyzer::loadLandmarker(
|
||||
const std::string& path, const LandmarkerConfig& config
|
||||
) {
|
||||
landmarker_ = std::make_unique<Landmark106>(path, config);
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::vector<AnalyzedFace> FaceAnalyzer::analyze(const cv::Mat& image) {
|
||||
if (!detector_) {
|
||||
throw std::runtime_error("FaceAnalyzer: detector not loaded. Call loadDetector() first.");
|
||||
}
|
||||
|
||||
auto faces = detector_->detect(image);
|
||||
|
||||
std::vector<AnalyzedFace> results;
|
||||
results.reserve(faces.size());
|
||||
|
||||
for (const auto& face : faces) {
|
||||
AnalyzedFace result;
|
||||
result.face = face;
|
||||
|
||||
if (landmarker_) {
|
||||
result.landmarks = landmarker_->getLandmarks(image, face.bbox);
|
||||
}
|
||||
if (recognizer_) {
|
||||
result.embedding = recognizer_->getNormalizedEmbedding(image, face.landmarks);
|
||||
}
|
||||
|
||||
results.push_back(std::move(result));
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
} // namespace uniface
|
||||
204
uniface-cpp/src/detector.cpp
Normal file
@@ -0,0 +1,204 @@
|
||||
#include "uniface/detector.hpp"
|
||||
|
||||
#include "uniface/utils.hpp"
|
||||
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
|
||||
#include <opencv2/imgproc.hpp>
|
||||
|
||||
namespace uniface {
|
||||
|
||||
namespace {
|
||||
|
||||
// Model configuration constants
|
||||
constexpr std::array<int, 3> kFeatureStrides = {8, 16, 32};
|
||||
constexpr std::array<float, 2> kVariance = {0.1f, 0.2f};
|
||||
constexpr int kNumLandmarks = 5;
|
||||
|
||||
// BGR mean values for image normalization
|
||||
constexpr float kMeanB = 104.0f;
|
||||
constexpr float kMeanG = 117.0f;
|
||||
constexpr float kMeanR = 123.0f;
|
||||
|
||||
// Anchor min sizes for each feature map level
|
||||
const std::vector<std::vector<int>> kMinSizes = {
|
||||
{ 16, 32},
|
||||
{ 64, 128},
|
||||
{256, 512}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
RetinaFace::RetinaFace(
|
||||
const std::string& model_path, float conf_thresh, float nms_thresh, cv::Size input_size
|
||||
)
|
||||
: net_(cv::dnn::readNetFromONNX(model_path))
|
||||
, confidence_threshold_(conf_thresh)
|
||||
, nms_threshold_(nms_thresh)
|
||||
, input_size_(input_size) {
|
||||
generateAnchors();
|
||||
}
|
||||
|
||||
void RetinaFace::generateAnchors() {
|
||||
anchors_.clear();
|
||||
|
||||
size_t estimated_anchors = 0;
|
||||
for (size_t k = 0; k < kFeatureStrides.size(); ++k) {
|
||||
const int step = kFeatureStrides[k];
|
||||
const auto feature_h = static_cast<size_t>(
|
||||
std::ceil(static_cast<float>(input_size_.height) / static_cast<float>(step))
|
||||
);
|
||||
const auto feature_w = static_cast<size_t>(
|
||||
std::ceil(static_cast<float>(input_size_.width) / static_cast<float>(step))
|
||||
);
|
||||
estimated_anchors += feature_h * feature_w * kMinSizes[k].size();
|
||||
}
|
||||
anchors_.reserve(estimated_anchors);
|
||||
|
||||
for (size_t k = 0; k < kFeatureStrides.size(); ++k) {
|
||||
const int step = kFeatureStrides[k];
|
||||
const int feature_h = static_cast<int>(
|
||||
std::ceil(static_cast<float>(input_size_.height) / static_cast<float>(step))
|
||||
);
|
||||
const int feature_w = static_cast<int>(
|
||||
std::ceil(static_cast<float>(input_size_.width) / static_cast<float>(step))
|
||||
);
|
||||
|
||||
for (int i = 0; i < feature_h; ++i) {
|
||||
for (int j = 0; j < feature_w; ++j) {
|
||||
for (const int min_size : kMinSizes[k]) {
|
||||
const float s_kx = static_cast<float>(min_size) /
|
||||
static_cast<float>(input_size_.height);
|
||||
const float s_ky = static_cast<float>(min_size) /
|
||||
static_cast<float>(input_size_.width);
|
||||
const float cx = (static_cast<float>(j) + 0.5f) * static_cast<float>(step) /
|
||||
static_cast<float>(input_size_.height);
|
||||
const float cy = (static_cast<float>(i) + 0.5f) * static_cast<float>(step) /
|
||||
static_cast<float>(input_size_.width);
|
||||
|
||||
anchors_.push_back({cx, cy, s_kx, s_ky});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<Face> RetinaFace::detect(const cv::Mat& image) {
|
||||
cv::Mat input_blob;
|
||||
const float resize_factor = letterboxResize(image, input_blob, input_size_);
|
||||
|
||||
const cv::Mat blob = cv::dnn::blobFromImage(
|
||||
input_blob, 1.0, cv::Size(), cv::Scalar(kMeanB, kMeanG, kMeanR), false, false
|
||||
);
|
||||
|
||||
net_.setInput(blob);
|
||||
const auto output_names = net_.getUnconnectedOutLayersNames();
|
||||
std::vector<cv::Mat> outputs;
|
||||
net_.forward(outputs, output_names);
|
||||
|
||||
if (outputs.size() < 3) {
|
||||
std::cerr << "Error: Model output count mismatch. Expected at least 3, got "
|
||||
<< outputs.size() << std::endl;
|
||||
return {};
|
||||
}
|
||||
|
||||
// Identify outputs by shape: loc(N,4), conf(N,2), landmarks(N,10)
|
||||
cv::Mat loc_output, conf_output, land_output;
|
||||
|
||||
for (const auto& output : outputs) {
|
||||
switch (output.size[2]) {
|
||||
case 4: loc_output = output; break;
|
||||
case 2: conf_output = output; break;
|
||||
case 10: land_output = output; break;
|
||||
default: break;
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to positional outputs
|
||||
if (loc_output.empty()) loc_output = outputs[0];
|
||||
if (conf_output.empty()) conf_output = outputs[1];
|
||||
if (land_output.empty()) land_output = outputs[2];
|
||||
|
||||
const auto* loc_data = reinterpret_cast<const float*>(loc_output.data);
|
||||
const auto* conf_data = reinterpret_cast<const float*>(conf_output.data);
|
||||
const auto* land_data = reinterpret_cast<const float*>(land_output.data);
|
||||
const auto num_priors = static_cast<size_t>(loc_output.size[1]);
|
||||
|
||||
if (num_priors != anchors_.size()) {
|
||||
std::cerr << "Error: Anchor count mismatch! Expected " << anchors_.size()
|
||||
<< " anchors but model output has " << num_priors << " priors.\n"
|
||||
<< "This usually means the input size doesn't match the model's "
|
||||
<< "expected size." << std::endl;
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<cv::Rect2f> decoded_boxes;
|
||||
std::vector<float> scores;
|
||||
std::vector<std::array<cv::Point2f, 5>> decoded_landmarks;
|
||||
decoded_boxes.reserve(num_priors);
|
||||
scores.reserve(num_priors);
|
||||
decoded_landmarks.reserve(num_priors);
|
||||
|
||||
const auto scale_w = static_cast<float>(input_size_.width);
|
||||
const auto scale_h = static_cast<float>(input_size_.height);
|
||||
|
||||
for (size_t i = 0; i < num_priors; ++i) {
|
||||
const float score = conf_data[i * 2 + 1];
|
||||
if (score < confidence_threshold_) continue;
|
||||
|
||||
const float px = anchors_[i][0];
|
||||
const float py = anchors_[i][1];
|
||||
const float pw = anchors_[i][2];
|
||||
const float ph = anchors_[i][3];
|
||||
|
||||
const float dx = loc_data[i * 4 + 0];
|
||||
const float dy = loc_data[i * 4 + 1];
|
||||
const float dw = loc_data[i * 4 + 2];
|
||||
const float dh = loc_data[i * 4 + 3];
|
||||
|
||||
const float cx = px + dx * kVariance[0] * pw;
|
||||
const float cy = py + dy * kVariance[0] * ph;
|
||||
const float w = pw * std::exp(dw * kVariance[1]);
|
||||
const float h = ph * std::exp(dh * kVariance[1]);
|
||||
|
||||
const float x1 = (cx - w / 2.0f) * scale_w / resize_factor;
|
||||
const float y1 = (cy - h / 2.0f) * scale_h / resize_factor;
|
||||
const float x2 = (cx + w / 2.0f) * scale_w / resize_factor;
|
||||
const float y2 = (cy + h / 2.0f) * scale_h / resize_factor;
|
||||
|
||||
decoded_boxes.emplace_back(x1, y1, x2 - x1, y2 - y1);
|
||||
scores.push_back(score);
|
||||
|
||||
std::array<cv::Point2f, 5> landmarks{};
|
||||
for (int k = 0; k < kNumLandmarks; ++k) {
|
||||
const float ldx = land_data[i * 10 + static_cast<size_t>(k) * 2 + 0];
|
||||
const float ldy = land_data[i * 10 + static_cast<size_t>(k) * 2 + 1];
|
||||
const float lx = (px + ldx * kVariance[0] * pw) * scale_w / resize_factor;
|
||||
const float ly = (py + ldy * kVariance[0] * ph) * scale_h / resize_factor;
|
||||
landmarks[static_cast<size_t>(k)] = cv::Point2f(lx, ly);
|
||||
}
|
||||
decoded_landmarks.push_back(landmarks);
|
||||
}
|
||||
|
||||
// NMS
|
||||
std::vector<cv::Rect2d> boxes_for_nms;
|
||||
boxes_for_nms.reserve(decoded_boxes.size());
|
||||
for (const auto& box : decoded_boxes) {
|
||||
boxes_for_nms.emplace_back(box.x, box.y, box.width, box.height);
|
||||
}
|
||||
|
||||
std::vector<int> nms_indices;
|
||||
cv::dnn::NMSBoxes(boxes_for_nms, scores, confidence_threshold_, nms_threshold_, nms_indices);
|
||||
|
||||
std::vector<Face> results;
|
||||
results.reserve(nms_indices.size());
|
||||
for (const int idx : nms_indices) {
|
||||
const auto uidx = static_cast<size_t>(idx);
|
||||
results.push_back({decoded_boxes[uidx], scores[uidx], decoded_landmarks[uidx]});
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
} // namespace uniface
|
||||
90
uniface-cpp/src/landmarker.cpp
Normal file
@@ -0,0 +1,90 @@
|
||||
#include "uniface/landmarker.hpp"
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include <opencv2/imgproc.hpp>
|
||||
|
||||
namespace uniface {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr int kNumLandmarks = 106;
|
||||
|
||||
cv::Mat computeCenterTransform(const cv::Point2f& center, float scale, int output_size) {
|
||||
cv::Mat transform = cv::Mat::zeros(2, 3, CV_64F);
|
||||
|
||||
transform.at<double>(0, 0) = scale;
|
||||
transform.at<double>(1, 1) = scale;
|
||||
transform.at<double>(0, 2) = -center.x * scale + output_size / 2.0;
|
||||
transform.at<double>(1, 2) = -center.y * scale + output_size / 2.0;
|
||||
|
||||
return transform;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Landmark106::Landmark106(const std::string& model_path, const LandmarkerConfig& config)
|
||||
: net_(cv::dnn::readNetFromONNX(model_path))
|
||||
, config_(config) {}
|
||||
|
||||
cv::Mat Landmark106::preprocess(const cv::Mat& image, const cv::Rect2f& bbox, cv::Mat& transform) {
|
||||
const float width = bbox.width;
|
||||
const float height = bbox.height;
|
||||
const float center_x = bbox.x + width / 2.0f;
|
||||
const float center_y = bbox.y + height / 2.0f;
|
||||
|
||||
const float max_dim = std::max(width, height);
|
||||
const float scale = static_cast<float>(config_.input_size.width) / (max_dim * 1.5f);
|
||||
|
||||
transform = computeCenterTransform(cv::Point2f(center_x, center_y), scale, config_.input_size.width);
|
||||
|
||||
cv::Mat aligned;
|
||||
cv::warpAffine(image, aligned, transform, config_.input_size, cv::INTER_LINEAR, cv::BORDER_CONSTANT);
|
||||
|
||||
cv::Mat blob = cv::dnn::blobFromImage(aligned, 1.0, config_.input_size, cv::Scalar(0, 0, 0), true, false);
|
||||
|
||||
return blob;
|
||||
}
|
||||
|
||||
Landmarks Landmark106::postprocess(const cv::Mat& predictions, const cv::Mat& transform) {
|
||||
Landmarks result{};
|
||||
const auto* pred_data = reinterpret_cast<const float*>(predictions.data);
|
||||
|
||||
cv::Mat inverse_transform;
|
||||
cv::invertAffineTransform(transform, inverse_transform);
|
||||
|
||||
const int input_size = config_.input_size.width;
|
||||
const float half_size = static_cast<float>(input_size) / 2.0f;
|
||||
|
||||
for (int i = 0; i < kNumLandmarks; ++i) {
|
||||
// Denormalize from [-1, 1] to pixel coordinates
|
||||
float x = (pred_data[i * 2 + 0] + 1.0f) * half_size;
|
||||
float y = (pred_data[i * 2 + 1] + 1.0f) * half_size;
|
||||
|
||||
// Transform back to original image coordinates
|
||||
const float orig_x = static_cast<float>(
|
||||
inverse_transform.at<double>(0, 0) * x + inverse_transform.at<double>(0, 1) * y +
|
||||
inverse_transform.at<double>(0, 2)
|
||||
);
|
||||
const float orig_y = static_cast<float>(
|
||||
inverse_transform.at<double>(1, 0) * x + inverse_transform.at<double>(1, 1) * y +
|
||||
inverse_transform.at<double>(1, 2)
|
||||
);
|
||||
|
||||
result.points[static_cast<size_t>(i)] = cv::Point2f(orig_x, orig_y);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
Landmarks Landmark106::getLandmarks(const cv::Mat& image, const cv::Rect2f& bbox) {
|
||||
cv::Mat transform;
|
||||
cv::Mat blob = preprocess(image, bbox, transform);
|
||||
|
||||
net_.setInput(blob);
|
||||
cv::Mat output = net_.forward();
|
||||
|
||||
return postprocess(output, transform);
|
||||
}
|
||||
|
||||
} // namespace uniface
|
||||
73
uniface-cpp/src/recognizer.cpp
Normal file
@@ -0,0 +1,73 @@
|
||||
#include "uniface/recognizer.hpp"
|
||||
|
||||
#include "uniface/utils.hpp"
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include <opencv2/imgproc.hpp>
|
||||
|
||||
namespace uniface {
|
||||
|
||||
ArcFace::ArcFace(const std::string& model_path, const RecognizerConfig& config)
|
||||
: net_(cv::dnn::readNetFromONNX(model_path))
|
||||
, config_(config) {}
|
||||
|
||||
cv::Mat ArcFace::preprocess(const cv::Mat& face_image) {
|
||||
cv::Mat resized;
|
||||
if (face_image.size() != config_.input_size) {
|
||||
cv::resize(face_image, resized, config_.input_size);
|
||||
} else {
|
||||
resized = face_image;
|
||||
}
|
||||
|
||||
// Normalize: (pixel - mean) / std, BGR -> RGB
|
||||
cv::Mat blob = cv::dnn::blobFromImage(
|
||||
resized, 1.0 / config_.input_std, config_.input_size,
|
||||
cv::Scalar(config_.input_mean, config_.input_mean, config_.input_mean), true, false
|
||||
);
|
||||
|
||||
return blob;
|
||||
}
|
||||
|
||||
Embedding ArcFace::getEmbedding(const cv::Mat& aligned_face) {
|
||||
cv::Mat blob = preprocess(aligned_face);
|
||||
|
||||
net_.setInput(blob);
|
||||
cv::Mat output = net_.forward();
|
||||
|
||||
Embedding embedding{};
|
||||
const auto* output_data = reinterpret_cast<const float*>(output.data);
|
||||
const size_t embedding_size = std::min(static_cast<size_t>(output.total()), embedding.size());
|
||||
|
||||
for (size_t i = 0; i < embedding_size; ++i) {
|
||||
embedding[i] = output_data[i];
|
||||
}
|
||||
|
||||
return embedding;
|
||||
}
|
||||
|
||||
Embedding ArcFace::getEmbedding(const cv::Mat& image, const std::array<cv::Point2f, 5>& landmarks) {
|
||||
cv::Mat aligned = alignFace(image, landmarks, config_.input_size);
|
||||
return getEmbedding(aligned);
|
||||
}
|
||||
|
||||
Embedding ArcFace::getNormalizedEmbedding(const cv::Mat& image, const std::array<cv::Point2f, 5>& landmarks) {
|
||||
Embedding embedding = getEmbedding(image, landmarks);
|
||||
|
||||
// L2 normalize
|
||||
float norm = 0.0f;
|
||||
for (const float val : embedding) {
|
||||
norm += val * val;
|
||||
}
|
||||
norm = std::sqrt(norm);
|
||||
|
||||
if (norm > 1e-8f) {
|
||||
for (float& val : embedding) {
|
||||
val /= norm;
|
||||
}
|
||||
}
|
||||
|
||||
return embedding;
|
||||
}
|
||||
|
||||
} // namespace uniface
|
||||
82
uniface-cpp/src/utils.cpp
Normal file
@@ -0,0 +1,82 @@
|
||||
#include "uniface/utils.hpp"
|
||||
|
||||
#include <opencv2/calib3d.hpp>
|
||||
#include <opencv2/imgproc.hpp>
|
||||
|
||||
namespace uniface {
|
||||
|
||||
cv::Mat alignFace(const cv::Mat& image, const std::array<cv::Point2f, 5>& landmarks, cv::Size output_size) {
|
||||
const float ratio = static_cast<float>(output_size.width) / 112.0f;
|
||||
|
||||
std::vector<cv::Point2f> dst_points(5);
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
dst_points[i].x = kReferenceAlignment[static_cast<size_t>(i) * 2] * ratio;
|
||||
dst_points[i].y = kReferenceAlignment[static_cast<size_t>(i) * 2 + 1] * ratio;
|
||||
}
|
||||
|
||||
std::vector<cv::Point2f> src_points(landmarks.begin(), landmarks.end());
|
||||
cv::Mat transform = cv::estimateAffinePartial2D(src_points, dst_points);
|
||||
|
||||
if (transform.empty()) {
|
||||
cv::Mat resized;
|
||||
cv::resize(image, resized, output_size);
|
||||
return resized;
|
||||
}
|
||||
|
||||
cv::Mat aligned;
|
||||
cv::warpAffine(image, aligned, transform, output_size, cv::INTER_LINEAR, cv::BORDER_CONSTANT);
|
||||
|
||||
return aligned;
|
||||
}
|
||||
|
||||
float cosineSimilarity(const Embedding& a, const Embedding& b) noexcept {
|
||||
float dot = 0.0f;
|
||||
float norm_a = 0.0f;
|
||||
float norm_b = 0.0f;
|
||||
|
||||
for (size_t i = 0; i < a.size(); ++i) {
|
||||
dot += a[i] * b[i];
|
||||
norm_a += a[i] * a[i];
|
||||
norm_b += b[i] * b[i];
|
||||
}
|
||||
|
||||
const float denom = std::sqrt(norm_a) * std::sqrt(norm_b);
|
||||
if (denom < 1e-8f) {
|
||||
return 0.0f;
|
||||
}
|
||||
|
||||
return dot / denom;
|
||||
}
|
||||
|
||||
float letterboxResize(const cv::Mat& src, cv::Mat& dst, cv::Size target_size) {
|
||||
const auto src_height = static_cast<float>(src.rows);
|
||||
const auto src_width = static_cast<float>(src.cols);
|
||||
const auto target_height = static_cast<float>(target_size.height);
|
||||
const auto target_width = static_cast<float>(target_size.width);
|
||||
|
||||
const float im_ratio = src_height / src_width;
|
||||
const float model_ratio = target_height / target_width;
|
||||
|
||||
int new_width = 0;
|
||||
int new_height = 0;
|
||||
|
||||
if (im_ratio > model_ratio) {
|
||||
new_height = static_cast<int>(target_height);
|
||||
new_width = static_cast<int>(static_cast<float>(new_height) / im_ratio);
|
||||
} else {
|
||||
new_width = static_cast<int>(target_width);
|
||||
new_height = static_cast<int>(static_cast<float>(new_width) * im_ratio);
|
||||
}
|
||||
|
||||
const float resize_factor = static_cast<float>(new_height) / src_height;
|
||||
|
||||
cv::Mat resized;
|
||||
cv::resize(src, resized, cv::Size(new_width, new_height));
|
||||
|
||||
dst = cv::Mat::zeros(target_size, src.type());
|
||||
resized.copyTo(dst(cv::Rect(0, 0, new_width, new_height)));
|
||||
|
||||
return resize_factor;
|
||||
}
|
||||
|
||||
} // namespace uniface
|
||||
@@ -29,7 +29,9 @@ from __future__ import annotations
|
||||
|
||||
__license__ = 'MIT'
|
||||
__author__ = 'Yakhyokhuja Valikhujaev'
|
||||
__version__ = '3.0.0'
|
||||
__version__ = '3.1.0'
|
||||
|
||||
import contextlib
|
||||
|
||||
from uniface.face_utils import compute_similarity, face_alignment
|
||||
from uniface.log import Logger, enable_logging
|
||||
@@ -54,6 +56,10 @@ from .spoofing import MiniFASNet, create_spoofer
|
||||
from .tracking import BYTETracker
|
||||
from .types import AttributeResult, EmotionResult, Face, GazeResult, SpoofingResult
|
||||
|
||||
# Optional: FAISS vector store (requires `pip install faiss-cpu`)
|
||||
with contextlib.suppress(ImportError):
|
||||
from .indexing import FAISS
|
||||
|
||||
__all__ = [
|
||||
# Metadata
|
||||
'__author__',
|
||||
@@ -101,6 +107,8 @@ __all__ = [
|
||||
'BYTETracker',
|
||||
# Privacy
|
||||
'BlurFace',
|
||||
# Indexing (optional)
|
||||
'FAISS',
|
||||
# Utilities
|
||||
'Logger',
|
||||
'compute_similarity',
|
||||
|
||||
@@ -2,9 +2,25 @@
|
||||
# Author: Yakhyokhuja Valikhujaev
|
||||
# GitHub: https://github.com/yakhyo
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ModelInfo:
|
||||
"""Model metadata including download URL and SHA-256 hash.
|
||||
|
||||
Attributes:
|
||||
url: Direct download link to the model weights.
|
||||
sha256: SHA-256 checksum for integrity verification.
|
||||
"""
|
||||
|
||||
url: str
|
||||
sha256: str
|
||||
|
||||
|
||||
# fmt: off
|
||||
class SphereFaceWeights(str, Enum):
|
||||
"""
|
||||
@@ -166,125 +182,202 @@ class MiniFASNetWeights(str, Enum):
|
||||
https://github.com/yakhyo/face-anti-spoofing
|
||||
|
||||
Model Variants:
|
||||
- V1SE: Uses scale=4.0 for face crop (squeese-and-excitation version)
|
||||
- V1SE: Uses scale=4.0 for face crop (squeeze-and-excitation version)
|
||||
- V2: Uses scale=2.7 for face crop (improved version)
|
||||
"""
|
||||
V1SE = "minifasnet_v1se"
|
||||
V2 = "minifasnet_v2"
|
||||
|
||||
|
||||
MODEL_URLS: dict[Enum, str] = {
|
||||
# Centralized Model Registry
|
||||
MODEL_REGISTRY: dict[Enum, ModelInfo] = {
|
||||
# RetinaFace
|
||||
RetinaFaceWeights.MNET_025: 'https://github.com/yakhyo/uniface/releases/download/weights/retinaface_mv1_0.25.onnx',
|
||||
RetinaFaceWeights.MNET_050: 'https://github.com/yakhyo/uniface/releases/download/weights/retinaface_mv1_0.50.onnx',
|
||||
RetinaFaceWeights.MNET_V1: 'https://github.com/yakhyo/uniface/releases/download/weights/retinaface_mv1.onnx',
|
||||
RetinaFaceWeights.MNET_V2: 'https://github.com/yakhyo/uniface/releases/download/weights/retinaface_mv2.onnx',
|
||||
RetinaFaceWeights.RESNET18: 'https://github.com/yakhyo/uniface/releases/download/weights/retinaface_r18.onnx',
|
||||
RetinaFaceWeights.RESNET34: 'https://github.com/yakhyo/uniface/releases/download/weights/retinaface_r34.onnx',
|
||||
RetinaFaceWeights.MNET_025: ModelInfo(
|
||||
url='https://github.com/yakhyo/uniface/releases/download/weights/retinaface_mv1_0.25.onnx',
|
||||
sha256='b7a7acab55e104dce6f32cdfff929bd83946da5cd869b9e2e9bdffafd1b7e4a5'
|
||||
),
|
||||
RetinaFaceWeights.MNET_050: ModelInfo(
|
||||
url='https://github.com/yakhyo/uniface/releases/download/weights/retinaface_mv1_0.50.onnx',
|
||||
sha256='d8977186f6037999af5b4113d42ba77a84a6ab0c996b17c713cc3d53b88bfc37'
|
||||
),
|
||||
RetinaFaceWeights.MNET_V1: ModelInfo(
|
||||
url='https://github.com/yakhyo/uniface/releases/download/weights/retinaface_mv1.onnx',
|
||||
sha256='75c961aaf0aff03d13c074e9ec656e5510e174454dd4964a161aab4fe5f04153'
|
||||
),
|
||||
RetinaFaceWeights.MNET_V2: ModelInfo(
|
||||
url='https://github.com/yakhyo/uniface/releases/download/weights/retinaface_mv2.onnx',
|
||||
sha256='3ca44c045651cabeed1193a1fae8946ad1f3a55da8fa74b341feab5a8319f757'
|
||||
),
|
||||
RetinaFaceWeights.RESNET18: ModelInfo(
|
||||
url='https://github.com/yakhyo/uniface/releases/download/weights/retinaface_r18.onnx',
|
||||
sha256='e8b5ddd7d2c3c8f7c942f9f10cec09d8e319f78f09725d3f709631de34fb649d'
|
||||
),
|
||||
RetinaFaceWeights.RESNET34: ModelInfo(
|
||||
url='https://github.com/yakhyo/uniface/releases/download/weights/retinaface_r34.onnx',
|
||||
sha256='bd0263dc2a465d32859555cb1741f2d98991eb0053696e8ee33fec583d30e630'
|
||||
),
|
||||
|
||||
# MobileFace
|
||||
MobileFaceWeights.MNET_025: 'https://github.com/yakhyo/uniface/releases/download/weights/mobilenetv1_0.25.onnx',
|
||||
MobileFaceWeights.MNET_V2: 'https://github.com/yakhyo/uniface/releases/download/weights/mobilenetv2.onnx',
|
||||
MobileFaceWeights.MNET_V3_SMALL: 'https://github.com/yakhyo/uniface/releases/download/weights/mobilenetv3_small.onnx',
|
||||
MobileFaceWeights.MNET_V3_LARGE: 'https://github.com/yakhyo/uniface/releases/download/weights/mobilenetv3_large.onnx',
|
||||
MobileFaceWeights.MNET_025: ModelInfo(
|
||||
url='https://github.com/yakhyo/uniface/releases/download/weights/mobilenetv1_0.25.onnx',
|
||||
sha256='eeda7d23d9c2b40cf77fa8da8e895b5697465192648852216074679657f8ee8b'
|
||||
),
|
||||
MobileFaceWeights.MNET_V2: ModelInfo(
|
||||
url='https://github.com/yakhyo/uniface/releases/download/weights/mobilenetv2.onnx',
|
||||
sha256='38b148284dd48cc898d5d4453104252fbdcbacc105fe3f0b80e78954d9d20d89'
|
||||
),
|
||||
MobileFaceWeights.MNET_V3_SMALL: ModelInfo(
|
||||
url='https://github.com/yakhyo/uniface/releases/download/weights/mobilenetv3_small.onnx',
|
||||
sha256='d4acafa1039a82957aa8a9a1dac278a401c353a749c39df43de0e29cc1c127c3'
|
||||
),
|
||||
MobileFaceWeights.MNET_V3_LARGE: ModelInfo(
|
||||
url='https://github.com/yakhyo/uniface/releases/download/weights/mobilenetv3_large.onnx',
|
||||
sha256='0e48f8e11f070211716d03e5c65a3db35a5e917cfb5bc30552358629775a142a'
|
||||
),
|
||||
|
||||
# SphereFace
|
||||
SphereFaceWeights.SPHERE20: 'https://github.com/yakhyo/uniface/releases/download/weights/sphere20.onnx',
|
||||
SphereFaceWeights.SPHERE36: 'https://github.com/yakhyo/uniface/releases/download/weights/sphere36.onnx',
|
||||
SphereFaceWeights.SPHERE20: ModelInfo(
|
||||
url='https://github.com/yakhyo/uniface/releases/download/weights/sphere20.onnx',
|
||||
sha256='c02878cf658eb1861f580b7e7144b0d27cc29c440bcaa6a99d466d2854f14c9d'
|
||||
),
|
||||
SphereFaceWeights.SPHERE36: ModelInfo(
|
||||
url='https://github.com/yakhyo/uniface/releases/download/weights/sphere36.onnx',
|
||||
sha256='13b3890cd5d7dec2b63f7c36fd7ce07403e5a0bbb701d9647c0289e6cbe7bb20'
|
||||
),
|
||||
|
||||
# ArcFace
|
||||
ArcFaceWeights.MNET: 'https://github.com/yakhyo/uniface/releases/download/weights/w600k_mbf.onnx',
|
||||
ArcFaceWeights.RESNET: 'https://github.com/yakhyo/uniface/releases/download/weights/w600k_r50.onnx',
|
||||
ArcFaceWeights.MNET: ModelInfo(
|
||||
url='https://github.com/yakhyo/uniface/releases/download/weights/w600k_mbf.onnx',
|
||||
sha256='9cc6e4a75f0e2bf0b1aed94578f144d15175f357bdc05e815e5c4a02b319eb4f'
|
||||
),
|
||||
ArcFaceWeights.RESNET: ModelInfo(
|
||||
url='https://github.com/yakhyo/uniface/releases/download/weights/w600k_r50.onnx',
|
||||
sha256='4c06341c33c2ca1f86781dab0e829f88ad5b64be9fba56e56bc9ebdefc619e43'
|
||||
),
|
||||
|
||||
# AdaFace
|
||||
AdaFaceWeights.IR_18: 'https://github.com/yakhyo/adaface-onnx/releases/download/weights/adaface_ir_18.onnx',
|
||||
AdaFaceWeights.IR_101: 'https://github.com/yakhyo/adaface-onnx/releases/download/weights/adaface_ir_101.onnx',
|
||||
AdaFaceWeights.IR_18: ModelInfo(
|
||||
url='https://github.com/yakhyo/adaface-onnx/releases/download/weights/adaface_ir_18.onnx',
|
||||
sha256='6b6a35772fb636cdd4fa86520c1a259d0c41472a76f70f802b351837a00d9870'
|
||||
),
|
||||
AdaFaceWeights.IR_101: ModelInfo(
|
||||
url='https://github.com/yakhyo/adaface-onnx/releases/download/weights/adaface_ir_101.onnx',
|
||||
sha256='f2eb07d03de0af560a82e1214df799fec5e09375d43521e2868f9dc387e5a43e'
|
||||
),
|
||||
|
||||
# SCRFD
|
||||
SCRFDWeights.SCRFD_10G_KPS: 'https://github.com/yakhyo/uniface/releases/download/weights/scrfd_10g_kps.onnx',
|
||||
SCRFDWeights.SCRFD_500M_KPS: 'https://github.com/yakhyo/uniface/releases/download/weights/scrfd_500m_kps.onnx',
|
||||
SCRFDWeights.SCRFD_10G_KPS: ModelInfo(
|
||||
url='https://github.com/yakhyo/uniface/releases/download/weights/scrfd_10g_kps.onnx',
|
||||
sha256='5838f7fe053675b1c7a08b633df49e7af5495cee0493c7dcf6697200b85b5b91'
|
||||
),
|
||||
SCRFDWeights.SCRFD_500M_KPS: ModelInfo(
|
||||
url='https://github.com/yakhyo/uniface/releases/download/weights/scrfd_500m_kps.onnx',
|
||||
sha256='5e4447f50245bbd7966bd6c0fa52938c61474a04ec7def48753668a9d8b4ea3a'
|
||||
),
|
||||
|
||||
# YOLOv5-Face
|
||||
YOLOv5FaceWeights.YOLOV5N: 'https://github.com/yakhyo/yolov5-face-onnx-inference/releases/download/weights/yolov5n_face.onnx',
|
||||
YOLOv5FaceWeights.YOLOV5S: 'https://github.com/yakhyo/yolov5-face-onnx-inference/releases/download/weights/yolov5s_face.onnx',
|
||||
YOLOv5FaceWeights.YOLOV5M: 'https://github.com/yakhyo/yolov5-face-onnx-inference/releases/download/weights/yolov5m_face.onnx',
|
||||
YOLOv5FaceWeights.YOLOV5N: ModelInfo(
|
||||
url='https://github.com/yakhyo/yolov5-face-onnx-inference/releases/download/weights/yolov5n_face.onnx',
|
||||
sha256='eb244a06e36999db732b317c2b30fa113cd6cfc1a397eaf738f2d6f33c01f640'
|
||||
),
|
||||
YOLOv5FaceWeights.YOLOV5S: ModelInfo(
|
||||
url='https://github.com/yakhyo/yolov5-face-onnx-inference/releases/download/weights/yolov5s_face.onnx',
|
||||
sha256='fc682801cd5880e1e296184a14aea0035486b5146ec1a1389d2e7149cb134bb2'
|
||||
),
|
||||
YOLOv5FaceWeights.YOLOV5M: ModelInfo(
|
||||
url='https://github.com/yakhyo/yolov5-face-onnx-inference/releases/download/weights/yolov5m_face.onnx',
|
||||
sha256='04302ce27a15bde3e20945691b688e2dd018a10e92dd8932146bede6a49207b2'
|
||||
),
|
||||
|
||||
# YOLOv8-Face
|
||||
YOLOv8FaceWeights.YOLOV8_LITE_S: 'https://github.com/yakhyo/yolov8-face-onnx-inference/releases/download/weights/yolov8-lite-s.onnx',
|
||||
YOLOv8FaceWeights.YOLOV8N: 'https://github.com/yakhyo/yolov8-face-onnx-inference/releases/download/weights/yolov8n-face.onnx',
|
||||
YOLOv8FaceWeights.YOLOV8_LITE_S: ModelInfo(
|
||||
url='https://github.com/yakhyo/yolov8-face-onnx-inference/releases/download/weights/yolov8-lite-s.onnx',
|
||||
sha256='11bc496be01356d2d960085bfd8abb8f103199900a034f239a8a1705a1b31dba'
|
||||
),
|
||||
YOLOv8FaceWeights.YOLOV8N: ModelInfo(
|
||||
url='https://github.com/yakhyo/yolov8-face-onnx-inference/releases/download/weights/yolov8n-face.onnx',
|
||||
sha256='33f3951af7fc0c4d9b321b29cdcd8c9a59d0a29a8d4bdc01fcb5507d5c714809'
|
||||
),
|
||||
|
||||
# DDAFM
|
||||
DDAMFNWeights.AFFECNET7: 'https://github.com/yakhyo/uniface/releases/download/weights/affecnet7.script',
|
||||
DDAMFNWeights.AFFECNET8: 'https://github.com/yakhyo/uniface/releases/download/weights/affecnet8.script',
|
||||
DDAMFNWeights.AFFECNET7: ModelInfo(
|
||||
url='https://github.com/yakhyo/uniface/releases/download/weights/affecnet7.script',
|
||||
sha256='10535bf8b6afe8e9d6ae26cea6c3add9a93036e9addb6adebfd4a972171d015d'
|
||||
),
|
||||
DDAMFNWeights.AFFECNET8: ModelInfo(
|
||||
url='https://github.com/yakhyo/uniface/releases/download/weights/affecnet8.script',
|
||||
sha256='8c66963bc71db42796a14dfcbfcd181b268b65a3fc16e87147d6a3a3d7e0f487'
|
||||
),
|
||||
|
||||
# AgeGender
|
||||
AgeGenderWeights.DEFAULT: 'https://github.com/yakhyo/uniface/releases/download/weights/genderage.onnx',
|
||||
AgeGenderWeights.DEFAULT: ModelInfo(
|
||||
url='https://github.com/yakhyo/uniface/releases/download/weights/genderage.onnx',
|
||||
sha256='4fde69b1c810857b88c64a335084f1c3fe8f01246c9a191b48c7bb756d6652fb'
|
||||
),
|
||||
|
||||
# FairFace
|
||||
FairFaceWeights.DEFAULT: 'https://github.com/yakhyo/fairface-onnx/releases/download/weights/fairface.onnx',
|
||||
FairFaceWeights.DEFAULT: ModelInfo(
|
||||
url='https://github.com/yakhyo/fairface-onnx/releases/download/weights/fairface.onnx',
|
||||
sha256='9c8c47d437cd310538d233f2465f9ed0524cb7fb51882a37f74e8bc22437fdbf'
|
||||
),
|
||||
|
||||
# Landmarks
|
||||
LandmarkWeights.DEFAULT: 'https://github.com/yakhyo/uniface/releases/download/weights/2d106det.onnx',
|
||||
LandmarkWeights.DEFAULT: ModelInfo(
|
||||
url='https://github.com/yakhyo/uniface/releases/download/weights/2d106det.onnx',
|
||||
sha256='f001b856447c413801ef5c42091ed0cd516fcd21f2d6b79635b1e733a7109dbf'
|
||||
),
|
||||
|
||||
# Gaze (MobileGaze)
|
||||
GazeWeights.RESNET18: 'https://github.com/yakhyo/gaze-estimation/releases/download/weights/resnet18_gaze.onnx',
|
||||
GazeWeights.RESNET34: 'https://github.com/yakhyo/gaze-estimation/releases/download/weights/resnet34_gaze.onnx',
|
||||
GazeWeights.RESNET50: 'https://github.com/yakhyo/gaze-estimation/releases/download/weights/resnet50_gaze.onnx',
|
||||
GazeWeights.MOBILENET_V2: 'https://github.com/yakhyo/gaze-estimation/releases/download/weights/mobilenetv2_gaze.onnx',
|
||||
GazeWeights.MOBILEONE_S0: 'https://github.com/yakhyo/gaze-estimation/releases/download/weights/mobileone_s0_gaze.onnx',
|
||||
GazeWeights.RESNET18: ModelInfo(
|
||||
url='https://github.com/yakhyo/gaze-estimation/releases/download/weights/resnet18_gaze.onnx',
|
||||
sha256='404fec1efd07ff49f981e47f461c20c2627119e465ec441bbd1c067d3f16e657'
|
||||
),
|
||||
GazeWeights.RESNET34: ModelInfo(
|
||||
url='https://github.com/yakhyo/gaze-estimation/releases/download/weights/resnet34_gaze.onnx',
|
||||
sha256='c8e6b14f6095d2425241b9302aa663d9a23b7dfb9d43941352b718c91dc7f2cf'
|
||||
),
|
||||
GazeWeights.RESNET50: ModelInfo(
|
||||
url='https://github.com/yakhyo/gaze-estimation/releases/download/weights/resnet50_gaze.onnx',
|
||||
sha256='bb28d421565adc4dfb665742f8fc80bdef36dd8caa0c87e040e0937f9fdca9a6'
|
||||
),
|
||||
GazeWeights.MOBILENET_V2: ModelInfo(
|
||||
url='https://github.com/yakhyo/gaze-estimation/releases/download/weights/mobilenetv2_gaze.onnx',
|
||||
sha256='b81312df85c7ac1c1b5f78c573620d22c2719cb839650e15f12dc7eecb7744a4'
|
||||
),
|
||||
GazeWeights.MOBILEONE_S0: ModelInfo(
|
||||
url='https://github.com/yakhyo/gaze-estimation/releases/download/weights/mobileone_s0_gaze.onnx',
|
||||
sha256='8b4fdc4e3da44733c9a82e7776b411e4a39f94e8e285aee0fc85a548a55f7d9f'
|
||||
),
|
||||
|
||||
# Parsing
|
||||
ParsingWeights.RESNET18: 'https://github.com/yakhyo/face-parsing/releases/download/weights/resnet18.onnx',
|
||||
ParsingWeights.RESNET34: 'https://github.com/yakhyo/face-parsing/releases/download/weights/resnet34.onnx',
|
||||
ParsingWeights.RESNET18: ModelInfo(
|
||||
url='https://github.com/yakhyo/face-parsing/releases/download/weights/resnet18.onnx',
|
||||
sha256='0d9bd318e46987c3bdbfacae9e2c0f461cae1c6ac6ea6d43bbe541a91727e33f'
|
||||
),
|
||||
ParsingWeights.RESNET34: ModelInfo(
|
||||
url='https://github.com/yakhyo/face-parsing/releases/download/weights/resnet34.onnx',
|
||||
sha256='5b805bba7b5660ab7070b5a381dcf75e5b3e04199f1e9387232a77a00095102e'
|
||||
),
|
||||
|
||||
# Anti-Spoofing (MiniFASNet)
|
||||
MiniFASNetWeights.V1SE: 'https://github.com/yakhyo/face-anti-spoofing/releases/download/weights/MiniFASNetV1SE.onnx',
|
||||
MiniFASNetWeights.V2: 'https://github.com/yakhyo/face-anti-spoofing/releases/download/weights/MiniFASNetV2.onnx',
|
||||
MiniFASNetWeights.V1SE: ModelInfo(
|
||||
url='https://github.com/yakhyo/face-anti-spoofing/releases/download/weights/MiniFASNetV1SE.onnx',
|
||||
sha256='ebab7f90c7833fbccd46d3a555410e78d969db5438e169b6524be444862b3676'
|
||||
),
|
||||
MiniFASNetWeights.V2: ModelInfo(
|
||||
url='https://github.com/yakhyo/face-anti-spoofing/releases/download/weights/MiniFASNetV2.onnx',
|
||||
sha256='b32929adc2d9c34b9486f8c4c7bc97c1b69bc0ea9befefc380e4faae4e463907'
|
||||
),
|
||||
|
||||
# XSeg
|
||||
XSegWeights.DEFAULT: 'https://github.com/yakhyo/face-segmentation/releases/download/weights/xseg.onnx',
|
||||
XSegWeights.DEFAULT: ModelInfo(
|
||||
url='https://github.com/yakhyo/face-segmentation/releases/download/weights/xseg.onnx',
|
||||
sha256='0b57328efcb839d85973164b617ceee9dfe6cfcb2c82e8a033bba9f4f09b27e5'
|
||||
),
|
||||
}
|
||||
|
||||
MODEL_SHA256: dict[Enum, str] = {
|
||||
# RetinaFace
|
||||
RetinaFaceWeights.MNET_025: 'b7a7acab55e104dce6f32cdfff929bd83946da5cd869b9e2e9bdffafd1b7e4a5',
|
||||
RetinaFaceWeights.MNET_050: 'd8977186f6037999af5b4113d42ba77a84a6ab0c996b17c713cc3d53b88bfc37',
|
||||
RetinaFaceWeights.MNET_V1: '75c961aaf0aff03d13c074e9ec656e5510e174454dd4964a161aab4fe5f04153',
|
||||
RetinaFaceWeights.MNET_V2: '3ca44c045651cabeed1193a1fae8946ad1f3a55da8fa74b341feab5a8319f757',
|
||||
RetinaFaceWeights.RESNET18: 'e8b5ddd7d2c3c8f7c942f9f10cec09d8e319f78f09725d3f709631de34fb649d',
|
||||
RetinaFaceWeights.RESNET34: 'bd0263dc2a465d32859555cb1741f2d98991eb0053696e8ee33fec583d30e630',
|
||||
# MobileFace
|
||||
MobileFaceWeights.MNET_025: 'eeda7d23d9c2b40cf77fa8da8e895b5697465192648852216074679657f8ee8b',
|
||||
MobileFaceWeights.MNET_V2: '38b148284dd48cc898d5d4453104252fbdcbacc105fe3f0b80e78954d9d20d89',
|
||||
MobileFaceWeights.MNET_V3_SMALL: 'd4acafa1039a82957aa8a9a1dac278a401c353a749c39df43de0e29cc1c127c3',
|
||||
MobileFaceWeights.MNET_V3_LARGE: '0e48f8e11f070211716d03e5c65a3db35a5e917cfb5bc30552358629775a142a',
|
||||
# SphereFace
|
||||
SphereFaceWeights.SPHERE20: 'c02878cf658eb1861f580b7e7144b0d27cc29c440bcaa6a99d466d2854f14c9d',
|
||||
SphereFaceWeights.SPHERE36: '13b3890cd5d7dec2b63f7c36fd7ce07403e5a0bbb701d9647c0289e6cbe7bb20',
|
||||
# ArcFace
|
||||
ArcFaceWeights.MNET: '9cc6e4a75f0e2bf0b1aed94578f144d15175f357bdc05e815e5c4a02b319eb4f',
|
||||
ArcFaceWeights.RESNET: '4c06341c33c2ca1f86781dab0e829f88ad5b64be9fba56e56bc9ebdefc619e43',
|
||||
# AdaFace
|
||||
AdaFaceWeights.IR_18: '6b6a35772fb636cdd4fa86520c1a259d0c41472a76f70f802b351837a00d9870',
|
||||
AdaFaceWeights.IR_101: 'f2eb07d03de0af560a82e1214df799fec5e09375d43521e2868f9dc387e5a43e',
|
||||
# SCRFD
|
||||
SCRFDWeights.SCRFD_10G_KPS: '5838f7fe053675b1c7a08b633df49e7af5495cee0493c7dcf6697200b85b5b91',
|
||||
SCRFDWeights.SCRFD_500M_KPS: '5e4447f50245bbd7966bd6c0fa52938c61474a04ec7def48753668a9d8b4ea3a',
|
||||
# YOLOv5-Face
|
||||
YOLOv5FaceWeights.YOLOV5N: 'eb244a06e36999db732b317c2b30fa113cd6cfc1a397eaf738f2d6f33c01f640',
|
||||
YOLOv5FaceWeights.YOLOV5S: 'fc682801cd5880e1e296184a14aea0035486b5146ec1a1389d2e7149cb134bb2',
|
||||
YOLOv5FaceWeights.YOLOV5M: '04302ce27a15bde3e20945691b688e2dd018a10e92dd8932146bede6a49207b2',
|
||||
# YOLOv8-Face
|
||||
YOLOv8FaceWeights.YOLOV8_LITE_S: '11bc496be01356d2d960085bfd8abb8f103199900a034f239a8a1705a1b31dba',
|
||||
YOLOv8FaceWeights.YOLOV8N: '33f3951af7fc0c4d9b321b29cdcd8c9a59d0a29a8d4bdc01fcb5507d5c714809',
|
||||
# DDAFM
|
||||
DDAMFNWeights.AFFECNET7: '10535bf8b6afe8e9d6ae26cea6c3add9a93036e9addb6adebfd4a972171d015d',
|
||||
DDAMFNWeights.AFFECNET8: '8c66963bc71db42796a14dfcbfcd181b268b65a3fc16e87147d6a3a3d7e0f487',
|
||||
# AgeGender
|
||||
AgeGenderWeights.DEFAULT: '4fde69b1c810857b88c64a335084f1c3fe8f01246c9a191b48c7bb756d6652fb',
|
||||
# FairFace
|
||||
FairFaceWeights.DEFAULT: '9c8c47d437cd310538d233f2465f9ed0524cb7fb51882a37f74e8bc22437fdbf',
|
||||
# Landmark
|
||||
LandmarkWeights.DEFAULT: 'f001b856447c413801ef5c42091ed0cd516fcd21f2d6b79635b1e733a7109dbf',
|
||||
# MobileGaze (trained on Gaze360)
|
||||
GazeWeights.RESNET18: '404fec1efd07ff49f981e47f461c20c2627119e465ec441bbd1c067d3f16e657',
|
||||
GazeWeights.RESNET34: 'c8e6b14f6095d2425241b9302aa663d9a23b7dfb9d43941352b718c91dc7f2cf',
|
||||
GazeWeights.RESNET50: 'bb28d421565adc4dfb665742f8fc80bdef36dd8caa0c87e040e0937f9fdca9a6',
|
||||
GazeWeights.MOBILENET_V2: 'b81312df85c7ac1c1b5f78c573620d22c2719cb839650e15f12dc7eecb7744a4',
|
||||
GazeWeights.MOBILEONE_S0: '8b4fdc4e3da44733c9a82e7776b411e4a39f94e8e285aee0fc85a548a55f7d9f',
|
||||
# Face Parsing
|
||||
ParsingWeights.RESNET18: '0d9bd318e46987c3bdbfacae9e2c0f461cae1c6ac6ea6d43bbe541a91727e33f',
|
||||
ParsingWeights.RESNET34: '5b805bba7b5660ab7070b5a381dcf75e5b3e04199f1e9387232a77a00095102e',
|
||||
# Anti-Spoofing (MiniFASNet)
|
||||
MiniFASNetWeights.V1SE: 'ebab7f90c7833fbccd46d3a555410e78d969db5438e169b6524be444862b3676',
|
||||
MiniFASNetWeights.V2: 'b32929adc2d9c34b9486f8c4c7bc97c1b69bc0ea9befefc380e4faae4e463907',
|
||||
# XSeg
|
||||
XSegWeights.DEFAULT: '0b57328efcb839d85973164b617ceee9dfe6cfcb2c82e8a033bba9f4f09b27e5',
|
||||
}
|
||||
|
||||
# Backward compatibility (optional, can be removed if all code uses MODEL_REGISTRY)
|
||||
MODEL_URLS: dict[Enum, str] = {k: v.url for k, v in MODEL_REGISTRY.items()}
|
||||
MODEL_SHA256: dict[Enum, str] = {k: v.sha256 for k, v in MODEL_REGISTRY.items()}
|
||||
|
||||
CHUNK_SIZE = 8192
|
||||
|
||||
9
uniface/indexing/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
# Copyright 2025-2026 Yakhyokhuja Valikhujaev
|
||||
# Author: Yakhyokhuja Valikhujaev
|
||||
# GitHub: https://github.com/yakhyo
|
||||
|
||||
"""Vector indexing backends for fast similarity search."""
|
||||
|
||||
from uniface.indexing.faiss import FAISS
|
||||
|
||||
__all__ = ['FAISS']
|
||||
197
uniface/indexing/faiss.py
Normal file
@@ -0,0 +1,197 @@
|
||||
# Copyright 2025-2026 Yakhyokhuja Valikhujaev
|
||||
# Author: Yakhyokhuja Valikhujaev
|
||||
# GitHub: https://github.com/yakhyo
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from uniface.log import Logger
|
||||
|
||||
__all__ = ['FAISS']
|
||||
|
||||
Metadata = dict[str, Any]
|
||||
|
||||
|
||||
def _import_faiss():
|
||||
"""Lazily import faiss, raising a clear error if not installed."""
|
||||
# Prevent OpenMP abort on macOS when multiple libraries (e.g. scipy,
|
||||
# torch) each bundle their own libomp.
|
||||
os.environ.setdefault('KMP_DUPLICATE_LIB_OK', 'TRUE')
|
||||
|
||||
try:
|
||||
import faiss
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
'faiss is required for FAISS vector store. '
|
||||
'Install it with: pip install faiss-cpu (CPU) '
|
||||
'or: pip install faiss-gpu (CUDA)'
|
||||
) from exc
|
||||
return faiss
|
||||
|
||||
|
||||
class FAISS:
|
||||
"""FAISS vector store using IndexFlatIP (inner product).
|
||||
|
||||
Vectors must be L2-normalised **before** being added so that inner
|
||||
product equals cosine similarity. The store does not normalise
|
||||
internally -- that is the caller's responsibility.
|
||||
|
||||
Each vector is paired with a metadata dict that can carry any
|
||||
JSON-serialisable payload (person ID, name, source image, etc.).
|
||||
|
||||
Args:
|
||||
embedding_size: Dimension of embedding vectors.
|
||||
db_path: Directory for persisting the index and metadata.
|
||||
|
||||
Example:
|
||||
>>> from uniface.indexing import FAISS
|
||||
>>> store = FAISS(embedding_size=512, db_path='./my_index')
|
||||
>>> store.add(embedding, {'person_id': '001', 'name': 'Alice'})
|
||||
>>> result, score = store.search(query_embedding)
|
||||
>>> result['name']
|
||||
'Alice'
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_size: int = 512,
|
||||
db_path: str = './vector_index',
|
||||
) -> None:
|
||||
faiss = _import_faiss()
|
||||
|
||||
self.embedding_size = embedding_size
|
||||
self.db_path = db_path
|
||||
self._index_file = os.path.join(db_path, 'faiss_index.bin')
|
||||
self._meta_file = os.path.join(db_path, 'metadata.json')
|
||||
|
||||
os.makedirs(db_path, exist_ok=True)
|
||||
|
||||
self.index = faiss.IndexFlatIP(embedding_size)
|
||||
self.metadata: list[Metadata] = []
|
||||
|
||||
def add(self, embedding: np.ndarray, metadata: Metadata) -> None:
|
||||
"""Add a single embedding with associated metadata.
|
||||
|
||||
Args:
|
||||
embedding: Embedding vector (must be L2-normalised).
|
||||
metadata: Arbitrary dict of JSON-serialisable key-value pairs.
|
||||
"""
|
||||
vec = self._prepare(embedding).reshape(1, -1)
|
||||
self.index.add(vec)
|
||||
self.metadata.append(metadata)
|
||||
|
||||
def search(
|
||||
self,
|
||||
embedding: np.ndarray,
|
||||
threshold: float = 0.4,
|
||||
) -> tuple[Metadata | None, float]:
|
||||
"""Find the closest match for a query embedding.
|
||||
|
||||
Args:
|
||||
embedding: Query embedding vector (must be L2-normalised).
|
||||
threshold: Minimum cosine similarity to accept a match.
|
||||
|
||||
Returns:
|
||||
``(metadata, similarity)`` for the best match, or
|
||||
``(None, similarity)`` when below *threshold* or the
|
||||
index is empty.
|
||||
"""
|
||||
if self.index.ntotal == 0:
|
||||
return None, 0.0
|
||||
|
||||
vec = self._prepare(embedding).reshape(1, -1)
|
||||
similarities, indices = self.index.search(vec, 1)
|
||||
|
||||
similarity = float(similarities[0][0])
|
||||
idx = int(indices[0][0])
|
||||
|
||||
if similarity > threshold and 0 <= idx < len(self.metadata):
|
||||
return self.metadata[idx], similarity
|
||||
return None, similarity
|
||||
|
||||
def remove(self, key: str, value: Any) -> int:
|
||||
"""Remove all entries where ``metadata[key] == value`` and rebuild.
|
||||
|
||||
Args:
|
||||
key: Metadata key to match against.
|
||||
value: Value to match.
|
||||
|
||||
Returns:
|
||||
Number of entries removed.
|
||||
"""
|
||||
faiss = _import_faiss()
|
||||
|
||||
keep = [i for i, m in enumerate(self.metadata) if m.get(key) != value]
|
||||
removed = len(self.metadata) - len(keep)
|
||||
if removed == 0:
|
||||
return 0
|
||||
|
||||
if keep:
|
||||
vectors = np.empty((len(keep), self.embedding_size), dtype=np.float32)
|
||||
for dst, src in enumerate(keep):
|
||||
self.index.reconstruct(src, vectors[dst])
|
||||
new_index = faiss.IndexFlatIP(self.embedding_size)
|
||||
new_index.add(vectors)
|
||||
else:
|
||||
new_index = faiss.IndexFlatIP(self.embedding_size)
|
||||
|
||||
self.index = new_index
|
||||
self.metadata = [self.metadata[i] for i in keep]
|
||||
Logger.info('Removed %d entries where %s=%s (%d remaining)', removed, key, value, self.index.ntotal)
|
||||
return removed
|
||||
|
||||
def save(self) -> None:
|
||||
"""Persist the FAISS index and metadata to disk."""
|
||||
faiss = _import_faiss()
|
||||
|
||||
faiss.write_index(self.index, self._index_file)
|
||||
with open(self._meta_file, 'w', encoding='utf-8') as fh:
|
||||
json.dump(self.metadata, fh, ensure_ascii=False, indent=2)
|
||||
Logger.info('Saved FAISS index with %d vectors to %s', self.index.ntotal, self.db_path)
|
||||
|
||||
def load(self) -> bool:
|
||||
"""Load a previously saved index and metadata from disk.
|
||||
|
||||
Returns:
|
||||
``True`` if loaded successfully, ``False`` if files are missing.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If files exist but cannot be read.
|
||||
"""
|
||||
if not (os.path.exists(self._index_file) and os.path.exists(self._meta_file)):
|
||||
return False
|
||||
|
||||
faiss = _import_faiss()
|
||||
|
||||
try:
|
||||
loaded_index = faiss.read_index(self._index_file)
|
||||
with open(self._meta_file, encoding='utf-8') as fh:
|
||||
loaded_metadata: list[Metadata] = json.load(fh)
|
||||
except Exception as exc:
|
||||
raise RuntimeError(f'Failed to load FAISS index from {self.db_path}') from exc
|
||||
|
||||
self.index = loaded_index
|
||||
self.metadata = loaded_metadata
|
||||
Logger.info('Loaded FAISS index with %d vectors from %s', self.index.ntotal, self.db_path)
|
||||
return True
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
"""Number of vectors currently in the index."""
|
||||
return self.index.ntotal
|
||||
|
||||
@staticmethod
|
||||
def _prepare(vec: np.ndarray) -> np.ndarray:
|
||||
"""Cast to contiguous float32 for FAISS compatibility."""
|
||||
return np.ascontiguousarray(vec.ravel(), dtype=np.float32)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.index.ntotal
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'FAISS(embedding_size={self.embedding_size}, vectors={self.index.ntotal})'
|
||||
@@ -14,6 +14,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from enum import Enum
|
||||
import hashlib
|
||||
import os
|
||||
import time
|
||||
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
@@ -64,7 +65,12 @@ def set_cache_dir(path: str) -> None:
|
||||
Logger.info(f'Cache directory set to: {path}')
|
||||
|
||||
|
||||
def verify_model_weights(model_name: Enum, root: str | None = None) -> str:
|
||||
def verify_model_weights(
|
||||
model_name: Enum,
|
||||
root: str | None = None,
|
||||
timeout: int = 60,
|
||||
max_retries: int = 3,
|
||||
) -> str:
|
||||
"""Ensure model weights are present, downloading and verifying them if necessary.
|
||||
|
||||
Given a model identifier from an Enum class (e.g., `RetinaFaceWeights.MNET_V2`),
|
||||
@@ -76,6 +82,8 @@ def verify_model_weights(model_name: Enum, root: str | None = None) -> str:
|
||||
model_name: Model weight identifier enum (e.g., `RetinaFaceWeights.MNET_V2`).
|
||||
root: Directory to store or locate the model weights.
|
||||
If None, uses the cache directory from :func:`get_cache_dir`.
|
||||
timeout: Connection timeout in seconds. Defaults to 60.
|
||||
max_retries: Maximum number of download attempts. Defaults to 3.
|
||||
|
||||
Returns:
|
||||
Absolute path to the verified model weights file.
|
||||
@@ -95,59 +103,75 @@ def verify_model_weights(model_name: Enum, root: str | None = None) -> str:
|
||||
root = os.path.expanduser(root) if root is not None else get_cache_dir()
|
||||
os.makedirs(root, exist_ok=True)
|
||||
|
||||
# Keep model_name as enum for dictionary lookup
|
||||
url = const.MODEL_URLS.get(model_name)
|
||||
if not url:
|
||||
Logger.error(f"No URL found for model '{model_name}'")
|
||||
raise ValueError(f"No URL found for model '{model_name}'")
|
||||
# Lookup model info from registry
|
||||
model_info = const.MODEL_REGISTRY.get(model_name)
|
||||
if not model_info:
|
||||
Logger.error(f"No entry found in MODEL_REGISTRY for model '{model_name}'")
|
||||
raise ValueError(f"Unknown model identifier: '{model_name}'")
|
||||
|
||||
url = model_info.url
|
||||
expected_hash = model_info.sha256
|
||||
|
||||
file_ext = os.path.splitext(url)[1]
|
||||
model_path = os.path.normpath(os.path.join(root, f'{model_name.value}{file_ext}'))
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
Logger.info(f"Downloading model '{model_name}' from {url}")
|
||||
Logger.info(f"Downloading model '{model_name.value}' from {url}")
|
||||
try:
|
||||
download_file(url, model_path)
|
||||
Logger.info(f"Successfully downloaded '{model_name}' to {model_path}")
|
||||
download_file(url, model_path, timeout=timeout, max_retries=max_retries)
|
||||
Logger.info(f"Successfully downloaded '{model_name.value}' to {model_path}")
|
||||
except Exception as e:
|
||||
Logger.error(f"Failed to download model '{model_name}': {e}")
|
||||
raise ConnectionError(f"Download failed for '{model_name}'") from e
|
||||
Logger.error(f"Failed to download model '{model_name.value}': {e}")
|
||||
raise ConnectionError(f"Download failed for '{model_name.value}' after {max_retries} attempts") from e
|
||||
|
||||
expected_hash = const.MODEL_SHA256.get(model_name)
|
||||
if expected_hash and not verify_file_hash(model_path, expected_hash):
|
||||
os.remove(model_path) # Remove corrupted file
|
||||
Logger.warning('Corrupted weight detected. Removing...')
|
||||
raise ValueError(f"Hash mismatch for '{model_name}'. The file may be corrupted; please try downloading again.")
|
||||
Logger.warning(f"Corrupted weights detected for '{model_name.value}'. Removing...")
|
||||
raise ValueError(f"Hash mismatch for '{model_name.value}'. The file may be corrupted; please try again.")
|
||||
|
||||
return model_path
|
||||
|
||||
|
||||
def download_file(url: str, dest_path: str, timeout: int = 30) -> None:
|
||||
"""Download a file from a URL in chunks and save it to the destination path.
|
||||
def download_file(url: str, dest_path: str, timeout: int = 60, max_retries: int = 3) -> None:
|
||||
"""Download a file from a URL with retry logic.
|
||||
|
||||
Args:
|
||||
url: URL to download from.
|
||||
dest_path: Local file path to save to.
|
||||
timeout: Connection timeout in seconds. Defaults to 30.
|
||||
timeout: Connection timeout in seconds. Defaults to 60.
|
||||
max_retries: Maximum number of attempts. Defaults to 3.
|
||||
"""
|
||||
try:
|
||||
response = requests.get(url, stream=True, timeout=timeout)
|
||||
response.raise_for_status()
|
||||
with (
|
||||
open(dest_path, 'wb') as file,
|
||||
tqdm(
|
||||
desc=f'Downloading {dest_path}',
|
||||
unit='B',
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as progress,
|
||||
):
|
||||
for chunk in response.iter_content(chunk_size=const.CHUNK_SIZE):
|
||||
if chunk:
|
||||
file.write(chunk)
|
||||
progress.update(len(chunk))
|
||||
except requests.RequestException as e:
|
||||
raise ConnectionError(f'Failed to download file from {url}. Error: {e}') from e
|
||||
last_error = None
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
response = requests.get(url, stream=True, timeout=timeout)
|
||||
response.raise_for_status()
|
||||
|
||||
total_size = int(response.headers.get('content-length', 0))
|
||||
|
||||
with (
|
||||
open(dest_path, 'wb') as file,
|
||||
tqdm(
|
||||
total=total_size,
|
||||
desc=f'Attempt {attempt + 1}/{max_retries}',
|
||||
unit='B',
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as progress,
|
||||
):
|
||||
for chunk in response.iter_content(chunk_size=const.CHUNK_SIZE):
|
||||
if chunk:
|
||||
file.write(chunk)
|
||||
progress.update(len(chunk))
|
||||
return # Success
|
||||
except (OSError, requests.RequestException) as e:
|
||||
last_error = e
|
||||
Logger.warning(f'Download attempt {attempt + 1} failed: {e}. Retrying...')
|
||||
if os.path.exists(dest_path):
|
||||
os.remove(dest_path)
|
||||
time.sleep(2**attempt) # Exponential backoff
|
||||
|
||||
raise ConnectionError(f'Failed to download file from {url}. Error: {last_error}')
|
||||
|
||||
|
||||
def verify_file_hash(file_path: str, expected_hash: str) -> bool:
|
||||
@@ -162,7 +186,9 @@ def verify_file_hash(file_path: str, expected_hash: str) -> bool:
|
||||
return actual_hash == expected_hash
|
||||
|
||||
|
||||
def download_models(model_names: list[Enum], max_workers: int = 4) -> dict[Enum, str]:
|
||||
def download_models(
|
||||
model_names: list[Enum], max_workers: int = 4, timeout: int = 60, max_retries: int = 3
|
||||
) -> dict[Enum, str]:
|
||||
"""Download and verify multiple models concurrently.
|
||||
|
||||
Uses a thread pool to download models in parallel, which is significantly
|
||||
@@ -171,6 +197,8 @@ def download_models(model_names: list[Enum], max_workers: int = 4) -> dict[Enum,
|
||||
Args:
|
||||
model_names: List of model weight enum identifiers to download.
|
||||
max_workers: Maximum number of concurrent download threads. Defaults to 4.
|
||||
timeout: Connection timeout in seconds. Defaults to 60.
|
||||
max_retries: Maximum number of attempts per model. Defaults to 3.
|
||||
|
||||
Returns:
|
||||
Mapping of each model enum to its local file path.
|
||||
@@ -187,7 +215,10 @@ def download_models(model_names: list[Enum], max_workers: int = 4) -> dict[Enum,
|
||||
errors: list[str] = []
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
future_to_model = {executor.submit(verify_model_weights, name): name for name in model_names}
|
||||
future_to_model = {
|
||||
executor.submit(verify_model_weights, name, timeout=timeout, max_retries=max_retries): name
|
||||
for name in model_names
|
||||
}
|
||||
|
||||
for future in as_completed(future_to_model):
|
||||
model = future_to_model[future]
|
||||
@@ -204,8 +235,3 @@ def download_models(model_names: list[Enum], max_workers: int = 4) -> dict[Enum,
|
||||
|
||||
Logger.info(f'All {len(results)} model(s) downloaded and verified')
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
for model in const.RetinaFaceWeights:
|
||||
model_path = verify_model_weights(model)
|
||||
|
||||
@@ -10,6 +10,8 @@ inference sessions with automatic hardware acceleration detection.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
|
||||
import onnxruntime as ort
|
||||
|
||||
from uniface.log import Logger
|
||||
@@ -17,6 +19,7 @@ from uniface.log import Logger
|
||||
__all__ = ['create_onnx_session', 'get_available_providers']
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def get_available_providers() -> list[str]:
|
||||
"""Get list of available ONNX Runtime execution providers.
|
||||
|
||||
@@ -30,7 +33,7 @@ def get_available_providers() -> list[str]:
|
||||
|
||||
Example:
|
||||
>>> providers = get_available_providers()
|
||||
>>> # On M4 Mac: ['CoreMLExecutionProvider', 'CPUExecutionProvider']
|
||||
>>> # On macOS: ['CoreMLExecutionProvider', 'CPUExecutionProvider']
|
||||
>>> # On Linux with CUDA: ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
||||
"""
|
||||
available = ort.get_available_providers()
|
||||
@@ -98,7 +101,7 @@ def create_onnx_session(
|
||||
'CPUExecutionProvider': 'CPU',
|
||||
}
|
||||
provider_display = provider_names.get(active_provider, active_provider)
|
||||
Logger.info(f'✓ Model loaded ({provider_display})')
|
||||
Logger.debug(f'Model loaded from {model_path} ({provider_display})')
|
||||
|
||||
return session
|
||||
except Exception as e:
|
||||
|
||||
@@ -4,14 +4,11 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
from typing import ClassVar
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
__all__ = ['BlurFace', 'EllipticalBlur']
|
||||
|
||||
|
||||
|
||||