mirror of
https://github.com/yakhyo/uniface.git
synced 2025-12-30 00:52:25 +00:00
feat: Add 2D Gaze estimation models (#34)
* feat: Add Gaze Estimation, update docs and Add example notebook, inference code * docs: Update README.md
This commit is contained in:
committed by
GitHub
parent
da8a5cf35b
commit
4d1921e531
@@ -58,3 +58,6 @@ Open an issue or start a discussion on GitHub.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
42
MODELS.md
42
MODELS.md
@@ -291,6 +291,47 @@ emotion, confidence = predictor.predict(image, landmarks)
|
||||
|
||||
---
|
||||
|
||||
## Gaze Estimation Models
|
||||
|
||||
### MobileGaze Family
|
||||
|
||||
Real-time gaze direction prediction models trained on Gaze360 dataset. Returns pitch (vertical) and yaw (horizontal) angles in radians.
|
||||
|
||||
| Model Name | Params | Size | MAE* | Use Case |
|
||||
| -------------- | ------ | ------- | ----- | ----------------------------- |
|
||||
| `RESNET18` | 11.7M | 43 MB | 12.84 | Balanced accuracy/speed |
|
||||
| `RESNET34` ⭐ | 24.8M | 81.6 MB | 11.33 | **Recommended default** |
|
||||
| `RESNET50` | 25.6M | 91.3 MB | 11.34 | High accuracy |
|
||||
| `MOBILENET_V2` | 3.5M | 9.59 MB | 13.07 | Mobile/Edge devices |
|
||||
| `MOBILEONE_S0` | 2.1M | 4.8 MB | 12.58 | Lightweight/Real-time |
|
||||
|
||||
*MAE (Mean Absolute Error) in degrees on Gaze360 test set - lower is better
|
||||
|
||||
**Dataset**: Trained on Gaze360 (indoor/outdoor scenes with diverse head poses)
|
||||
**Training**: 200 epochs with classification-based approach (binned angles)
|
||||
|
||||
#### Usage
|
||||
|
||||
```python
|
||||
from uniface import MobileGaze
|
||||
from uniface.constants import GazeWeights
|
||||
import numpy as np
|
||||
|
||||
# Default (recommended)
|
||||
gaze_estimator = MobileGaze() # Uses RESNET34
|
||||
|
||||
# Lightweight model
|
||||
gaze_estimator = MobileGaze(model_name=GazeWeights.MOBILEONE_S0)
|
||||
|
||||
# Estimate gaze from face crop
|
||||
pitch, yaw = gaze_estimator.estimate(face_crop)
|
||||
print(f"Pitch: {np.degrees(pitch):.1f}°, Yaw: {np.degrees(yaw):.1f}°")
|
||||
```
|
||||
|
||||
**Note**: Requires face crop as input. Use face detection first to obtain bounding boxes.
|
||||
|
||||
---
|
||||
|
||||
## Model Updates
|
||||
|
||||
Models are automatically downloaded and cached on first use. Cache location: `~/.uniface/models/`
|
||||
@@ -330,6 +371,7 @@ python scripts/download_model.py --model MNET_V2
|
||||
- **YOLOv5-Face Original**: [deepcam-cn/yolov5-face](https://github.com/deepcam-cn/yolov5-face) - Original PyTorch implementation
|
||||
- **YOLOv5-Face ONNX**: [yakhyo/yolov5-face-onnx-inference](https://github.com/yakhyo/yolov5-face-onnx-inference) - ONNX inference implementation
|
||||
- **Face Recognition Training**: [yakhyo/face-recognition](https://github.com/yakhyo/face-recognition) - ArcFace, MobileFace, SphereFace training code
|
||||
- **Gaze Estimation Training**: [yakhyo/gaze-estimation](https://github.com/yakhyo/gaze-estimation) - MobileGaze training code and pretrained weights
|
||||
- **InsightFace**: [deepinsight/insightface](https://github.com/deepinsight/insightface) - Model architectures and pretrained weights
|
||||
|
||||
### Papers
|
||||
|
||||
@@ -242,7 +242,50 @@ if faces:
|
||||
|
||||
---
|
||||
|
||||
## 7. Batch Processing (3 minutes)
|
||||
## 7. Gaze Estimation (2 minutes)
|
||||
|
||||
Estimate where a person is looking:
|
||||
|
||||
```python
|
||||
import cv2
|
||||
import numpy as np
|
||||
from uniface import RetinaFace, MobileGaze
|
||||
from uniface.visualization import draw_gaze
|
||||
|
||||
# Initialize models
|
||||
detector = RetinaFace()
|
||||
gaze_estimator = MobileGaze()
|
||||
|
||||
# Load image
|
||||
image = cv2.imread("photo.jpg")
|
||||
faces = detector.detect(image)
|
||||
|
||||
# Estimate gaze for each face
|
||||
for i, face in enumerate(faces):
|
||||
bbox = face['bbox']
|
||||
x1, y1, x2, y2 = map(int, bbox[:4])
|
||||
face_crop = image[y1:y2, x1:x2]
|
||||
|
||||
if face_crop.size > 0:
|
||||
pitch, yaw = gaze_estimator.estimate(face_crop)
|
||||
print(f"Face {i+1}: pitch={np.degrees(pitch):.1f}°, yaw={np.degrees(yaw):.1f}°")
|
||||
|
||||
# Draw gaze direction
|
||||
draw_gaze(image, bbox, pitch, yaw)
|
||||
|
||||
cv2.imwrite("gaze_output.jpg", image)
|
||||
```
|
||||
|
||||
**Output:**
|
||||
|
||||
```
|
||||
Face 1: pitch=5.2°, yaw=-12.3°
|
||||
Face 2: pitch=-8.1°, yaw=15.7°
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 8. Batch Processing (3 minutes)
|
||||
|
||||
Process multiple images:
|
||||
|
||||
@@ -275,7 +318,7 @@ print("Done!")
|
||||
|
||||
---
|
||||
|
||||
## 8. Model Selection
|
||||
## 9. Model Selection
|
||||
|
||||
Choose the right model for your use case:
|
||||
|
||||
@@ -326,6 +369,22 @@ recognizer = MobileFace(model_name=MobileFaceWeights.MNET_V2) # Fast, small siz
|
||||
recognizer = SphereFace(model_name=SphereFaceWeights.SPHERE20) # Alternative method
|
||||
```
|
||||
|
||||
### Gaze Estimation Models
|
||||
|
||||
```python
|
||||
from uniface import MobileGaze
|
||||
from uniface.constants import GazeWeights
|
||||
|
||||
# Default (recommended)
|
||||
gaze_estimator = MobileGaze() # Uses RESNET34
|
||||
|
||||
# Lightweight (mobile/edge devices)
|
||||
gaze_estimator = MobileGaze(model_name=GazeWeights.MOBILEONE_S0)
|
||||
|
||||
# High accuracy
|
||||
gaze_estimator = MobileGaze(model_name=GazeWeights.RESNET50)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Common Issues
|
||||
@@ -400,4 +459,5 @@ Explore interactive examples for common tasks:
|
||||
- **RetinaFace Training**: [yakhyo/retinaface-pytorch](https://github.com/yakhyo/retinaface-pytorch)
|
||||
- **YOLOv5-Face ONNX**: [yakhyo/yolov5-face-onnx-inference](https://github.com/yakhyo/yolov5-face-onnx-inference)
|
||||
- **Face Recognition Training**: [yakhyo/face-recognition](https://github.com/yakhyo/face-recognition)
|
||||
- **Gaze Estimation Training**: [yakhyo/gaze-estimation](https://github.com/yakhyo/gaze-estimation)
|
||||
- **InsightFace**: [deepinsight/insightface](https://github.com/deepinsight/insightface)
|
||||
|
||||
33
README.md
33
README.md
@@ -20,6 +20,7 @@
|
||||
- **High-Speed Face Detection**: ONNX-optimized RetinaFace, SCRFD, and YOLOv5-Face models
|
||||
- **Facial Landmark Detection**: Accurate 106-point landmark localization
|
||||
- **Face Recognition**: ArcFace, MobileFace, and SphereFace embeddings
|
||||
- **Gaze Estimation**: Real-time gaze direction prediction with MobileGaze
|
||||
- **Attribute Analysis**: Age, gender, and emotion detection
|
||||
- **Face Alignment**: Precise alignment for downstream tasks
|
||||
- **Hardware Acceleration**: ARM64 optimizations (Apple Silicon), CUDA (NVIDIA), CPU fallback
|
||||
@@ -152,6 +153,29 @@ gender_str = 'Female' if gender == 0 else 'Male'
|
||||
print(f"{gender_str}, {age} years old")
|
||||
```
|
||||
|
||||
### Gaze Estimation
|
||||
|
||||
```python
|
||||
from uniface import RetinaFace, MobileGaze
|
||||
from uniface.visualization import draw_gaze
|
||||
import numpy as np
|
||||
|
||||
detector = RetinaFace()
|
||||
gaze_estimator = MobileGaze()
|
||||
|
||||
faces = detector.detect(image)
|
||||
for face in faces:
|
||||
bbox = face['bbox']
|
||||
x1, y1, x2, y2 = map(int, bbox[:4])
|
||||
face_crop = image[y1:y2, x1:x2]
|
||||
|
||||
pitch, yaw = gaze_estimator.estimate(face_crop)
|
||||
print(f"Gaze: pitch={np.degrees(pitch):.1f}°, yaw={np.degrees(yaw):.1f}°")
|
||||
|
||||
# Visualize
|
||||
draw_gaze(image, bbox, pitch, yaw)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Documentation
|
||||
@@ -252,6 +276,12 @@ faces = detect_faces(image, method='retinaface', conf_thresh=0.8) # methods: re
|
||||
| `AgeGender` | `model_name=AgeGenderWeights.DEFAULT`; `input_size` auto-detected | Requires bbox; ONNXRuntime |
|
||||
| `Emotion` | `model_weights=DDAMFNWeights.AFFECNET7`, `input_size=(112, 112)` | Requires 5-point landmarks; TorchScript |
|
||||
|
||||
**Gaze Estimation**
|
||||
|
||||
| Class | Key params (defaults) | Notes |
|
||||
| ------------- | ------------------------------------------ | ------------------------------------ |
|
||||
| `MobileGaze` | `model_name=GazeWeights.RESNET34` | Returns (pitch, yaw) angles in radians; trained on Gaze360 |
|
||||
|
||||
---
|
||||
|
||||
## Model Performance
|
||||
@@ -298,6 +328,7 @@ Interactive examples covering common face analysis tasks:
|
||||
| **Face Recognition** | Extract face embeddings and compare faces | [face_analyzer.ipynb](examples/face_analyzer.ipynb) |
|
||||
| **Face Verification** | Compare two faces to verify identity | [face_verification.ipynb](examples/face_verification.ipynb) |
|
||||
| **Face Search** | Find a person in a group photo | [face_search.ipynb](examples/face_search.ipynb) |
|
||||
| **Gaze Estimation** | Estimate gaze direction from face images | [gaze_estimation.ipynb](examples/gaze_estimation.ipynb) |
|
||||
|
||||
### Webcam Face Detection
|
||||
|
||||
@@ -488,6 +519,7 @@ uniface/
|
||||
│ ├── detection/ # Face detection models
|
||||
│ ├── recognition/ # Face recognition models
|
||||
│ ├── landmark/ # Landmark detection
|
||||
│ ├── gaze/ # Gaze estimation
|
||||
│ ├── attribute/ # Age, gender, emotion
|
||||
│ ├── onnx_utils.py # ONNX Runtime utilities
|
||||
│ ├── model_store.py # Model download & caching
|
||||
@@ -504,6 +536,7 @@ uniface/
|
||||
- **RetinaFace Training**: [yakhyo/retinaface-pytorch](https://github.com/yakhyo/retinaface-pytorch) - PyTorch implementation and training code
|
||||
- **YOLOv5-Face ONNX**: [yakhyo/yolov5-face-onnx-inference](https://github.com/yakhyo/yolov5-face-onnx-inference) - ONNX inference implementation
|
||||
- **Face Recognition Training**: [yakhyo/face-recognition](https://github.com/yakhyo/face-recognition) - ArcFace, MobileFace, SphereFace training code
|
||||
- **Gaze Estimation Training**: [yakhyo/gaze-estimation](https://github.com/yakhyo/gaze-estimation) - MobileGaze training code and pretrained weights
|
||||
- **InsightFace**: [deepinsight/insightface](https://github.com/deepinsight/insightface) - Model architectures and pretrained weights
|
||||
|
||||
## Contributing
|
||||
|
||||
271
examples/gaze_estimation.ipynb
Normal file
271
examples/gaze_estimation.ipynb
Normal file
File diff suppressed because one or more lines are too long
@@ -1,7 +1,7 @@
|
||||
[project]
|
||||
name = "uniface"
|
||||
version = "1.3.2"
|
||||
description = "UniFace: A Comprehensive Library for Face Detection, Recognition, Landmark Analysis, Age, and Gender Detection"
|
||||
version = "1.4.0"
|
||||
description = "UniFace: A Comprehensive Library for Face Detection, Recognition, Landmark Analysis, Gaze Estimation, Age, and Gender Detection"
|
||||
readme = "README.md"
|
||||
license = { text = "MIT" }
|
||||
authors = [{ name = "Yakhyokhuja Valikhujaev", email = "yakhyo9696@gmail.com" }]
|
||||
@@ -14,6 +14,7 @@ keywords = [
|
||||
"face-detection",
|
||||
"face-recognition",
|
||||
"facial-landmarks",
|
||||
"gaze-estimation",
|
||||
"age-detection",
|
||||
"gender-detection",
|
||||
"computer-vision",
|
||||
|
||||
@@ -9,6 +9,7 @@ Scripts for testing UniFace features.
|
||||
| `run_detection.py` | Face detection on image or webcam |
|
||||
| `run_age_gender.py` | Age and gender prediction |
|
||||
| `run_emotion.py` | Emotion detection (7 or 8 emotions) |
|
||||
| `run_gaze_estimation.py` | Gaze direction estimation |
|
||||
| `run_landmarks.py` | 106-point facial landmark detection |
|
||||
| `run_recognition.py` | Face embedding extraction and comparison |
|
||||
| `run_face_analyzer.py` | Complete face analysis (detection + recognition + attributes) |
|
||||
@@ -33,6 +34,10 @@ python scripts/run_age_gender.py --webcam
|
||||
python scripts/run_emotion.py --image assets/test.jpg
|
||||
python scripts/run_emotion.py --webcam
|
||||
|
||||
# Gaze estimation
|
||||
python scripts/run_gaze_estimation.py --image assets/test.jpg
|
||||
python scripts/run_gaze_estimation.py --webcam
|
||||
|
||||
# Landmarks
|
||||
python scripts/run_landmarks.py --image assets/test.jpg
|
||||
python scripts/run_landmarks.py --webcam
|
||||
|
||||
@@ -79,7 +79,9 @@ def run_webcam(detector, age_gender, threshold: float = 0.6):
|
||||
bboxes = [f['bbox'] for f in faces]
|
||||
scores = [f['confidence'] for f in faces]
|
||||
landmarks = [f['landmarks'] for f in faces]
|
||||
draw_detections(frame, bboxes, scores, landmarks, vis_threshold=threshold)
|
||||
draw_detections(
|
||||
image=frame, bboxes=bboxes, scores=scores, landmarks=landmarks, vis_threshold=threshold, fancy_bbox=True
|
||||
)
|
||||
|
||||
for face in faces:
|
||||
gender_id, age = age_gender.predict(frame, face['bbox']) # predict per face
|
||||
|
||||
@@ -98,7 +98,7 @@ def main():
|
||||
else:
|
||||
from uniface.constants import YOLOv5FaceWeights
|
||||
|
||||
detector = YOLOv5Face(model_name=YOLOv5FaceWeights.YOLOV5N)
|
||||
detector = YOLOv5Face(model_name=YOLOv5FaceWeights.YOLOV5M)
|
||||
|
||||
if args.webcam:
|
||||
run_webcam(detector, args.threshold)
|
||||
|
||||
104
scripts/run_gaze_estimation.py
Normal file
104
scripts/run_gaze_estimation.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# Gaze estimation on detected faces
|
||||
# Usage: python run_gaze_estimation.py --image path/to/image.jpg
|
||||
# python run_gaze_estimation.py --webcam
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from uniface import RetinaFace
|
||||
from uniface.gaze import MobileGaze
|
||||
from uniface.visualization import draw_gaze
|
||||
|
||||
|
||||
def process_image(detector, gaze_estimator, image_path: str, save_dir: str = 'outputs'):
|
||||
image = cv2.imread(image_path)
|
||||
if image is None:
|
||||
print(f"Error: Failed to load image from '{image_path}'")
|
||||
return
|
||||
|
||||
faces = detector.detect(image)
|
||||
print(f'Detected {len(faces)} face(s)')
|
||||
|
||||
for i, face in enumerate(faces):
|
||||
bbox = face['bbox']
|
||||
x1, y1, x2, y2 = map(int, bbox[:4])
|
||||
face_crop = image[y1:y2, x1:x2]
|
||||
|
||||
if face_crop.size == 0:
|
||||
continue
|
||||
|
||||
pitch, yaw = gaze_estimator.estimate(face_crop)
|
||||
print(f' Face {i + 1}: pitch={np.degrees(pitch):.1f}°, yaw={np.degrees(yaw):.1f}°')
|
||||
|
||||
# Draw both bbox and gaze arrow with angle text
|
||||
draw_gaze(image, bbox, pitch, yaw, draw_angles=True)
|
||||
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
output_path = os.path.join(save_dir, f'{Path(image_path).stem}_gaze.jpg')
|
||||
cv2.imwrite(output_path, image)
|
||||
print(f'Output saved: {output_path}')
|
||||
|
||||
|
||||
def run_webcam(detector, gaze_estimator):
|
||||
cap = cv2.VideoCapture(0)
|
||||
if not cap.isOpened():
|
||||
print('Cannot open webcam')
|
||||
return
|
||||
|
||||
print("Press 'q' to quit")
|
||||
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
frame = cv2.flip(frame, 1)
|
||||
faces = detector.detect(frame)
|
||||
|
||||
for face in faces:
|
||||
bbox = face['bbox']
|
||||
x1, y1, x2, y2 = map(int, bbox[:4])
|
||||
face_crop = frame[y1:y2, x1:x2]
|
||||
|
||||
if face_crop.size == 0:
|
||||
continue
|
||||
|
||||
pitch, yaw = gaze_estimator.estimate(face_crop)
|
||||
# Draw both bbox and gaze arrow
|
||||
draw_gaze(frame, bbox, pitch, yaw)
|
||||
|
||||
cv2.putText(frame, f'Faces: {len(faces)}', (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
|
||||
cv2.imshow('Gaze Estimation', frame)
|
||||
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||
break
|
||||
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Run gaze estimation')
|
||||
parser.add_argument('--image', type=str, help='Path to input image')
|
||||
parser.add_argument('--webcam', action='store_true', help='Use webcam')
|
||||
parser.add_argument('--save_dir', type=str, default='outputs')
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.image and not args.webcam:
|
||||
parser.error('Either --image or --webcam must be specified')
|
||||
|
||||
detector = RetinaFace()
|
||||
gaze_estimator = MobileGaze()
|
||||
|
||||
if args.webcam:
|
||||
run_webcam(detector, gaze_estimator)
|
||||
else:
|
||||
process_image(detector, gaze_estimator, args.image, args.save_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -13,7 +13,7 @@
|
||||
|
||||
__license__ = 'MIT'
|
||||
__author__ = 'Yakhyokhuja Valikhujaev'
|
||||
__version__ = '1.3.2'
|
||||
__version__ = '1.4.0'
|
||||
|
||||
|
||||
from uniface.face_utils import compute_similarity, face_alignment
|
||||
@@ -37,6 +37,7 @@ from .detection import (
|
||||
detect_faces,
|
||||
list_available_detectors,
|
||||
)
|
||||
from .gaze import MobileGaze, create_gaze_estimator
|
||||
from .landmark import Landmark106, create_landmarker
|
||||
from .recognition import ArcFace, MobileFace, SphereFace, create_recognizer
|
||||
|
||||
@@ -49,6 +50,7 @@ __all__ = [
|
||||
'FaceAnalyzer',
|
||||
# Factory functions
|
||||
'create_detector',
|
||||
'create_gaze_estimator',
|
||||
'create_landmarker',
|
||||
'create_recognizer',
|
||||
'detect_faces',
|
||||
@@ -63,6 +65,8 @@ __all__ = [
|
||||
'SphereFace',
|
||||
# Landmark models
|
||||
'Landmark106',
|
||||
# Gaze models
|
||||
'MobileGaze',
|
||||
# Attribute models
|
||||
'AgeGender',
|
||||
'Emotion',
|
||||
|
||||
@@ -96,6 +96,19 @@ class LandmarkWeights(str, Enum):
|
||||
DEFAULT = "2d_106"
|
||||
|
||||
|
||||
class GazeWeights(str, Enum):
|
||||
"""
|
||||
MobileGaze: Real-Time Gaze Estimation models.
|
||||
Trained on Gaze360 dataset.
|
||||
https://github.com/yakhyo/gaze-estimation
|
||||
"""
|
||||
RESNET18 = "gaze_resnet18"
|
||||
RESNET34 = "gaze_resnet34"
|
||||
RESNET50 = "gaze_resnet50"
|
||||
MOBILENET_V2 = "gaze_mobilenetv2"
|
||||
MOBILEONE_S0 = "gaze_mobileone_s0"
|
||||
|
||||
|
||||
MODEL_URLS: Dict[Enum, str] = {
|
||||
# RetinaFace
|
||||
RetinaFaceWeights.MNET_025: 'https://github.com/yakhyo/uniface/releases/download/weights/retinaface_mv1_0.25.onnx',
|
||||
@@ -129,6 +142,12 @@ MODEL_URLS: Dict[Enum, str] = {
|
||||
AgeGenderWeights.DEFAULT: 'https://github.com/yakhyo/uniface/releases/download/weights/genderage.onnx',
|
||||
# Landmarks
|
||||
LandmarkWeights.DEFAULT: 'https://github.com/yakhyo/uniface/releases/download/weights/2d106det.onnx',
|
||||
# Gaze (MobileGaze)
|
||||
GazeWeights.RESNET18: 'https://github.com/yakhyo/gaze-estimation/releases/download/weights/resnet18_gaze.onnx',
|
||||
GazeWeights.RESNET34: 'https://github.com/yakhyo/gaze-estimation/releases/download/weights/resnet34_gaze.onnx',
|
||||
GazeWeights.RESNET50: 'https://github.com/yakhyo/gaze-estimation/releases/download/weights/resnet50_gaze.onnx',
|
||||
GazeWeights.MOBILENET_V2: 'https://github.com/yakhyo/gaze-estimation/releases/download/weights/mobilenetv2_gaze.onnx',
|
||||
GazeWeights.MOBILEONE_S0: 'https://github.com/yakhyo/gaze-estimation/releases/download/weights/mobileone_s0_gaze.onnx',
|
||||
}
|
||||
|
||||
MODEL_SHA256: Dict[Enum, str] = {
|
||||
@@ -164,6 +183,12 @@ MODEL_SHA256: Dict[Enum, str] = {
|
||||
AgeGenderWeights.DEFAULT: '4fde69b1c810857b88c64a335084f1c3fe8f01246c9a191b48c7bb756d6652fb',
|
||||
# Landmark
|
||||
LandmarkWeights.DEFAULT: 'f001b856447c413801ef5c42091ed0cd516fcd21f2d6b79635b1e733a7109dbf',
|
||||
# MobileGaze (trained on Gaze360)
|
||||
GazeWeights.RESNET18: '23d5d7e4f6f40dce8c35274ce9d08b45b9e22cbaaf5af73182f473229d713d31',
|
||||
GazeWeights.RESNET34: '4457ee5f7acd1a5ab02da4b61f02fc3a0b17adbf3844dd0ba3cd4288f2b5e1de',
|
||||
GazeWeights.RESNET50: 'e1eaf98f5ec7c89c6abe7cfe39f7be83e747163f98d1ff945c0603b3c521be22',
|
||||
GazeWeights.MOBILENET_V2: 'fdcdb84e3e6421b5a79e8f95139f249fc258d7f387eed5ddac2b80a9a15ce076',
|
||||
GazeWeights.MOBILEONE_S0: 'c0b5a4f4a0ffd24f76ab3c1452354bb2f60110899fd9a88b464c75bafec0fde8',
|
||||
}
|
||||
|
||||
CHUNK_SIZE = 8192
|
||||
|
||||
58
uniface/gaze/__init__.py
Normal file
58
uniface/gaze/__init__.py
Normal file
@@ -0,0 +1,58 @@
|
||||
# Copyright 2025 Yakhyokhuja Valikhujaev
|
||||
# Author: Yakhyokhuja Valikhujaev
|
||||
# GitHub: https://github.com/yakhyo
|
||||
|
||||
from .base import BaseGazeEstimator
|
||||
from .models import MobileGaze
|
||||
|
||||
|
||||
def create_gaze_estimator(method: str = 'mobilegaze', **kwargs) -> BaseGazeEstimator:
|
||||
"""
|
||||
Factory function to create gaze estimators.
|
||||
|
||||
This function initializes and returns a gaze estimator instance based on the
|
||||
specified method. It acts as a high-level interface to the underlying
|
||||
model classes.
|
||||
|
||||
Args:
|
||||
method (str): The gaze estimation method to use.
|
||||
Options: 'mobilegaze' (default).
|
||||
**kwargs: Model-specific parameters passed to the estimator's constructor.
|
||||
For example, `model_name` can be used to select a specific
|
||||
backbone from `GazeWeights` enum (RESNET18, RESNET34, RESNET50,
|
||||
MOBILENET_V2, MOBILEONE_S0).
|
||||
|
||||
Returns:
|
||||
BaseGazeEstimator: An initialized gaze estimator instance ready for use.
|
||||
|
||||
Raises:
|
||||
ValueError: If the specified `method` is not supported.
|
||||
|
||||
Examples:
|
||||
>>> # Create the default MobileGaze estimator (ResNet18 backbone)
|
||||
>>> estimator = create_gaze_estimator()
|
||||
|
||||
>>> # Create with MobileNetV2 backbone
|
||||
>>> from uniface.constants import GazeWeights
|
||||
>>> estimator = create_gaze_estimator(
|
||||
... 'mobilegaze',
|
||||
... model_name=GazeWeights.MOBILENET_V2
|
||||
... )
|
||||
|
||||
>>> # Use the estimator
|
||||
>>> pitch, yaw = estimator.estimate(face_crop)
|
||||
"""
|
||||
method = method.lower()
|
||||
|
||||
if method in ('mobilegaze', 'mobile_gaze', 'gaze'):
|
||||
return MobileGaze(**kwargs)
|
||||
else:
|
||||
available = ['mobilegaze']
|
||||
raise ValueError(f"Unsupported gaze estimation method: '{method}'. Available: {available}")
|
||||
|
||||
|
||||
__all__ = [
|
||||
'create_gaze_estimator',
|
||||
'MobileGaze',
|
||||
'BaseGazeEstimator',
|
||||
]
|
||||
108
uniface/gaze/base.py
Normal file
108
uniface/gaze/base.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# Copyright 2025 Yakhyokhuja Valikhujaev
|
||||
# Author: Yakhyokhuja Valikhujaev
|
||||
# GitHub: https://github.com/yakhyo
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class BaseGazeEstimator(ABC):
|
||||
"""
|
||||
Abstract base class for all gaze estimation models.
|
||||
|
||||
This class defines the common interface that all gaze estimators must implement,
|
||||
ensuring consistency across different gaze estimation methods. Gaze estimation
|
||||
predicts the direction a person is looking based on their face image.
|
||||
|
||||
The gaze direction is represented as pitch and yaw angles in radians:
|
||||
- Pitch: Vertical angle (positive = looking up, negative = looking down)
|
||||
- Yaw: Horizontal angle (positive = looking right, negative = looking left)
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def _initialize_model(self) -> None:
|
||||
"""
|
||||
Initialize the underlying model for inference.
|
||||
|
||||
This method should handle loading model weights, creating the
|
||||
inference session (e.g., ONNX Runtime), and any necessary
|
||||
setup procedures to prepare the model for prediction.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the model fails to load or initialize.
|
||||
"""
|
||||
raise NotImplementedError('Subclasses must implement the _initialize_model method.')
|
||||
|
||||
@abstractmethod
|
||||
def preprocess(self, face_image: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Preprocess the input face image for model inference.
|
||||
|
||||
This method should take a raw face crop and convert it into the format
|
||||
expected by the model's inference engine (e.g., normalized tensor).
|
||||
|
||||
Args:
|
||||
face_image (np.ndarray): A cropped face image in BGR format with
|
||||
shape (H, W, C).
|
||||
|
||||
Returns:
|
||||
np.ndarray: The preprocessed image tensor ready for inference,
|
||||
typically with shape (1, C, H, W).
|
||||
"""
|
||||
raise NotImplementedError('Subclasses must implement the preprocess method.')
|
||||
|
||||
@abstractmethod
|
||||
def postprocess(self, outputs: Tuple[np.ndarray, np.ndarray]) -> Tuple[float, float]:
|
||||
"""
|
||||
Postprocess raw model outputs into gaze angles.
|
||||
|
||||
This method takes the raw output from the model's inference and
|
||||
converts it into pitch and yaw angles in radians.
|
||||
|
||||
Args:
|
||||
outputs: Raw outputs from the model inference. The format depends
|
||||
on the specific model architecture.
|
||||
|
||||
Returns:
|
||||
Tuple[float, float]: A tuple of (pitch, yaw) angles in radians.
|
||||
"""
|
||||
raise NotImplementedError('Subclasses must implement the postprocess method.')
|
||||
|
||||
@abstractmethod
|
||||
def estimate(self, face_image: np.ndarray) -> Tuple[float, float]:
|
||||
"""
|
||||
Perform end-to-end gaze estimation on a face image.
|
||||
|
||||
This method orchestrates the full pipeline: preprocessing the input,
|
||||
running inference, and postprocessing to return the gaze direction.
|
||||
|
||||
Args:
|
||||
face_image (np.ndarray): A cropped face image in BGR format.
|
||||
The face should be roughly centered and
|
||||
well-framed within the image.
|
||||
|
||||
Returns:
|
||||
Tuple[float, float]: A tuple of (pitch, yaw) angles in radians:
|
||||
- pitch: Vertical gaze angle (positive = up, negative = down)
|
||||
- yaw: Horizontal gaze angle (positive = right, negative = left)
|
||||
|
||||
Example:
|
||||
>>> estimator = create_gaze_estimator()
|
||||
>>> pitch, yaw = estimator.estimate(face_crop)
|
||||
>>> print(f"Looking: pitch={np.degrees(pitch):.1f}°, yaw={np.degrees(yaw):.1f}°")
|
||||
"""
|
||||
raise NotImplementedError('Subclasses must implement the estimate method.')
|
||||
|
||||
def __call__(self, face_image: np.ndarray) -> Tuple[float, float]:
|
||||
"""
|
||||
Provides a convenient, callable shortcut for the `estimate` method.
|
||||
|
||||
Args:
|
||||
face_image (np.ndarray): A cropped face image in BGR format.
|
||||
|
||||
Returns:
|
||||
Tuple[float, float]: A tuple of (pitch, yaw) angles in radians.
|
||||
"""
|
||||
return self.estimate(face_image)
|
||||
187
uniface/gaze/models.py
Normal file
187
uniface/gaze/models.py
Normal file
@@ -0,0 +1,187 @@
|
||||
# Copyright 2025 Yakhyokhuja Valikhujaev
|
||||
# Author: Yakhyokhuja Valikhujaev
|
||||
# GitHub: https://github.com/yakhyo
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from uniface.constants import GazeWeights
|
||||
from uniface.log import Logger
|
||||
from uniface.model_store import verify_model_weights
|
||||
from uniface.onnx_utils import create_onnx_session
|
||||
|
||||
from .base import BaseGazeEstimator
|
||||
|
||||
__all__ = ['MobileGaze']
|
||||
|
||||
|
||||
class MobileGaze(BaseGazeEstimator):
|
||||
"""
|
||||
MobileGaze: Real-Time Gaze Estimation with ONNX Runtime.
|
||||
|
||||
MobileGaze is a gaze estimation model that predicts gaze direction from a single
|
||||
face image. It supports multiple backbone architectures including ResNet 18/34/50,
|
||||
MobileNetV2, and MobileOne S0. The model uses a classification approach with binned
|
||||
angles, which are then decoded to continuous pitch and yaw values.
|
||||
|
||||
The model outputs gaze direction as pitch (vertical) and yaw (horizontal) angles
|
||||
in radians.
|
||||
|
||||
Reference:
|
||||
https://github.com/yakhyo/gaze-estimation
|
||||
|
||||
Args:
|
||||
model_name (GazeWeights): The enum specifying the gaze model backbone to load.
|
||||
Options: RESNET18, RESNET34, RESNET50, MOBILENET_V2, MOBILEONE_S0.
|
||||
Defaults to `GazeWeights.RESNET18`.
|
||||
input_size (Tuple[int, int]): The resolution (width, height) for the model's
|
||||
input. Defaults to (448, 448).
|
||||
|
||||
Attributes:
|
||||
input_size (Tuple[int, int]): Model input dimensions.
|
||||
input_mean (list): Per-channel mean values for normalization (ImageNet).
|
||||
input_std (list): Per-channel std values for normalization (ImageNet).
|
||||
|
||||
Example:
|
||||
>>> from uniface.gaze import MobileGaze
|
||||
>>> from uniface import RetinaFace
|
||||
>>>
|
||||
>>> detector = RetinaFace()
|
||||
>>> gaze_estimator = MobileGaze()
|
||||
>>>
|
||||
>>> # Detect faces and estimate gaze for each
|
||||
>>> faces = detector.detect(image)
|
||||
>>> for face in faces:
|
||||
... bbox = face['bbox']
|
||||
... x1, y1, x2, y2 = map(int, bbox[:4])
|
||||
... face_crop = image[y1:y2, x1:x2]
|
||||
... pitch, yaw = gaze_estimator.estimate(face_crop)
|
||||
... print(f"Gaze: pitch={np.degrees(pitch):.1f}°, yaw={np.degrees(yaw):.1f}°")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: GazeWeights = GazeWeights.RESNET34,
|
||||
input_size: Tuple[int, int] = (448, 448),
|
||||
) -> None:
|
||||
Logger.info(f'Initializing MobileGaze with model={model_name}, input_size={input_size}')
|
||||
|
||||
self.input_size = input_size
|
||||
self.input_mean = [0.485, 0.456, 0.406]
|
||||
self.input_std = [0.229, 0.224, 0.225]
|
||||
|
||||
# Model specific parameters for bin-based classification (Gaze360 config)
|
||||
self._bins = 90
|
||||
self._binwidth = 4
|
||||
self._angle_offset = 180
|
||||
self._idx_tensor = np.arange(self._bins, dtype=np.float32)
|
||||
|
||||
self.model_path = verify_model_weights(model_name)
|
||||
self._initialize_model()
|
||||
|
||||
def _initialize_model(self) -> None:
|
||||
"""
|
||||
Initialize the ONNX model from the stored model path.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the model fails to load or initialize.
|
||||
"""
|
||||
try:
|
||||
self.session = create_onnx_session(self.model_path)
|
||||
|
||||
# Get input configuration
|
||||
input_cfg = self.session.get_inputs()[0]
|
||||
input_shape = input_cfg.shape
|
||||
self.input_name = input_cfg.name
|
||||
self.input_size = tuple(input_shape[2:4][::-1]) # Update from model
|
||||
|
||||
# Get output configuration
|
||||
outputs = self.session.get_outputs()
|
||||
self.output_names = [output.name for output in outputs]
|
||||
|
||||
if len(self.output_names) != 2:
|
||||
raise ValueError(f'Expected 2 output nodes (pitch, yaw), got {len(self.output_names)}')
|
||||
|
||||
Logger.info(f'MobileGaze initialized with input size {self.input_size}')
|
||||
|
||||
except Exception as e:
|
||||
Logger.error(f"Failed to load gaze model from '{self.model_path}'", exc_info=True)
|
||||
raise RuntimeError(f'Failed to initialize gaze model: {e}') from e
|
||||
|
||||
def preprocess(self, face_image: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Preprocess a face crop for gaze estimation.
|
||||
|
||||
Args:
|
||||
face_image (np.ndarray): A cropped face image in BGR format.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Preprocessed image tensor with shape (1, 3, H, W).
|
||||
"""
|
||||
# Convert BGR to RGB
|
||||
image = cv2.cvtColor(face_image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Resize to model input size
|
||||
image = cv2.resize(image, self.input_size)
|
||||
|
||||
# Normalize to [0, 1] and apply normalization
|
||||
image = image.astype(np.float32) / 255.0
|
||||
mean = np.array(self.input_mean, dtype=np.float32)
|
||||
std = np.array(self.input_std, dtype=np.float32)
|
||||
image = (image - mean) / std
|
||||
|
||||
# HWC -> CHW -> NCHW
|
||||
image = np.transpose(image, (2, 0, 1))
|
||||
image = np.expand_dims(image, axis=0).astype(np.float32)
|
||||
|
||||
return image
|
||||
|
||||
def _softmax(self, x: np.ndarray) -> np.ndarray:
|
||||
"""Apply softmax along axis 1."""
|
||||
e_x = np.exp(x - np.max(x, axis=1, keepdims=True))
|
||||
return e_x / e_x.sum(axis=1, keepdims=True)
|
||||
|
||||
def postprocess(self, outputs: Tuple[np.ndarray, np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Postprocess raw model outputs into gaze angles.
|
||||
|
||||
This method takes the raw output from the model's inference and
|
||||
converts it into pitch and yaw angles in radians.
|
||||
|
||||
Args:
|
||||
outputs: Raw outputs from the model inference. The format depends
|
||||
on the specific model architecture.
|
||||
|
||||
Returns:
|
||||
Tuple[np.ndarray, np.ndarray]: A tuple of (pitch, yaw) angles in radians.
|
||||
"""
|
||||
pitch_logits, yaw_logits = outputs
|
||||
|
||||
# Convert logits to probabilities
|
||||
pitch_probs = self._softmax(pitch_logits)
|
||||
yaw_probs = self._softmax(yaw_logits)
|
||||
|
||||
# Compute expected bin index (soft-argmax)
|
||||
pitch_deg = np.sum(pitch_probs * self._idx_tensor, axis=1) * self._binwidth - self._angle_offset
|
||||
yaw_deg = np.sum(yaw_probs * self._idx_tensor, axis=1) * self._binwidth - self._angle_offset
|
||||
|
||||
# Convert degrees to radians
|
||||
pitch = np.radians(pitch_deg[0])
|
||||
yaw = np.radians(yaw_deg[0])
|
||||
|
||||
return pitch, yaw
|
||||
|
||||
def estimate(self, face_image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Perform end-to-end gaze estimation on a face image.
|
||||
|
||||
This method orchestrates the full pipeline: preprocessing the input,
|
||||
running inference, and postprocessing to return the gaze direction.
|
||||
"""
|
||||
input_tensor = self.preprocess(face_image)
|
||||
outputs = self.session.run(self.output_names, {self.input_name: input_tensor})
|
||||
pitch, yaw = self.postprocess((outputs[0], outputs[1]))
|
||||
|
||||
return pitch, yaw
|
||||
@@ -126,3 +126,97 @@ def draw_fancy_bbox(
|
||||
# Bottom-right corner
|
||||
cv2.line(image, (x2, y2), (x2, y2 - corner_length), color, thickness)
|
||||
cv2.line(image, (x2, y2), (x2 - corner_length, y2), color, thickness)
|
||||
|
||||
|
||||
def draw_gaze(
|
||||
image: np.ndarray,
|
||||
bbox: np.ndarray,
|
||||
pitch: np.ndarray,
|
||||
yaw: np.ndarray,
|
||||
*,
|
||||
draw_bbox: bool = True,
|
||||
fancy_bbox: bool = True,
|
||||
draw_angles: bool = True,
|
||||
):
|
||||
"""
|
||||
Draws gaze direction with optional bounding box on an image.
|
||||
|
||||
Args:
|
||||
image: Input image to draw on (modified in-place).
|
||||
bbox: Face bounding box [x1, y1, x2, y2].
|
||||
pitch: Vertical gaze angle in radians.
|
||||
yaw: Horizontal gaze angle in radians.
|
||||
draw_bbox: Whether to draw the bounding box. Defaults to True.
|
||||
fancy_bbox: Use fancy corner-style bbox. Defaults to True.
|
||||
draw_angles: Whether to display pitch/yaw values as text. Defaults to False.
|
||||
"""
|
||||
x_min, y_min, x_max, y_max = map(int, bbox[:4])
|
||||
|
||||
# Calculate dynamic line thickness based on image size (same as draw_detections)
|
||||
line_thickness = max(round(sum(image.shape[:2]) / 2 * 0.003), 2)
|
||||
|
||||
# Calculate dynamic font scale based on bbox height (same as draw_detections)
|
||||
bbox_h = y_max - y_min
|
||||
font_scale = max(0.4, min(0.7, bbox_h / 200))
|
||||
font_thickness = 2
|
||||
|
||||
# Draw bounding box if requested
|
||||
if draw_bbox:
|
||||
if fancy_bbox:
|
||||
draw_fancy_bbox(image, bbox, color=(0, 255, 0), thickness=line_thickness)
|
||||
else:
|
||||
cv2.rectangle(image, (x_min, y_min), (x_max, y_max), (0, 255, 0), line_thickness)
|
||||
|
||||
# Calculate center of the bounding box
|
||||
x_center = (x_min + x_max) // 2
|
||||
y_center = (y_min + y_max) // 2
|
||||
|
||||
# Calculate the direction of the gaze
|
||||
length = x_max - x_min
|
||||
dx = int(-length * np.sin(pitch) * np.cos(yaw))
|
||||
dy = int(-length * np.sin(yaw))
|
||||
|
||||
point1 = (x_center, y_center)
|
||||
point2 = (x_center + dx, y_center + dy)
|
||||
|
||||
# Calculate dynamic center point radius based on line thickness
|
||||
center_radius = max(line_thickness + 1, 4)
|
||||
|
||||
# Draw gaze direction
|
||||
cv2.circle(image, (x_center, y_center), radius=center_radius, color=(0, 0, 255), thickness=-1)
|
||||
cv2.arrowedLine(
|
||||
image,
|
||||
point1,
|
||||
point2,
|
||||
color=(0, 0, 255),
|
||||
thickness=line_thickness,
|
||||
line_type=cv2.LINE_AA,
|
||||
tipLength=0.25,
|
||||
)
|
||||
|
||||
# Draw angle values
|
||||
if draw_angles:
|
||||
text = f'P:{np.degrees(pitch):.0f}deg Y:{np.degrees(yaw):.0f}deg'
|
||||
(text_width, text_height), baseline = cv2.getTextSize(
|
||||
text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, font_thickness
|
||||
)
|
||||
|
||||
# Draw background rectangle for text
|
||||
cv2.rectangle(
|
||||
image,
|
||||
(x_min, y_min - text_height - baseline - 10),
|
||||
(x_min + text_width + 10, y_min),
|
||||
(0, 0, 255),
|
||||
-1,
|
||||
)
|
||||
|
||||
# Draw text
|
||||
cv2.putText(
|
||||
image,
|
||||
text,
|
||||
(x_min + 5, y_min - 5),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
font_scale,
|
||||
(255, 255, 255),
|
||||
font_thickness,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user