mirror of
https://github.com/yakhyo/uniface.git
synced 2025-12-30 09:02:25 +00:00
feat: Add 2D Gaze estimation models (#34)
* feat: Add Gaze Estimation, update docs and Add example notebook, inference code * docs: Update README.md
This commit is contained in:
committed by
GitHub
parent
da8a5cf35b
commit
4d1921e531
@@ -58,3 +58,6 @@ Open an issue or start a discussion on GitHub.
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
42
MODELS.md
42
MODELS.md
@@ -291,6 +291,47 @@ emotion, confidence = predictor.predict(image, landmarks)
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## Gaze Estimation Models
|
||||||
|
|
||||||
|
### MobileGaze Family
|
||||||
|
|
||||||
|
Real-time gaze direction prediction models trained on Gaze360 dataset. Returns pitch (vertical) and yaw (horizontal) angles in radians.
|
||||||
|
|
||||||
|
| Model Name | Params | Size | MAE* | Use Case |
|
||||||
|
| -------------- | ------ | ------- | ----- | ----------------------------- |
|
||||||
|
| `RESNET18` | 11.7M | 43 MB | 12.84 | Balanced accuracy/speed |
|
||||||
|
| `RESNET34` ⭐ | 24.8M | 81.6 MB | 11.33 | **Recommended default** |
|
||||||
|
| `RESNET50` | 25.6M | 91.3 MB | 11.34 | High accuracy |
|
||||||
|
| `MOBILENET_V2` | 3.5M | 9.59 MB | 13.07 | Mobile/Edge devices |
|
||||||
|
| `MOBILEONE_S0` | 2.1M | 4.8 MB | 12.58 | Lightweight/Real-time |
|
||||||
|
|
||||||
|
*MAE (Mean Absolute Error) in degrees on Gaze360 test set - lower is better
|
||||||
|
|
||||||
|
**Dataset**: Trained on Gaze360 (indoor/outdoor scenes with diverse head poses)
|
||||||
|
**Training**: 200 epochs with classification-based approach (binned angles)
|
||||||
|
|
||||||
|
#### Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
from uniface import MobileGaze
|
||||||
|
from uniface.constants import GazeWeights
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# Default (recommended)
|
||||||
|
gaze_estimator = MobileGaze() # Uses RESNET34
|
||||||
|
|
||||||
|
# Lightweight model
|
||||||
|
gaze_estimator = MobileGaze(model_name=GazeWeights.MOBILEONE_S0)
|
||||||
|
|
||||||
|
# Estimate gaze from face crop
|
||||||
|
pitch, yaw = gaze_estimator.estimate(face_crop)
|
||||||
|
print(f"Pitch: {np.degrees(pitch):.1f}°, Yaw: {np.degrees(yaw):.1f}°")
|
||||||
|
```
|
||||||
|
|
||||||
|
**Note**: Requires face crop as input. Use face detection first to obtain bounding boxes.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## Model Updates
|
## Model Updates
|
||||||
|
|
||||||
Models are automatically downloaded and cached on first use. Cache location: `~/.uniface/models/`
|
Models are automatically downloaded and cached on first use. Cache location: `~/.uniface/models/`
|
||||||
@@ -330,6 +371,7 @@ python scripts/download_model.py --model MNET_V2
|
|||||||
- **YOLOv5-Face Original**: [deepcam-cn/yolov5-face](https://github.com/deepcam-cn/yolov5-face) - Original PyTorch implementation
|
- **YOLOv5-Face Original**: [deepcam-cn/yolov5-face](https://github.com/deepcam-cn/yolov5-face) - Original PyTorch implementation
|
||||||
- **YOLOv5-Face ONNX**: [yakhyo/yolov5-face-onnx-inference](https://github.com/yakhyo/yolov5-face-onnx-inference) - ONNX inference implementation
|
- **YOLOv5-Face ONNX**: [yakhyo/yolov5-face-onnx-inference](https://github.com/yakhyo/yolov5-face-onnx-inference) - ONNX inference implementation
|
||||||
- **Face Recognition Training**: [yakhyo/face-recognition](https://github.com/yakhyo/face-recognition) - ArcFace, MobileFace, SphereFace training code
|
- **Face Recognition Training**: [yakhyo/face-recognition](https://github.com/yakhyo/face-recognition) - ArcFace, MobileFace, SphereFace training code
|
||||||
|
- **Gaze Estimation Training**: [yakhyo/gaze-estimation](https://github.com/yakhyo/gaze-estimation) - MobileGaze training code and pretrained weights
|
||||||
- **InsightFace**: [deepinsight/insightface](https://github.com/deepinsight/insightface) - Model architectures and pretrained weights
|
- **InsightFace**: [deepinsight/insightface](https://github.com/deepinsight/insightface) - Model architectures and pretrained weights
|
||||||
|
|
||||||
### Papers
|
### Papers
|
||||||
|
|||||||
@@ -242,7 +242,50 @@ if faces:
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 7. Batch Processing (3 minutes)
|
## 7. Gaze Estimation (2 minutes)
|
||||||
|
|
||||||
|
Estimate where a person is looking:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from uniface import RetinaFace, MobileGaze
|
||||||
|
from uniface.visualization import draw_gaze
|
||||||
|
|
||||||
|
# Initialize models
|
||||||
|
detector = RetinaFace()
|
||||||
|
gaze_estimator = MobileGaze()
|
||||||
|
|
||||||
|
# Load image
|
||||||
|
image = cv2.imread("photo.jpg")
|
||||||
|
faces = detector.detect(image)
|
||||||
|
|
||||||
|
# Estimate gaze for each face
|
||||||
|
for i, face in enumerate(faces):
|
||||||
|
bbox = face['bbox']
|
||||||
|
x1, y1, x2, y2 = map(int, bbox[:4])
|
||||||
|
face_crop = image[y1:y2, x1:x2]
|
||||||
|
|
||||||
|
if face_crop.size > 0:
|
||||||
|
pitch, yaw = gaze_estimator.estimate(face_crop)
|
||||||
|
print(f"Face {i+1}: pitch={np.degrees(pitch):.1f}°, yaw={np.degrees(yaw):.1f}°")
|
||||||
|
|
||||||
|
# Draw gaze direction
|
||||||
|
draw_gaze(image, bbox, pitch, yaw)
|
||||||
|
|
||||||
|
cv2.imwrite("gaze_output.jpg", image)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Output:**
|
||||||
|
|
||||||
|
```
|
||||||
|
Face 1: pitch=5.2°, yaw=-12.3°
|
||||||
|
Face 2: pitch=-8.1°, yaw=15.7°
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 8. Batch Processing (3 minutes)
|
||||||
|
|
||||||
Process multiple images:
|
Process multiple images:
|
||||||
|
|
||||||
@@ -275,7 +318,7 @@ print("Done!")
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 8. Model Selection
|
## 9. Model Selection
|
||||||
|
|
||||||
Choose the right model for your use case:
|
Choose the right model for your use case:
|
||||||
|
|
||||||
@@ -326,6 +369,22 @@ recognizer = MobileFace(model_name=MobileFaceWeights.MNET_V2) # Fast, small siz
|
|||||||
recognizer = SphereFace(model_name=SphereFaceWeights.SPHERE20) # Alternative method
|
recognizer = SphereFace(model_name=SphereFaceWeights.SPHERE20) # Alternative method
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Gaze Estimation Models
|
||||||
|
|
||||||
|
```python
|
||||||
|
from uniface import MobileGaze
|
||||||
|
from uniface.constants import GazeWeights
|
||||||
|
|
||||||
|
# Default (recommended)
|
||||||
|
gaze_estimator = MobileGaze() # Uses RESNET34
|
||||||
|
|
||||||
|
# Lightweight (mobile/edge devices)
|
||||||
|
gaze_estimator = MobileGaze(model_name=GazeWeights.MOBILEONE_S0)
|
||||||
|
|
||||||
|
# High accuracy
|
||||||
|
gaze_estimator = MobileGaze(model_name=GazeWeights.RESNET50)
|
||||||
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Common Issues
|
## Common Issues
|
||||||
@@ -400,4 +459,5 @@ Explore interactive examples for common tasks:
|
|||||||
- **RetinaFace Training**: [yakhyo/retinaface-pytorch](https://github.com/yakhyo/retinaface-pytorch)
|
- **RetinaFace Training**: [yakhyo/retinaface-pytorch](https://github.com/yakhyo/retinaface-pytorch)
|
||||||
- **YOLOv5-Face ONNX**: [yakhyo/yolov5-face-onnx-inference](https://github.com/yakhyo/yolov5-face-onnx-inference)
|
- **YOLOv5-Face ONNX**: [yakhyo/yolov5-face-onnx-inference](https://github.com/yakhyo/yolov5-face-onnx-inference)
|
||||||
- **Face Recognition Training**: [yakhyo/face-recognition](https://github.com/yakhyo/face-recognition)
|
- **Face Recognition Training**: [yakhyo/face-recognition](https://github.com/yakhyo/face-recognition)
|
||||||
|
- **Gaze Estimation Training**: [yakhyo/gaze-estimation](https://github.com/yakhyo/gaze-estimation)
|
||||||
- **InsightFace**: [deepinsight/insightface](https://github.com/deepinsight/insightface)
|
- **InsightFace**: [deepinsight/insightface](https://github.com/deepinsight/insightface)
|
||||||
|
|||||||
33
README.md
33
README.md
@@ -20,6 +20,7 @@
|
|||||||
- **High-Speed Face Detection**: ONNX-optimized RetinaFace, SCRFD, and YOLOv5-Face models
|
- **High-Speed Face Detection**: ONNX-optimized RetinaFace, SCRFD, and YOLOv5-Face models
|
||||||
- **Facial Landmark Detection**: Accurate 106-point landmark localization
|
- **Facial Landmark Detection**: Accurate 106-point landmark localization
|
||||||
- **Face Recognition**: ArcFace, MobileFace, and SphereFace embeddings
|
- **Face Recognition**: ArcFace, MobileFace, and SphereFace embeddings
|
||||||
|
- **Gaze Estimation**: Real-time gaze direction prediction with MobileGaze
|
||||||
- **Attribute Analysis**: Age, gender, and emotion detection
|
- **Attribute Analysis**: Age, gender, and emotion detection
|
||||||
- **Face Alignment**: Precise alignment for downstream tasks
|
- **Face Alignment**: Precise alignment for downstream tasks
|
||||||
- **Hardware Acceleration**: ARM64 optimizations (Apple Silicon), CUDA (NVIDIA), CPU fallback
|
- **Hardware Acceleration**: ARM64 optimizations (Apple Silicon), CUDA (NVIDIA), CPU fallback
|
||||||
@@ -152,6 +153,29 @@ gender_str = 'Female' if gender == 0 else 'Male'
|
|||||||
print(f"{gender_str}, {age} years old")
|
print(f"{gender_str}, {age} years old")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Gaze Estimation
|
||||||
|
|
||||||
|
```python
|
||||||
|
from uniface import RetinaFace, MobileGaze
|
||||||
|
from uniface.visualization import draw_gaze
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
detector = RetinaFace()
|
||||||
|
gaze_estimator = MobileGaze()
|
||||||
|
|
||||||
|
faces = detector.detect(image)
|
||||||
|
for face in faces:
|
||||||
|
bbox = face['bbox']
|
||||||
|
x1, y1, x2, y2 = map(int, bbox[:4])
|
||||||
|
face_crop = image[y1:y2, x1:x2]
|
||||||
|
|
||||||
|
pitch, yaw = gaze_estimator.estimate(face_crop)
|
||||||
|
print(f"Gaze: pitch={np.degrees(pitch):.1f}°, yaw={np.degrees(yaw):.1f}°")
|
||||||
|
|
||||||
|
# Visualize
|
||||||
|
draw_gaze(image, bbox, pitch, yaw)
|
||||||
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Documentation
|
## Documentation
|
||||||
@@ -252,6 +276,12 @@ faces = detect_faces(image, method='retinaface', conf_thresh=0.8) # methods: re
|
|||||||
| `AgeGender` | `model_name=AgeGenderWeights.DEFAULT`; `input_size` auto-detected | Requires bbox; ONNXRuntime |
|
| `AgeGender` | `model_name=AgeGenderWeights.DEFAULT`; `input_size` auto-detected | Requires bbox; ONNXRuntime |
|
||||||
| `Emotion` | `model_weights=DDAMFNWeights.AFFECNET7`, `input_size=(112, 112)` | Requires 5-point landmarks; TorchScript |
|
| `Emotion` | `model_weights=DDAMFNWeights.AFFECNET7`, `input_size=(112, 112)` | Requires 5-point landmarks; TorchScript |
|
||||||
|
|
||||||
|
**Gaze Estimation**
|
||||||
|
|
||||||
|
| Class | Key params (defaults) | Notes |
|
||||||
|
| ------------- | ------------------------------------------ | ------------------------------------ |
|
||||||
|
| `MobileGaze` | `model_name=GazeWeights.RESNET34` | Returns (pitch, yaw) angles in radians; trained on Gaze360 |
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Model Performance
|
## Model Performance
|
||||||
@@ -298,6 +328,7 @@ Interactive examples covering common face analysis tasks:
|
|||||||
| **Face Recognition** | Extract face embeddings and compare faces | [face_analyzer.ipynb](examples/face_analyzer.ipynb) |
|
| **Face Recognition** | Extract face embeddings and compare faces | [face_analyzer.ipynb](examples/face_analyzer.ipynb) |
|
||||||
| **Face Verification** | Compare two faces to verify identity | [face_verification.ipynb](examples/face_verification.ipynb) |
|
| **Face Verification** | Compare two faces to verify identity | [face_verification.ipynb](examples/face_verification.ipynb) |
|
||||||
| **Face Search** | Find a person in a group photo | [face_search.ipynb](examples/face_search.ipynb) |
|
| **Face Search** | Find a person in a group photo | [face_search.ipynb](examples/face_search.ipynb) |
|
||||||
|
| **Gaze Estimation** | Estimate gaze direction from face images | [gaze_estimation.ipynb](examples/gaze_estimation.ipynb) |
|
||||||
|
|
||||||
### Webcam Face Detection
|
### Webcam Face Detection
|
||||||
|
|
||||||
@@ -488,6 +519,7 @@ uniface/
|
|||||||
│ ├── detection/ # Face detection models
|
│ ├── detection/ # Face detection models
|
||||||
│ ├── recognition/ # Face recognition models
|
│ ├── recognition/ # Face recognition models
|
||||||
│ ├── landmark/ # Landmark detection
|
│ ├── landmark/ # Landmark detection
|
||||||
|
│ ├── gaze/ # Gaze estimation
|
||||||
│ ├── attribute/ # Age, gender, emotion
|
│ ├── attribute/ # Age, gender, emotion
|
||||||
│ ├── onnx_utils.py # ONNX Runtime utilities
|
│ ├── onnx_utils.py # ONNX Runtime utilities
|
||||||
│ ├── model_store.py # Model download & caching
|
│ ├── model_store.py # Model download & caching
|
||||||
@@ -504,6 +536,7 @@ uniface/
|
|||||||
- **RetinaFace Training**: [yakhyo/retinaface-pytorch](https://github.com/yakhyo/retinaface-pytorch) - PyTorch implementation and training code
|
- **RetinaFace Training**: [yakhyo/retinaface-pytorch](https://github.com/yakhyo/retinaface-pytorch) - PyTorch implementation and training code
|
||||||
- **YOLOv5-Face ONNX**: [yakhyo/yolov5-face-onnx-inference](https://github.com/yakhyo/yolov5-face-onnx-inference) - ONNX inference implementation
|
- **YOLOv5-Face ONNX**: [yakhyo/yolov5-face-onnx-inference](https://github.com/yakhyo/yolov5-face-onnx-inference) - ONNX inference implementation
|
||||||
- **Face Recognition Training**: [yakhyo/face-recognition](https://github.com/yakhyo/face-recognition) - ArcFace, MobileFace, SphereFace training code
|
- **Face Recognition Training**: [yakhyo/face-recognition](https://github.com/yakhyo/face-recognition) - ArcFace, MobileFace, SphereFace training code
|
||||||
|
- **Gaze Estimation Training**: [yakhyo/gaze-estimation](https://github.com/yakhyo/gaze-estimation) - MobileGaze training code and pretrained weights
|
||||||
- **InsightFace**: [deepinsight/insightface](https://github.com/deepinsight/insightface) - Model architectures and pretrained weights
|
- **InsightFace**: [deepinsight/insightface](https://github.com/deepinsight/insightface) - Model architectures and pretrained weights
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|||||||
271
examples/gaze_estimation.ipynb
Normal file
271
examples/gaze_estimation.ipynb
Normal file
File diff suppressed because one or more lines are too long
@@ -1,7 +1,7 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "uniface"
|
name = "uniface"
|
||||||
version = "1.3.2"
|
version = "1.4.0"
|
||||||
description = "UniFace: A Comprehensive Library for Face Detection, Recognition, Landmark Analysis, Age, and Gender Detection"
|
description = "UniFace: A Comprehensive Library for Face Detection, Recognition, Landmark Analysis, Gaze Estimation, Age, and Gender Detection"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { text = "MIT" }
|
license = { text = "MIT" }
|
||||||
authors = [{ name = "Yakhyokhuja Valikhujaev", email = "yakhyo9696@gmail.com" }]
|
authors = [{ name = "Yakhyokhuja Valikhujaev", email = "yakhyo9696@gmail.com" }]
|
||||||
@@ -14,6 +14,7 @@ keywords = [
|
|||||||
"face-detection",
|
"face-detection",
|
||||||
"face-recognition",
|
"face-recognition",
|
||||||
"facial-landmarks",
|
"facial-landmarks",
|
||||||
|
"gaze-estimation",
|
||||||
"age-detection",
|
"age-detection",
|
||||||
"gender-detection",
|
"gender-detection",
|
||||||
"computer-vision",
|
"computer-vision",
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ Scripts for testing UniFace features.
|
|||||||
| `run_detection.py` | Face detection on image or webcam |
|
| `run_detection.py` | Face detection on image or webcam |
|
||||||
| `run_age_gender.py` | Age and gender prediction |
|
| `run_age_gender.py` | Age and gender prediction |
|
||||||
| `run_emotion.py` | Emotion detection (7 or 8 emotions) |
|
| `run_emotion.py` | Emotion detection (7 or 8 emotions) |
|
||||||
|
| `run_gaze_estimation.py` | Gaze direction estimation |
|
||||||
| `run_landmarks.py` | 106-point facial landmark detection |
|
| `run_landmarks.py` | 106-point facial landmark detection |
|
||||||
| `run_recognition.py` | Face embedding extraction and comparison |
|
| `run_recognition.py` | Face embedding extraction and comparison |
|
||||||
| `run_face_analyzer.py` | Complete face analysis (detection + recognition + attributes) |
|
| `run_face_analyzer.py` | Complete face analysis (detection + recognition + attributes) |
|
||||||
@@ -33,6 +34,10 @@ python scripts/run_age_gender.py --webcam
|
|||||||
python scripts/run_emotion.py --image assets/test.jpg
|
python scripts/run_emotion.py --image assets/test.jpg
|
||||||
python scripts/run_emotion.py --webcam
|
python scripts/run_emotion.py --webcam
|
||||||
|
|
||||||
|
# Gaze estimation
|
||||||
|
python scripts/run_gaze_estimation.py --image assets/test.jpg
|
||||||
|
python scripts/run_gaze_estimation.py --webcam
|
||||||
|
|
||||||
# Landmarks
|
# Landmarks
|
||||||
python scripts/run_landmarks.py --image assets/test.jpg
|
python scripts/run_landmarks.py --image assets/test.jpg
|
||||||
python scripts/run_landmarks.py --webcam
|
python scripts/run_landmarks.py --webcam
|
||||||
|
|||||||
@@ -79,7 +79,9 @@ def run_webcam(detector, age_gender, threshold: float = 0.6):
|
|||||||
bboxes = [f['bbox'] for f in faces]
|
bboxes = [f['bbox'] for f in faces]
|
||||||
scores = [f['confidence'] for f in faces]
|
scores = [f['confidence'] for f in faces]
|
||||||
landmarks = [f['landmarks'] for f in faces]
|
landmarks = [f['landmarks'] for f in faces]
|
||||||
draw_detections(frame, bboxes, scores, landmarks, vis_threshold=threshold)
|
draw_detections(
|
||||||
|
image=frame, bboxes=bboxes, scores=scores, landmarks=landmarks, vis_threshold=threshold, fancy_bbox=True
|
||||||
|
)
|
||||||
|
|
||||||
for face in faces:
|
for face in faces:
|
||||||
gender_id, age = age_gender.predict(frame, face['bbox']) # predict per face
|
gender_id, age = age_gender.predict(frame, face['bbox']) # predict per face
|
||||||
|
|||||||
@@ -98,7 +98,7 @@ def main():
|
|||||||
else:
|
else:
|
||||||
from uniface.constants import YOLOv5FaceWeights
|
from uniface.constants import YOLOv5FaceWeights
|
||||||
|
|
||||||
detector = YOLOv5Face(model_name=YOLOv5FaceWeights.YOLOV5N)
|
detector = YOLOv5Face(model_name=YOLOv5FaceWeights.YOLOV5M)
|
||||||
|
|
||||||
if args.webcam:
|
if args.webcam:
|
||||||
run_webcam(detector, args.threshold)
|
run_webcam(detector, args.threshold)
|
||||||
|
|||||||
104
scripts/run_gaze_estimation.py
Normal file
104
scripts/run_gaze_estimation.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
# Gaze estimation on detected faces
|
||||||
|
# Usage: python run_gaze_estimation.py --image path/to/image.jpg
|
||||||
|
# python run_gaze_estimation.py --webcam
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from uniface import RetinaFace
|
||||||
|
from uniface.gaze import MobileGaze
|
||||||
|
from uniface.visualization import draw_gaze
|
||||||
|
|
||||||
|
|
||||||
|
def process_image(detector, gaze_estimator, image_path: str, save_dir: str = 'outputs'):
|
||||||
|
image = cv2.imread(image_path)
|
||||||
|
if image is None:
|
||||||
|
print(f"Error: Failed to load image from '{image_path}'")
|
||||||
|
return
|
||||||
|
|
||||||
|
faces = detector.detect(image)
|
||||||
|
print(f'Detected {len(faces)} face(s)')
|
||||||
|
|
||||||
|
for i, face in enumerate(faces):
|
||||||
|
bbox = face['bbox']
|
||||||
|
x1, y1, x2, y2 = map(int, bbox[:4])
|
||||||
|
face_crop = image[y1:y2, x1:x2]
|
||||||
|
|
||||||
|
if face_crop.size == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
pitch, yaw = gaze_estimator.estimate(face_crop)
|
||||||
|
print(f' Face {i + 1}: pitch={np.degrees(pitch):.1f}°, yaw={np.degrees(yaw):.1f}°')
|
||||||
|
|
||||||
|
# Draw both bbox and gaze arrow with angle text
|
||||||
|
draw_gaze(image, bbox, pitch, yaw, draw_angles=True)
|
||||||
|
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
output_path = os.path.join(save_dir, f'{Path(image_path).stem}_gaze.jpg')
|
||||||
|
cv2.imwrite(output_path, image)
|
||||||
|
print(f'Output saved: {output_path}')
|
||||||
|
|
||||||
|
|
||||||
|
def run_webcam(detector, gaze_estimator):
|
||||||
|
cap = cv2.VideoCapture(0)
|
||||||
|
if not cap.isOpened():
|
||||||
|
print('Cannot open webcam')
|
||||||
|
return
|
||||||
|
|
||||||
|
print("Press 'q' to quit")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
ret, frame = cap.read()
|
||||||
|
if not ret:
|
||||||
|
break
|
||||||
|
|
||||||
|
frame = cv2.flip(frame, 1)
|
||||||
|
faces = detector.detect(frame)
|
||||||
|
|
||||||
|
for face in faces:
|
||||||
|
bbox = face['bbox']
|
||||||
|
x1, y1, x2, y2 = map(int, bbox[:4])
|
||||||
|
face_crop = frame[y1:y2, x1:x2]
|
||||||
|
|
||||||
|
if face_crop.size == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
pitch, yaw = gaze_estimator.estimate(face_crop)
|
||||||
|
# Draw both bbox and gaze arrow
|
||||||
|
draw_gaze(frame, bbox, pitch, yaw)
|
||||||
|
|
||||||
|
cv2.putText(frame, f'Faces: {len(faces)}', (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
|
||||||
|
cv2.imshow('Gaze Estimation', frame)
|
||||||
|
|
||||||
|
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||||
|
break
|
||||||
|
|
||||||
|
cap.release()
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='Run gaze estimation')
|
||||||
|
parser.add_argument('--image', type=str, help='Path to input image')
|
||||||
|
parser.add_argument('--webcam', action='store_true', help='Use webcam')
|
||||||
|
parser.add_argument('--save_dir', type=str, default='outputs')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not args.image and not args.webcam:
|
||||||
|
parser.error('Either --image or --webcam must be specified')
|
||||||
|
|
||||||
|
detector = RetinaFace()
|
||||||
|
gaze_estimator = MobileGaze()
|
||||||
|
|
||||||
|
if args.webcam:
|
||||||
|
run_webcam(detector, gaze_estimator)
|
||||||
|
else:
|
||||||
|
process_image(detector, gaze_estimator, args.image, args.save_dir)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@@ -13,7 +13,7 @@
|
|||||||
|
|
||||||
__license__ = 'MIT'
|
__license__ = 'MIT'
|
||||||
__author__ = 'Yakhyokhuja Valikhujaev'
|
__author__ = 'Yakhyokhuja Valikhujaev'
|
||||||
__version__ = '1.3.2'
|
__version__ = '1.4.0'
|
||||||
|
|
||||||
|
|
||||||
from uniface.face_utils import compute_similarity, face_alignment
|
from uniface.face_utils import compute_similarity, face_alignment
|
||||||
@@ -37,6 +37,7 @@ from .detection import (
|
|||||||
detect_faces,
|
detect_faces,
|
||||||
list_available_detectors,
|
list_available_detectors,
|
||||||
)
|
)
|
||||||
|
from .gaze import MobileGaze, create_gaze_estimator
|
||||||
from .landmark import Landmark106, create_landmarker
|
from .landmark import Landmark106, create_landmarker
|
||||||
from .recognition import ArcFace, MobileFace, SphereFace, create_recognizer
|
from .recognition import ArcFace, MobileFace, SphereFace, create_recognizer
|
||||||
|
|
||||||
@@ -49,6 +50,7 @@ __all__ = [
|
|||||||
'FaceAnalyzer',
|
'FaceAnalyzer',
|
||||||
# Factory functions
|
# Factory functions
|
||||||
'create_detector',
|
'create_detector',
|
||||||
|
'create_gaze_estimator',
|
||||||
'create_landmarker',
|
'create_landmarker',
|
||||||
'create_recognizer',
|
'create_recognizer',
|
||||||
'detect_faces',
|
'detect_faces',
|
||||||
@@ -63,6 +65,8 @@ __all__ = [
|
|||||||
'SphereFace',
|
'SphereFace',
|
||||||
# Landmark models
|
# Landmark models
|
||||||
'Landmark106',
|
'Landmark106',
|
||||||
|
# Gaze models
|
||||||
|
'MobileGaze',
|
||||||
# Attribute models
|
# Attribute models
|
||||||
'AgeGender',
|
'AgeGender',
|
||||||
'Emotion',
|
'Emotion',
|
||||||
|
|||||||
@@ -96,6 +96,19 @@ class LandmarkWeights(str, Enum):
|
|||||||
DEFAULT = "2d_106"
|
DEFAULT = "2d_106"
|
||||||
|
|
||||||
|
|
||||||
|
class GazeWeights(str, Enum):
|
||||||
|
"""
|
||||||
|
MobileGaze: Real-Time Gaze Estimation models.
|
||||||
|
Trained on Gaze360 dataset.
|
||||||
|
https://github.com/yakhyo/gaze-estimation
|
||||||
|
"""
|
||||||
|
RESNET18 = "gaze_resnet18"
|
||||||
|
RESNET34 = "gaze_resnet34"
|
||||||
|
RESNET50 = "gaze_resnet50"
|
||||||
|
MOBILENET_V2 = "gaze_mobilenetv2"
|
||||||
|
MOBILEONE_S0 = "gaze_mobileone_s0"
|
||||||
|
|
||||||
|
|
||||||
MODEL_URLS: Dict[Enum, str] = {
|
MODEL_URLS: Dict[Enum, str] = {
|
||||||
# RetinaFace
|
# RetinaFace
|
||||||
RetinaFaceWeights.MNET_025: 'https://github.com/yakhyo/uniface/releases/download/weights/retinaface_mv1_0.25.onnx',
|
RetinaFaceWeights.MNET_025: 'https://github.com/yakhyo/uniface/releases/download/weights/retinaface_mv1_0.25.onnx',
|
||||||
@@ -129,6 +142,12 @@ MODEL_URLS: Dict[Enum, str] = {
|
|||||||
AgeGenderWeights.DEFAULT: 'https://github.com/yakhyo/uniface/releases/download/weights/genderage.onnx',
|
AgeGenderWeights.DEFAULT: 'https://github.com/yakhyo/uniface/releases/download/weights/genderage.onnx',
|
||||||
# Landmarks
|
# Landmarks
|
||||||
LandmarkWeights.DEFAULT: 'https://github.com/yakhyo/uniface/releases/download/weights/2d106det.onnx',
|
LandmarkWeights.DEFAULT: 'https://github.com/yakhyo/uniface/releases/download/weights/2d106det.onnx',
|
||||||
|
# Gaze (MobileGaze)
|
||||||
|
GazeWeights.RESNET18: 'https://github.com/yakhyo/gaze-estimation/releases/download/weights/resnet18_gaze.onnx',
|
||||||
|
GazeWeights.RESNET34: 'https://github.com/yakhyo/gaze-estimation/releases/download/weights/resnet34_gaze.onnx',
|
||||||
|
GazeWeights.RESNET50: 'https://github.com/yakhyo/gaze-estimation/releases/download/weights/resnet50_gaze.onnx',
|
||||||
|
GazeWeights.MOBILENET_V2: 'https://github.com/yakhyo/gaze-estimation/releases/download/weights/mobilenetv2_gaze.onnx',
|
||||||
|
GazeWeights.MOBILEONE_S0: 'https://github.com/yakhyo/gaze-estimation/releases/download/weights/mobileone_s0_gaze.onnx',
|
||||||
}
|
}
|
||||||
|
|
||||||
MODEL_SHA256: Dict[Enum, str] = {
|
MODEL_SHA256: Dict[Enum, str] = {
|
||||||
@@ -164,6 +183,12 @@ MODEL_SHA256: Dict[Enum, str] = {
|
|||||||
AgeGenderWeights.DEFAULT: '4fde69b1c810857b88c64a335084f1c3fe8f01246c9a191b48c7bb756d6652fb',
|
AgeGenderWeights.DEFAULT: '4fde69b1c810857b88c64a335084f1c3fe8f01246c9a191b48c7bb756d6652fb',
|
||||||
# Landmark
|
# Landmark
|
||||||
LandmarkWeights.DEFAULT: 'f001b856447c413801ef5c42091ed0cd516fcd21f2d6b79635b1e733a7109dbf',
|
LandmarkWeights.DEFAULT: 'f001b856447c413801ef5c42091ed0cd516fcd21f2d6b79635b1e733a7109dbf',
|
||||||
|
# MobileGaze (trained on Gaze360)
|
||||||
|
GazeWeights.RESNET18: '23d5d7e4f6f40dce8c35274ce9d08b45b9e22cbaaf5af73182f473229d713d31',
|
||||||
|
GazeWeights.RESNET34: '4457ee5f7acd1a5ab02da4b61f02fc3a0b17adbf3844dd0ba3cd4288f2b5e1de',
|
||||||
|
GazeWeights.RESNET50: 'e1eaf98f5ec7c89c6abe7cfe39f7be83e747163f98d1ff945c0603b3c521be22',
|
||||||
|
GazeWeights.MOBILENET_V2: 'fdcdb84e3e6421b5a79e8f95139f249fc258d7f387eed5ddac2b80a9a15ce076',
|
||||||
|
GazeWeights.MOBILEONE_S0: 'c0b5a4f4a0ffd24f76ab3c1452354bb2f60110899fd9a88b464c75bafec0fde8',
|
||||||
}
|
}
|
||||||
|
|
||||||
CHUNK_SIZE = 8192
|
CHUNK_SIZE = 8192
|
||||||
|
|||||||
58
uniface/gaze/__init__.py
Normal file
58
uniface/gaze/__init__.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
# Copyright 2025 Yakhyokhuja Valikhujaev
|
||||||
|
# Author: Yakhyokhuja Valikhujaev
|
||||||
|
# GitHub: https://github.com/yakhyo
|
||||||
|
|
||||||
|
from .base import BaseGazeEstimator
|
||||||
|
from .models import MobileGaze
|
||||||
|
|
||||||
|
|
||||||
|
def create_gaze_estimator(method: str = 'mobilegaze', **kwargs) -> BaseGazeEstimator:
|
||||||
|
"""
|
||||||
|
Factory function to create gaze estimators.
|
||||||
|
|
||||||
|
This function initializes and returns a gaze estimator instance based on the
|
||||||
|
specified method. It acts as a high-level interface to the underlying
|
||||||
|
model classes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method (str): The gaze estimation method to use.
|
||||||
|
Options: 'mobilegaze' (default).
|
||||||
|
**kwargs: Model-specific parameters passed to the estimator's constructor.
|
||||||
|
For example, `model_name` can be used to select a specific
|
||||||
|
backbone from `GazeWeights` enum (RESNET18, RESNET34, RESNET50,
|
||||||
|
MOBILENET_V2, MOBILEONE_S0).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BaseGazeEstimator: An initialized gaze estimator instance ready for use.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the specified `method` is not supported.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> # Create the default MobileGaze estimator (ResNet18 backbone)
|
||||||
|
>>> estimator = create_gaze_estimator()
|
||||||
|
|
||||||
|
>>> # Create with MobileNetV2 backbone
|
||||||
|
>>> from uniface.constants import GazeWeights
|
||||||
|
>>> estimator = create_gaze_estimator(
|
||||||
|
... 'mobilegaze',
|
||||||
|
... model_name=GazeWeights.MOBILENET_V2
|
||||||
|
... )
|
||||||
|
|
||||||
|
>>> # Use the estimator
|
||||||
|
>>> pitch, yaw = estimator.estimate(face_crop)
|
||||||
|
"""
|
||||||
|
method = method.lower()
|
||||||
|
|
||||||
|
if method in ('mobilegaze', 'mobile_gaze', 'gaze'):
|
||||||
|
return MobileGaze(**kwargs)
|
||||||
|
else:
|
||||||
|
available = ['mobilegaze']
|
||||||
|
raise ValueError(f"Unsupported gaze estimation method: '{method}'. Available: {available}")
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'create_gaze_estimator',
|
||||||
|
'MobileGaze',
|
||||||
|
'BaseGazeEstimator',
|
||||||
|
]
|
||||||
108
uniface/gaze/base.py
Normal file
108
uniface/gaze/base.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
# Copyright 2025 Yakhyokhuja Valikhujaev
|
||||||
|
# Author: Yakhyokhuja Valikhujaev
|
||||||
|
# GitHub: https://github.com/yakhyo
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class BaseGazeEstimator(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for all gaze estimation models.
|
||||||
|
|
||||||
|
This class defines the common interface that all gaze estimators must implement,
|
||||||
|
ensuring consistency across different gaze estimation methods. Gaze estimation
|
||||||
|
predicts the direction a person is looking based on their face image.
|
||||||
|
|
||||||
|
The gaze direction is represented as pitch and yaw angles in radians:
|
||||||
|
- Pitch: Vertical angle (positive = looking up, negative = looking down)
|
||||||
|
- Yaw: Horizontal angle (positive = looking right, negative = looking left)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _initialize_model(self) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the underlying model for inference.
|
||||||
|
|
||||||
|
This method should handle loading model weights, creating the
|
||||||
|
inference session (e.g., ONNX Runtime), and any necessary
|
||||||
|
setup procedures to prepare the model for prediction.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If the model fails to load or initialize.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError('Subclasses must implement the _initialize_model method.')
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def preprocess(self, face_image: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Preprocess the input face image for model inference.
|
||||||
|
|
||||||
|
This method should take a raw face crop and convert it into the format
|
||||||
|
expected by the model's inference engine (e.g., normalized tensor).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
face_image (np.ndarray): A cropped face image in BGR format with
|
||||||
|
shape (H, W, C).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: The preprocessed image tensor ready for inference,
|
||||||
|
typically with shape (1, C, H, W).
|
||||||
|
"""
|
||||||
|
raise NotImplementedError('Subclasses must implement the preprocess method.')
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def postprocess(self, outputs: Tuple[np.ndarray, np.ndarray]) -> Tuple[float, float]:
|
||||||
|
"""
|
||||||
|
Postprocess raw model outputs into gaze angles.
|
||||||
|
|
||||||
|
This method takes the raw output from the model's inference and
|
||||||
|
converts it into pitch and yaw angles in radians.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs: Raw outputs from the model inference. The format depends
|
||||||
|
on the specific model architecture.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[float, float]: A tuple of (pitch, yaw) angles in radians.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError('Subclasses must implement the postprocess method.')
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def estimate(self, face_image: np.ndarray) -> Tuple[float, float]:
|
||||||
|
"""
|
||||||
|
Perform end-to-end gaze estimation on a face image.
|
||||||
|
|
||||||
|
This method orchestrates the full pipeline: preprocessing the input,
|
||||||
|
running inference, and postprocessing to return the gaze direction.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
face_image (np.ndarray): A cropped face image in BGR format.
|
||||||
|
The face should be roughly centered and
|
||||||
|
well-framed within the image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[float, float]: A tuple of (pitch, yaw) angles in radians:
|
||||||
|
- pitch: Vertical gaze angle (positive = up, negative = down)
|
||||||
|
- yaw: Horizontal gaze angle (positive = right, negative = left)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> estimator = create_gaze_estimator()
|
||||||
|
>>> pitch, yaw = estimator.estimate(face_crop)
|
||||||
|
>>> print(f"Looking: pitch={np.degrees(pitch):.1f}°, yaw={np.degrees(yaw):.1f}°")
|
||||||
|
"""
|
||||||
|
raise NotImplementedError('Subclasses must implement the estimate method.')
|
||||||
|
|
||||||
|
def __call__(self, face_image: np.ndarray) -> Tuple[float, float]:
|
||||||
|
"""
|
||||||
|
Provides a convenient, callable shortcut for the `estimate` method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
face_image (np.ndarray): A cropped face image in BGR format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[float, float]: A tuple of (pitch, yaw) angles in radians.
|
||||||
|
"""
|
||||||
|
return self.estimate(face_image)
|
||||||
187
uniface/gaze/models.py
Normal file
187
uniface/gaze/models.py
Normal file
@@ -0,0 +1,187 @@
|
|||||||
|
# Copyright 2025 Yakhyokhuja Valikhujaev
|
||||||
|
# Author: Yakhyokhuja Valikhujaev
|
||||||
|
# GitHub: https://github.com/yakhyo
|
||||||
|
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from uniface.constants import GazeWeights
|
||||||
|
from uniface.log import Logger
|
||||||
|
from uniface.model_store import verify_model_weights
|
||||||
|
from uniface.onnx_utils import create_onnx_session
|
||||||
|
|
||||||
|
from .base import BaseGazeEstimator
|
||||||
|
|
||||||
|
__all__ = ['MobileGaze']
|
||||||
|
|
||||||
|
|
||||||
|
class MobileGaze(BaseGazeEstimator):
|
||||||
|
"""
|
||||||
|
MobileGaze: Real-Time Gaze Estimation with ONNX Runtime.
|
||||||
|
|
||||||
|
MobileGaze is a gaze estimation model that predicts gaze direction from a single
|
||||||
|
face image. It supports multiple backbone architectures including ResNet 18/34/50,
|
||||||
|
MobileNetV2, and MobileOne S0. The model uses a classification approach with binned
|
||||||
|
angles, which are then decoded to continuous pitch and yaw values.
|
||||||
|
|
||||||
|
The model outputs gaze direction as pitch (vertical) and yaw (horizontal) angles
|
||||||
|
in radians.
|
||||||
|
|
||||||
|
Reference:
|
||||||
|
https://github.com/yakhyo/gaze-estimation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name (GazeWeights): The enum specifying the gaze model backbone to load.
|
||||||
|
Options: RESNET18, RESNET34, RESNET50, MOBILENET_V2, MOBILEONE_S0.
|
||||||
|
Defaults to `GazeWeights.RESNET18`.
|
||||||
|
input_size (Tuple[int, int]): The resolution (width, height) for the model's
|
||||||
|
input. Defaults to (448, 448).
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
input_size (Tuple[int, int]): Model input dimensions.
|
||||||
|
input_mean (list): Per-channel mean values for normalization (ImageNet).
|
||||||
|
input_std (list): Per-channel std values for normalization (ImageNet).
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from uniface.gaze import MobileGaze
|
||||||
|
>>> from uniface import RetinaFace
|
||||||
|
>>>
|
||||||
|
>>> detector = RetinaFace()
|
||||||
|
>>> gaze_estimator = MobileGaze()
|
||||||
|
>>>
|
||||||
|
>>> # Detect faces and estimate gaze for each
|
||||||
|
>>> faces = detector.detect(image)
|
||||||
|
>>> for face in faces:
|
||||||
|
... bbox = face['bbox']
|
||||||
|
... x1, y1, x2, y2 = map(int, bbox[:4])
|
||||||
|
... face_crop = image[y1:y2, x1:x2]
|
||||||
|
... pitch, yaw = gaze_estimator.estimate(face_crop)
|
||||||
|
... print(f"Gaze: pitch={np.degrees(pitch):.1f}°, yaw={np.degrees(yaw):.1f}°")
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: GazeWeights = GazeWeights.RESNET34,
|
||||||
|
input_size: Tuple[int, int] = (448, 448),
|
||||||
|
) -> None:
|
||||||
|
Logger.info(f'Initializing MobileGaze with model={model_name}, input_size={input_size}')
|
||||||
|
|
||||||
|
self.input_size = input_size
|
||||||
|
self.input_mean = [0.485, 0.456, 0.406]
|
||||||
|
self.input_std = [0.229, 0.224, 0.225]
|
||||||
|
|
||||||
|
# Model specific parameters for bin-based classification (Gaze360 config)
|
||||||
|
self._bins = 90
|
||||||
|
self._binwidth = 4
|
||||||
|
self._angle_offset = 180
|
||||||
|
self._idx_tensor = np.arange(self._bins, dtype=np.float32)
|
||||||
|
|
||||||
|
self.model_path = verify_model_weights(model_name)
|
||||||
|
self._initialize_model()
|
||||||
|
|
||||||
|
def _initialize_model(self) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the ONNX model from the stored model path.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If the model fails to load or initialize.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.session = create_onnx_session(self.model_path)
|
||||||
|
|
||||||
|
# Get input configuration
|
||||||
|
input_cfg = self.session.get_inputs()[0]
|
||||||
|
input_shape = input_cfg.shape
|
||||||
|
self.input_name = input_cfg.name
|
||||||
|
self.input_size = tuple(input_shape[2:4][::-1]) # Update from model
|
||||||
|
|
||||||
|
# Get output configuration
|
||||||
|
outputs = self.session.get_outputs()
|
||||||
|
self.output_names = [output.name for output in outputs]
|
||||||
|
|
||||||
|
if len(self.output_names) != 2:
|
||||||
|
raise ValueError(f'Expected 2 output nodes (pitch, yaw), got {len(self.output_names)}')
|
||||||
|
|
||||||
|
Logger.info(f'MobileGaze initialized with input size {self.input_size}')
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
Logger.error(f"Failed to load gaze model from '{self.model_path}'", exc_info=True)
|
||||||
|
raise RuntimeError(f'Failed to initialize gaze model: {e}') from e
|
||||||
|
|
||||||
|
def preprocess(self, face_image: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Preprocess a face crop for gaze estimation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
face_image (np.ndarray): A cropped face image in BGR format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Preprocessed image tensor with shape (1, 3, H, W).
|
||||||
|
"""
|
||||||
|
# Convert BGR to RGB
|
||||||
|
image = cv2.cvtColor(face_image, cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
|
# Resize to model input size
|
||||||
|
image = cv2.resize(image, self.input_size)
|
||||||
|
|
||||||
|
# Normalize to [0, 1] and apply normalization
|
||||||
|
image = image.astype(np.float32) / 255.0
|
||||||
|
mean = np.array(self.input_mean, dtype=np.float32)
|
||||||
|
std = np.array(self.input_std, dtype=np.float32)
|
||||||
|
image = (image - mean) / std
|
||||||
|
|
||||||
|
# HWC -> CHW -> NCHW
|
||||||
|
image = np.transpose(image, (2, 0, 1))
|
||||||
|
image = np.expand_dims(image, axis=0).astype(np.float32)
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
def _softmax(self, x: np.ndarray) -> np.ndarray:
|
||||||
|
"""Apply softmax along axis 1."""
|
||||||
|
e_x = np.exp(x - np.max(x, axis=1, keepdims=True))
|
||||||
|
return e_x / e_x.sum(axis=1, keepdims=True)
|
||||||
|
|
||||||
|
def postprocess(self, outputs: Tuple[np.ndarray, np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
|
"""
|
||||||
|
Postprocess raw model outputs into gaze angles.
|
||||||
|
|
||||||
|
This method takes the raw output from the model's inference and
|
||||||
|
converts it into pitch and yaw angles in radians.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs: Raw outputs from the model inference. The format depends
|
||||||
|
on the specific model architecture.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[np.ndarray, np.ndarray]: A tuple of (pitch, yaw) angles in radians.
|
||||||
|
"""
|
||||||
|
pitch_logits, yaw_logits = outputs
|
||||||
|
|
||||||
|
# Convert logits to probabilities
|
||||||
|
pitch_probs = self._softmax(pitch_logits)
|
||||||
|
yaw_probs = self._softmax(yaw_logits)
|
||||||
|
|
||||||
|
# Compute expected bin index (soft-argmax)
|
||||||
|
pitch_deg = np.sum(pitch_probs * self._idx_tensor, axis=1) * self._binwidth - self._angle_offset
|
||||||
|
yaw_deg = np.sum(yaw_probs * self._idx_tensor, axis=1) * self._binwidth - self._angle_offset
|
||||||
|
|
||||||
|
# Convert degrees to radians
|
||||||
|
pitch = np.radians(pitch_deg[0])
|
||||||
|
yaw = np.radians(yaw_deg[0])
|
||||||
|
|
||||||
|
return pitch, yaw
|
||||||
|
|
||||||
|
def estimate(self, face_image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
|
"""
|
||||||
|
Perform end-to-end gaze estimation on a face image.
|
||||||
|
|
||||||
|
This method orchestrates the full pipeline: preprocessing the input,
|
||||||
|
running inference, and postprocessing to return the gaze direction.
|
||||||
|
"""
|
||||||
|
input_tensor = self.preprocess(face_image)
|
||||||
|
outputs = self.session.run(self.output_names, {self.input_name: input_tensor})
|
||||||
|
pitch, yaw = self.postprocess((outputs[0], outputs[1]))
|
||||||
|
|
||||||
|
return pitch, yaw
|
||||||
@@ -126,3 +126,97 @@ def draw_fancy_bbox(
|
|||||||
# Bottom-right corner
|
# Bottom-right corner
|
||||||
cv2.line(image, (x2, y2), (x2, y2 - corner_length), color, thickness)
|
cv2.line(image, (x2, y2), (x2, y2 - corner_length), color, thickness)
|
||||||
cv2.line(image, (x2, y2), (x2 - corner_length, y2), color, thickness)
|
cv2.line(image, (x2, y2), (x2 - corner_length, y2), color, thickness)
|
||||||
|
|
||||||
|
|
||||||
|
def draw_gaze(
|
||||||
|
image: np.ndarray,
|
||||||
|
bbox: np.ndarray,
|
||||||
|
pitch: np.ndarray,
|
||||||
|
yaw: np.ndarray,
|
||||||
|
*,
|
||||||
|
draw_bbox: bool = True,
|
||||||
|
fancy_bbox: bool = True,
|
||||||
|
draw_angles: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Draws gaze direction with optional bounding box on an image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: Input image to draw on (modified in-place).
|
||||||
|
bbox: Face bounding box [x1, y1, x2, y2].
|
||||||
|
pitch: Vertical gaze angle in radians.
|
||||||
|
yaw: Horizontal gaze angle in radians.
|
||||||
|
draw_bbox: Whether to draw the bounding box. Defaults to True.
|
||||||
|
fancy_bbox: Use fancy corner-style bbox. Defaults to True.
|
||||||
|
draw_angles: Whether to display pitch/yaw values as text. Defaults to False.
|
||||||
|
"""
|
||||||
|
x_min, y_min, x_max, y_max = map(int, bbox[:4])
|
||||||
|
|
||||||
|
# Calculate dynamic line thickness based on image size (same as draw_detections)
|
||||||
|
line_thickness = max(round(sum(image.shape[:2]) / 2 * 0.003), 2)
|
||||||
|
|
||||||
|
# Calculate dynamic font scale based on bbox height (same as draw_detections)
|
||||||
|
bbox_h = y_max - y_min
|
||||||
|
font_scale = max(0.4, min(0.7, bbox_h / 200))
|
||||||
|
font_thickness = 2
|
||||||
|
|
||||||
|
# Draw bounding box if requested
|
||||||
|
if draw_bbox:
|
||||||
|
if fancy_bbox:
|
||||||
|
draw_fancy_bbox(image, bbox, color=(0, 255, 0), thickness=line_thickness)
|
||||||
|
else:
|
||||||
|
cv2.rectangle(image, (x_min, y_min), (x_max, y_max), (0, 255, 0), line_thickness)
|
||||||
|
|
||||||
|
# Calculate center of the bounding box
|
||||||
|
x_center = (x_min + x_max) // 2
|
||||||
|
y_center = (y_min + y_max) // 2
|
||||||
|
|
||||||
|
# Calculate the direction of the gaze
|
||||||
|
length = x_max - x_min
|
||||||
|
dx = int(-length * np.sin(pitch) * np.cos(yaw))
|
||||||
|
dy = int(-length * np.sin(yaw))
|
||||||
|
|
||||||
|
point1 = (x_center, y_center)
|
||||||
|
point2 = (x_center + dx, y_center + dy)
|
||||||
|
|
||||||
|
# Calculate dynamic center point radius based on line thickness
|
||||||
|
center_radius = max(line_thickness + 1, 4)
|
||||||
|
|
||||||
|
# Draw gaze direction
|
||||||
|
cv2.circle(image, (x_center, y_center), radius=center_radius, color=(0, 0, 255), thickness=-1)
|
||||||
|
cv2.arrowedLine(
|
||||||
|
image,
|
||||||
|
point1,
|
||||||
|
point2,
|
||||||
|
color=(0, 0, 255),
|
||||||
|
thickness=line_thickness,
|
||||||
|
line_type=cv2.LINE_AA,
|
||||||
|
tipLength=0.25,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Draw angle values
|
||||||
|
if draw_angles:
|
||||||
|
text = f'P:{np.degrees(pitch):.0f}deg Y:{np.degrees(yaw):.0f}deg'
|
||||||
|
(text_width, text_height), baseline = cv2.getTextSize(
|
||||||
|
text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, font_thickness
|
||||||
|
)
|
||||||
|
|
||||||
|
# Draw background rectangle for text
|
||||||
|
cv2.rectangle(
|
||||||
|
image,
|
||||||
|
(x_min, y_min - text_height - baseline - 10),
|
||||||
|
(x_min + text_width + 10, y_min),
|
||||||
|
(0, 0, 255),
|
||||||
|
-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Draw text
|
||||||
|
cv2.putText(
|
||||||
|
image,
|
||||||
|
text,
|
||||||
|
(x_min + 5, y_min - 5),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX,
|
||||||
|
font_scale,
|
||||||
|
(255, 255, 255),
|
||||||
|
font_thickness,
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user