diff --git a/MODELS.md b/MODELS.md index 4fe4f96..64cbaa2 100644 --- a/MODELS.md +++ b/MODELS.md @@ -404,6 +404,47 @@ print(f"Detected {len(np.unique(mask))} facial components") --- +## Anti-Spoofing Models + +### MiniFASNet Family + +Lightweight face anti-spoofing models for liveness detection. Detect if a face is real (live) or fake (photo, video replay, mask). + +| Model Name | Size | Scale | Use Case | +| ---------- | ------ | ----- | ----------------------------- | +| `V1SE` | 1.2 MB | 4.0 | Squeeze-and-excitation variant | +| `V2` ⭐ | 1.2 MB | 2.7 | **Recommended default** | + +**Dataset**: Trained on face anti-spoofing datasets +**Output**: Returns (label_idx, score) where label_idx: 0=Fake, 1=Real + +#### Usage + +```python +from uniface import RetinaFace +from uniface.spoofing import MiniFASNet +from uniface.constants import MiniFASNetWeights + +# Default (V2, recommended) +detector = RetinaFace() +spoofer = MiniFASNet() + +# V1SE variant +spoofer = MiniFASNet(model_name=MiniFASNetWeights.V1SE) + +# Detect and check liveness +faces = detector.detect(image) +for face in faces: + label_idx, score = spoofer.predict(image, face['bbox']) + # label_idx: 0 = Fake, 1 = Real + label = 'Real' if label_idx == 1 else 'Fake' + print(f"{label}: {score:.1%}") +``` + +**Note**: Requires face bounding box from a detector. Use with RetinaFace, SCRFD, or YOLOv5Face. + +--- + ## Model Updates Models are automatically downloaded and cached on first use. Cache location: `~/.uniface/models/` @@ -445,6 +486,7 @@ python scripts/download_model.py --model MNET_V2 - **Face Recognition Training**: [yakhyo/face-recognition](https://github.com/yakhyo/face-recognition) - ArcFace, MobileFace, SphereFace training code - **Gaze Estimation Training**: [yakhyo/gaze-estimation](https://github.com/yakhyo/gaze-estimation) - MobileGaze training code and pretrained weights - **Face Parsing Training**: [yakhyo/face-parsing](https://github.com/yakhyo/face-parsing) - BiSeNet training code and pretrained weights +- **Face Anti-Spoofing**: [yakhyo/face-anti-spoofing](https://github.com/yakhyo/face-anti-spoofing) - MiniFASNet ONNX inference (weights from [minivision-ai/Silent-Face-Anti-Spoofing](https://github.com/minivision-ai/Silent-Face-Anti-Spoofing)) - **InsightFace**: [deepinsight/insightface](https://github.com/deepinsight/insightface) - Model architectures and pretrained weights ### Papers diff --git a/QUICKSTART.md b/QUICKSTART.md index ba01448..eb76c8b 100644 --- a/QUICKSTART.md +++ b/QUICKSTART.md @@ -420,7 +420,46 @@ python scripts/run_anonymization.py --image photo.jpg --method gaussian --blur-s --- -## 10. Batch Processing (3 minutes) +## 10. Face Anti-Spoofing (2 minutes) + +Detect if a face is real or fake (photo, video replay, mask): + +```python +from uniface import RetinaFace +from uniface.spoofing import MiniFASNet + +detector = RetinaFace() +spoofer = MiniFASNet() # Uses V2 by default + +image = cv2.imread("photo.jpg") +faces = detector.detect(image) + +for i, face in enumerate(faces): + label_idx, score = spoofer.predict(image, face['bbox']) + # label_idx: 0 = Fake, 1 = Real + label = 'Real' if label_idx == 1 else 'Fake' + print(f"Face {i+1}: {label} ({score:.1%})") +``` + +**Output:** + +``` +Face 1: Real (98.5%) +``` + +**Command-line tool:** + +```bash +# Image +python scripts/run_spoofing.py --image photo.jpg + +# Webcam +python scripts/run_spoofing.py --source 0 +``` + +--- + +## 11. Batch Processing (3 minutes) Process multiple images: @@ -453,7 +492,7 @@ print("Done!") --- -## 11. Model Selection +## 12. Model Selection Choose the right model for your use case: diff --git a/README.md b/README.md index f1b1e5a..6d9e46d 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,7 @@ - **Face Parsing**: BiSeNet-based semantic segmentation with 19 facial component classes - **Gaze Estimation**: Real-time gaze direction prediction with MobileGaze - **Attribute Analysis**: Age, gender, and emotion detection +- **Anti-Spoofing**: Face liveness detection with MiniFASNet models - **Face Anonymization**: Privacy-preserving face blurring with 5 methods (pixelate, gaussian, blackout, elliptical, median) - **Face Alignment**: Precise alignment for downstream tasks - **Hardware Acceleration**: ARM64 optimizations (Apple Silicon), CUDA (NVIDIA), CPU fallback @@ -199,6 +200,25 @@ vis_result = vis_parsing_maps(face_rgb, mask, save_image=False) print(f"Unique classes: {len(np.unique(mask))}") ``` +### Face Anti-Spoofing + +Detect if a face is real or fake (photo, video replay, mask): + +```python +from uniface import RetinaFace +from uniface.spoofing import MiniFASNet + +detector = RetinaFace() +spoofer = MiniFASNet() # Uses V2 by default + +faces = detector.detect(image) +for face in faces: + label_idx, score = spoofer.predict(image, face['bbox']) + # label_idx: 0 = Fake, 1 = Real + label = 'Real' if label_idx == 1 else 'Fake' + print(f"{label}: {score:.1%}") +``` + ### Face Anonymization Protect privacy by blurring or pixelating faces with 5 different methods: @@ -364,6 +384,12 @@ faces = detect_faces(image, method='retinaface', conf_thresh=0.8) # methods: re | ---------- | ---------------------------------------- | ------------------------------------ | | `BiSeNet` | `model_name=ParsingWeights.RESNET18`, `input_size=(512, 512)` | 19 facial component classes; BiSeNet architecture with ResNet backbone | +**Anti-Spoofing** + +| Class | Key params (defaults) | Notes | +| ------------- | ----------------------------------------- | ------------------------------------ | +| `MiniFASNet` | `model_name=MiniFASNetWeights.V2` | Returns (label_idx, score); 0=Fake, 1=Real | + --- ## Model Performance @@ -606,6 +632,7 @@ uniface/ │ ├── parsing/ # Face parsing │ ├── gaze/ # Gaze estimation │ ├── attribute/ # Age, gender, emotion +│ ├── spoofing/ # Face anti-spoofing │ ├── privacy/ # Face anonymization & blurring │ ├── onnx_utils.py # ONNX Runtime utilities │ ├── model_store.py # Model download & caching @@ -624,6 +651,7 @@ uniface/ - **Face Recognition Training**: [yakhyo/face-recognition](https://github.com/yakhyo/face-recognition) - ArcFace, MobileFace, SphereFace training code - **Face Parsing Training**: [yakhyo/face-parsing](https://github.com/yakhyo/face-parsing) - BiSeNet face parsing training code and pretrained weights - **Gaze Estimation Training**: [yakhyo/gaze-estimation](https://github.com/yakhyo/gaze-estimation) - MobileGaze training code and pretrained weights +- **Face Anti-Spoofing**: [yakhyo/face-anti-spoofing](https://github.com/yakhyo/face-anti-spoofing) - MiniFASNet ONNX inference (weights from [minivision-ai/Silent-Face-Anti-Spoofing](https://github.com/minivision-ai/Silent-Face-Anti-Spoofing)) - **InsightFace**: [deepinsight/insightface](https://github.com/deepinsight/insightface) - Model architectures and pretrained weights ## Contributing diff --git a/examples/face_anonymization.ipynb b/examples/face_anonymization.ipynb index ac80fdd..686b89a 100644 --- a/examples/face_anonymization.ipynb +++ b/examples/face_anonymization.ipynb @@ -159,7 +159,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [ { diff --git a/scripts/run_anonymization.py b/scripts/run_anonymization.py index 1e2e00c..3c8d432 100644 --- a/scripts/run_anonymization.py +++ b/scripts/run_anonymization.py @@ -205,3 +205,4 @@ Examples: if __name__ == '__main__': main() + diff --git a/scripts/run_spoofing.py b/scripts/run_spoofing.py new file mode 100644 index 0000000..e30bb9b --- /dev/null +++ b/scripts/run_spoofing.py @@ -0,0 +1,201 @@ +# Face Anti-Spoofing Detection +# Usage: +# Image: python run_spoofing.py --image path/to/image.jpg +# Video: python run_spoofing.py --video path/to/video.mp4 +# Webcam: python run_spoofing.py --source 0 + +import argparse +import os +from pathlib import Path + +import cv2 +import numpy as np + +from uniface import RetinaFace +from uniface.constants import MiniFASNetWeights +from uniface.spoofing import create_spoofer + + +def draw_spoofing_result( + image: np.ndarray, + bbox: list, + label_idx: int, + score: float, + thickness: int = 2, +) -> None: + """Draw bounding box with anti-spoofing result. + + Args: + image: Input image to draw on. + bbox: Bounding box in [x1, y1, x2, y2] format. + label_idx: Prediction label index (0 = Fake, 1 = Real). + score: Confidence score (0.0 to 1.0). + thickness: Line thickness for bounding box. + """ + x1, y1, x2, y2 = map(int, bbox[:4]) + + # Color based on result (green for real, red for fake) + is_real = label_idx == 1 + color = (0, 255, 0) if is_real else (0, 0, 255) + + # Draw bounding box + cv2.rectangle(image, (x1, y1), (x2, y2), color, thickness) + + # Prepare label + label = 'Real' if is_real else 'Fake' + text = f'{label}: {score:.1%}' + + # Draw label background + (tw, th), baseline = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2) + cv2.rectangle(image, (x1, y1 - th - 10), (x1 + tw + 10, y1), color, -1) + + # Draw label text + cv2.putText(image, text, (x1 + 5, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) + + +def process_image(detector, spoofer, image_path: str, save_dir: str = 'outputs') -> None: + """Process a single image for face anti-spoofing detection.""" + image = cv2.imread(image_path) + if image is None: + print(f"Error: Failed to load image from '{image_path}'") + return + + # Detect faces + faces = detector.detect(image) + print(f'Detected {len(faces)} face(s)') + + if not faces: + print('No faces detected in the image.') + return + + # Run anti-spoofing on each face + for i, face in enumerate(faces, 1): + label_idx, score = spoofer.predict(image, face['bbox']) + # label_idx: 0 = Fake, 1 = Real + label = 'Real' if label_idx == 1 else 'Fake' + print(f' Face {i}: {label} ({score:.1%})') + + # Draw result on image + draw_spoofing_result(image, face['bbox'], label_idx, score) + + # Save output + os.makedirs(save_dir, exist_ok=True) + output_path = os.path.join(save_dir, f'{Path(image_path).stem}_spoofing.jpg') + cv2.imwrite(output_path, image) + print(f'Output saved: {output_path}') + + +def process_video(detector, spoofer, source, save_dir: str = 'outputs') -> None: + """Process video or webcam stream for face anti-spoofing detection.""" + # Handle webcam or video file + if isinstance(source, int) or source.isdigit(): + cap = cv2.VideoCapture(int(source)) + is_webcam = True + output_name = 'webcam_spoofing.mp4' + else: + cap = cv2.VideoCapture(source) + is_webcam = False + output_name = f'{Path(source).stem}_spoofing.mp4' + + if not cap.isOpened(): + print(f'Error: Failed to open video source: {source}') + return + + # Get video properties + fps = int(cap.get(cv2.CAP_PROP_FPS)) if not is_webcam else 30 + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + + # Setup video writer + os.makedirs(save_dir, exist_ok=True) + output_path = os.path.join(save_dir, output_name) + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) + + print("Processing video... Press 'q' to quit") + frame_count = 0 + + try: + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + + frame_count += 1 + + # Detect faces + faces = detector.detect(frame) + + # Run anti-spoofing on each face + for face in faces: + label_idx, score = spoofer.predict(frame, face['bbox']) + draw_spoofing_result(frame, face['bbox'], label_idx, score) + + # Write frame + writer.write(frame) + + # Display frame + cv2.imshow('Face Anti-Spoofing', frame) + if cv2.waitKey(1) & 0xFF == ord('q'): + print('Stopped by user.') + break + + finally: + cap.release() + writer.release() + cv2.destroyAllWindows() + + print(f'Processed {frame_count} frames') + if not is_webcam: + print(f'Output saved: {output_path}') + + +def main(): + parser = argparse.ArgumentParser(description='Face Anti-Spoofing Detection') + parser.add_argument('--image', type=str, help='Path to input image') + parser.add_argument('--video', type=str, help='Path to input video') + parser.add_argument('--source', type=str, help='Video source (0 for webcam)') + parser.add_argument( + '--model', + type=str, + default='v2', + choices=['v1se', 'v2'], + help='Model variant: v1se or v2 (default: v2)', + ) + parser.add_argument('--scale', type=float, default=None, help='Custom crop scale (default: auto)') + parser.add_argument('--save_dir', type=str, default='outputs', help='Output directory') + args = parser.parse_args() + + # Check that at least one input source is provided + if not any([args.image, args.video, args.source]): + parser.print_help() + print('\nError: Please provide --image, --video, or --source') + return + + # Select model variant + model_name = MiniFASNetWeights.V1SE if args.model == 'v1se' else MiniFASNetWeights.V2 + + # Initialize models + print(f'Initializing models (MiniFASNet {args.model.upper()})...') + detector = RetinaFace() + spoofer = create_spoofer(model_name=model_name, scale=args.scale) + + # Process input + if args.image: + if not os.path.exists(args.image): + print(f'Error: Image not found: {args.image}') + return + process_image(detector, spoofer, args.image, args.save_dir) + + elif args.video: + if not os.path.exists(args.video): + print(f'Error: Video not found: {args.video}') + return + process_video(detector, spoofer, args.video, args.save_dir) + + elif args.source: + process_video(detector, spoofer, args.source, args.save_dir) + + +if __name__ == '__main__': + main() diff --git a/uniface/__init__.py b/uniface/__init__.py index 2c9d42f..dc8ea3d 100644 --- a/uniface/__init__.py +++ b/uniface/__init__.py @@ -42,6 +42,7 @@ from .landmark import Landmark106, create_landmarker from .parsing import BiSeNet, create_face_parser from .privacy import BlurFace, anonymize_faces from .recognition import ArcFace, MobileFace, SphereFace, create_recognizer +from .spoofing import MiniFASNet, create_spoofer __all__ = [ '__author__', @@ -56,6 +57,7 @@ __all__ = [ 'create_gaze_estimator', 'create_landmarker', 'create_recognizer', + 'create_spoofer', 'detect_faces', 'list_available_detectors', # Detection models @@ -75,6 +77,8 @@ __all__ = [ # Attribute models 'AgeGender', 'Emotion', + # Spoofing models + 'MiniFASNet', # Privacy 'BlurFace', 'anonymize_faces', diff --git a/uniface/constants.py b/uniface/constants.py index 9752513..7ad08f6 100644 --- a/uniface/constants.py +++ b/uniface/constants.py @@ -119,6 +119,20 @@ class ParsingWeights(str, Enum): RESNET34 = "parsing_resnet34" +class MiniFASNetWeights(str, Enum): + """ + MiniFASNet: Lightweight Face Anti-Spoofing models. + Trained on face anti-spoofing datasets. + https://github.com/yakhyo/face-anti-spoofing + + Model Variants: + - V1SE: Uses scale=4.0 for face crop (squeese-and-excitation version) + - V2: Uses scale=2.7 for face crop (improved version) + """ + V1SE = "minifasnet_v1se" + V2 = "minifasnet_v2" + + MODEL_URLS: Dict[Enum, str] = { # RetinaFace RetinaFaceWeights.MNET_025: 'https://github.com/yakhyo/uniface/releases/download/weights/retinaface_mv1_0.25.onnx', @@ -161,6 +175,9 @@ MODEL_URLS: Dict[Enum, str] = { # Parsing ParsingWeights.RESNET18: 'https://github.com/yakhyo/face-parsing/releases/download/weights/resnet18.onnx', ParsingWeights.RESNET34: 'https://github.com/yakhyo/face-parsing/releases/download/weights/resnet34.onnx', + # Anti-Spoofing (MiniFASNet) + MiniFASNetWeights.V1SE: 'https://github.com/yakhyo/face-anti-spoofing/releases/download/weights/MiniFASNetV1SE.onnx', + MiniFASNetWeights.V2: 'https://github.com/yakhyo/face-anti-spoofing/releases/download/weights/MiniFASNetV2.onnx', } MODEL_SHA256: Dict[Enum, str] = { @@ -205,6 +222,9 @@ MODEL_SHA256: Dict[Enum, str] = { # Face Parsing ParsingWeights.RESNET18: '0d9bd318e46987c3bdbfacae9e2c0f461cae1c6ac6ea6d43bbe541a91727e33f', ParsingWeights.RESNET34: '5b805bba7b5660ab7070b5a381dcf75e5b3e04199f1e9387232a77a00095102e', + # Anti-Spoofing (MiniFASNet) + MiniFASNetWeights.V1SE: 'ebab7f90c7833fbccd46d3a555410e78d969db5438e169b6524be444862b3676', + MiniFASNetWeights.V2: 'b32929adc2d9c34b9486f8c4c7bc97c1b69bc0ea9befefc380e4faae4e463907', } CHUNK_SIZE = 8192 diff --git a/uniface/spoofing/__init__.py b/uniface/spoofing/__init__.py new file mode 100644 index 0000000..352140e --- /dev/null +++ b/uniface/spoofing/__init__.py @@ -0,0 +1,64 @@ +# Copyright 2025 Yakhyokhuja Valikhujaev +# Author: Yakhyokhuja Valikhujaev +# GitHub: https://github.com/yakhyo + +from typing import Optional + +from uniface.constants import MiniFASNetWeights + +from .base import BaseSpoofer +from .minifasnet import MiniFASNet + +__all__ = [ + 'BaseSpoofer', + 'MiniFASNet', + 'MiniFASNetWeights', + 'create_spoofer', +] + + +def create_spoofer( + model_name: MiniFASNetWeights = MiniFASNetWeights.V2, + scale: Optional[float] = None, +) -> MiniFASNet: + """ + Factory function to create a face anti-spoofing model. + + This is a convenience function that creates a MiniFASNet instance + with the specified model variant and optional custom scale. + + Args: + model_name (MiniFASNetWeights): The model variant to use. + Options: + - MiniFASNetWeights.V2: Improved version (default), uses scale=2.7 + - MiniFASNetWeights.V1SE: Squeeze-and-excitation version, uses scale=4.0 + Defaults to MiniFASNetWeights.V2. + scale (Optional[float]): Custom crop scale factor for face region. + If None, uses the default scale for the selected model variant. + + Returns: + MiniFASNet: An initialized face anti-spoofing model. + + Example: + >>> from uniface.spoofing import create_spoofer, MiniFASNetWeights + >>> from uniface import RetinaFace + >>> + >>> # Create with default settings (V2 model) + >>> spoofer = create_spoofer() + >>> + >>> # Create with V1SE model + >>> spoofer = create_spoofer(model_name=MiniFASNetWeights.V1SE) + >>> + >>> # Create with custom scale + >>> spoofer = create_spoofer(scale=3.0) + >>> + >>> # Use with face detector + >>> detector = RetinaFace() + >>> faces = detector.detect(image) + >>> for face in faces: + ... label_idx, score = spoofer.predict(image, face['bbox']) + ... # label_idx: 0 = Fake, 1 = Real + ... label = 'Real' if label_idx == 1 else 'Fake' + ... print(f'{label}: {score:.2%}') + """ + return MiniFASNet(model_name=model_name, scale=scale) diff --git a/uniface/spoofing/base.py b/uniface/spoofing/base.py new file mode 100644 index 0000000..b16decd --- /dev/null +++ b/uniface/spoofing/base.py @@ -0,0 +1,117 @@ +# Copyright 2025 Yakhyokhuja Valikhujaev +# Author: Yakhyokhuja Valikhujaev +# GitHub: https://github.com/yakhyo + +from abc import ABC, abstractmethod +from typing import List, Tuple, Union + +import numpy as np + + +class BaseSpoofer(ABC): + """ + Abstract base class for all face anti-spoofing models. + + This class defines the common interface that all anti-spoofing models must implement, + ensuring consistency across different spoofing detection methods. Anti-spoofing models + detect whether a face is real (live person) or fake (photo, video, mask, etc.). + + The prediction returns a tuple of (label_idx, score): + - label_idx: 0 = Fake (spoof), 1 = Real (live) + - score: Confidence score for the predicted label (0.0 to 1.0) + """ + + @abstractmethod + def _initialize_model(self) -> None: + """ + Initialize the underlying model for inference. + + This method should handle loading model weights, creating the + inference session (e.g., ONNX Runtime), and any necessary + setup procedures to prepare the model for prediction. + + Raises: + RuntimeError: If the model fails to load or initialize. + """ + raise NotImplementedError('Subclasses must implement the _initialize_model method.') + + @abstractmethod + def preprocess(self, image: np.ndarray, bbox: Union[List, np.ndarray]) -> np.ndarray: + """ + Preprocess the input image for model inference. + + This method should crop the face region using the bounding box, + resize it to the model's expected input size, and normalize + the pixel values as required by the model. + + Args: + image (np.ndarray): Input image in BGR format with shape (H, W, C). + bbox (Union[List, np.ndarray]): Face bounding box in [x1, y1, x2, y2] format. + + Returns: + np.ndarray: The preprocessed image tensor ready for inference, + typically with shape (1, C, H, W). + """ + raise NotImplementedError('Subclasses must implement the preprocess method.') + + @abstractmethod + def postprocess(self, outputs: np.ndarray) -> Tuple[int, float]: + """ + Postprocess raw model outputs into prediction result. + + This method takes the raw output from the model's inference and + converts it into a label index and confidence score. + + Args: + outputs (np.ndarray): Raw outputs from the model inference (logits). + + Returns: + Tuple[int, float]: A tuple of (label_idx, score) where: + - label_idx: 0 = Fake (spoof), 1 = Real (live) + - score: Confidence score for the predicted label (0.0 to 1.0) + """ + raise NotImplementedError('Subclasses must implement the postprocess method.') + + @abstractmethod + def predict(self, image: np.ndarray, bbox: Union[List, np.ndarray]) -> Tuple[int, float]: + """ + Perform end-to-end anti-spoofing prediction on a face. + + This method orchestrates the full pipeline: preprocessing the input, + running inference, and postprocessing to return the prediction. + + Args: + image (np.ndarray): Input image in BGR format containing the face. + bbox (Union[List, np.ndarray]): Face bounding box in [x1, y1, x2, y2] format. + This is typically obtained from a face detector. + + Returns: + Tuple[int, float]: A tuple of (label_idx, score) where: + - label_idx: 0 = Fake (spoof), 1 = Real (live) + - score: Confidence score for the predicted label (0.0 to 1.0) + + Example: + >>> spoofer = MiniFASNet() + >>> detector = RetinaFace() + >>> faces = detector.detect(image) + >>> for face in faces: + ... label_idx, score = spoofer.predict(image, face['bbox']) + ... label = 'Real' if label_idx == 1 else 'Fake' + ... print(f'{label}: {score:.2%}') + """ + raise NotImplementedError('Subclasses must implement the predict method.') + + def __call__(self, image: np.ndarray, bbox: Union[List, np.ndarray]) -> Tuple[int, float]: + """ + Provides a convenient, callable shortcut for the `predict` method. + + Args: + image (np.ndarray): Input image in BGR format. + bbox (Union[List, np.ndarray]): Face bounding box in [x1, y1, x2, y2] format. + + Returns: + Tuple[int, float]: A tuple of (label_idx, score) where: + - label_idx: 0 = Fake (spoof), 1 = Real (live) + - score: Confidence score for the predicted label (0.0 to 1.0) + """ + return self.predict(image, bbox) diff --git a/uniface/spoofing/minifasnet.py b/uniface/spoofing/minifasnet.py new file mode 100644 index 0000000..0a43885 --- /dev/null +++ b/uniface/spoofing/minifasnet.py @@ -0,0 +1,225 @@ +# Copyright 2025 Yakhyokhuja Valikhujaev +# Author: Yakhyokhuja Valikhujaev +# GitHub: https://github.com/yakhyo + +from typing import List, Optional, Tuple, Union + +import cv2 +import numpy as np + +from uniface.constants import MiniFASNetWeights +from uniface.log import Logger +from uniface.model_store import verify_model_weights +from uniface.onnx_utils import create_onnx_session + +from .base import BaseSpoofer + +__all__ = ['MiniFASNet'] + +# Default crop scales for each model variant +DEFAULT_SCALES = { + MiniFASNetWeights.V1SE: 4.0, + MiniFASNetWeights.V2: 2.7, +} + + +class MiniFASNet(BaseSpoofer): + """ + MiniFASNet: Lightweight Face Anti-Spoofing with ONNX Runtime. + + MiniFASNet is a face anti-spoofing model that detects whether a face is real + (live person) or fake (photo, video replay, mask, etc.). It supports two model + variants: V1SE (with squeeze-and-excitation) and V2 (improved version). + + The model takes a face region cropped from the image using a bounding box + and predicts whether it's a real or spoofed face. + + Reference: + https://github.com/yakhyo/face-anti-spoofing + + Args: + model_name (MiniFASNetWeights): The enum specifying the model variant to load. + Options: V1SE (scale=4.0), V2 (scale=2.7). + Defaults to `MiniFASNetWeights.V2`. + scale (Optional[float]): Custom crop scale factor for face region. + If None, uses the default scale for the selected model variant. + V1SE uses 4.0, V2 uses 2.7. + + Attributes: + scale (float): Crop scale factor for face region extraction. + input_size (Tuple[int, int]): Model input dimensions (width, height). + + Example: + >>> from uniface.spoofing import MiniFASNet + >>> from uniface import RetinaFace + >>> + >>> detector = RetinaFace() + >>> spoofer = MiniFASNet() + >>> + >>> # Detect faces and check if they are real + >>> faces = detector.detect(image) + >>> for face in faces: + ... label_idx, score = spoofer.predict(image, face['bbox']) + ... # label_idx: 0 = Fake, 1 = Real + ... label = 'Real' if label_idx == 1 else 'Fake' + ... print(f'{label}: {score:.2%}') + """ + + def __init__( + self, + model_name: MiniFASNetWeights = MiniFASNetWeights.V2, + scale: Optional[float] = None, + ) -> None: + Logger.info(f'Initializing MiniFASNet with model={model_name.name}') + + # Use default scale for the model variant if not specified + self.scale = scale if scale is not None else DEFAULT_SCALES.get(model_name, 2.7) + + self.model_path = verify_model_weights(model_name) + self._initialize_model() + + def _initialize_model(self) -> None: + """ + Initialize the ONNX model from the stored model path. + + Raises: + RuntimeError: If the model fails to load or initialize. + """ + try: + self.session = create_onnx_session(self.model_path) + + # Get input configuration + input_cfg = self.session.get_inputs()[0] + self.input_name = input_cfg.name + # Input shape is (batch, channels, height, width) - we need (width, height) + self.input_size = tuple(input_cfg.shape[2:4][::-1]) # (width, height) + + # Get output configuration + output_cfg = self.session.get_outputs()[0] + self.output_name = output_cfg.name + + Logger.info(f'MiniFASNet initialized with input size {self.input_size}, scale={self.scale}') + + except Exception as e: + Logger.error(f"Failed to load MiniFASNet model from '{self.model_path}'", exc_info=True) + raise RuntimeError(f'Failed to initialize MiniFASNet model: {e}') from e + + def _xyxy_to_xywh(self, bbox: Union[List, np.ndarray]) -> List[int]: + """Convert bounding box from [x1, y1, x2, y2] to [x, y, w, h] format.""" + x1, y1, x2, y2 = bbox[:4] + return [int(x1), int(y1), int(x2 - x1), int(y2 - y1)] + + def _crop_face(self, image: np.ndarray, bbox_xywh: List[int]) -> np.ndarray: + """ + Crop and resize face region from image using scale factor. + + The crop is centered on the face bounding box and scaled to capture + more context around the face, which is important for anti-spoofing. + + Args: + image: Input image in BGR format. + bbox_xywh: Face bounding box in [x, y, w, h] format. + + Returns: + Cropped and resized face region. + """ + src_h, src_w = image.shape[:2] + x, y, box_w, box_h = bbox_xywh + + # Calculate the scale to apply based on image and face size + scale = min((src_h - 1) / box_h, (src_w - 1) / box_w, self.scale) + new_w = box_w * scale + new_h = box_h * scale + + # Calculate center of the bounding box + center_x = x + box_w / 2 + center_y = y + box_h / 2 + + # Calculate new bounding box coordinates + x1 = max(0, int(center_x - new_w / 2)) + y1 = max(0, int(center_y - new_h / 2)) + x2 = min(src_w - 1, int(center_x + new_w / 2)) + y2 = min(src_h - 1, int(center_y + new_h / 2)) + + # Crop and resize + cropped = image[y1 : y2 + 1, x1 : x2 + 1] + resized = cv2.resize(cropped, self.input_size) + + return resized + + def preprocess(self, image: np.ndarray, bbox: Union[List, np.ndarray]) -> np.ndarray: + """ + Preprocess the input image for model inference. + + Crops the face region, converts to float32, and arranges + dimensions for the model (NCHW format). + + Args: + image: Input image in BGR format with shape (H, W, C). + bbox: Face bounding box in [x1, y1, x2, y2] format. + + Returns: + Preprocessed image tensor with shape (1, C, H, W). + """ + # Convert bbox format + bbox_xywh = self._xyxy_to_xywh(bbox) + + # Crop and resize face region + face = self._crop_face(image, bbox_xywh) + + # Convert to float32 (no normalization needed for this model) + face = face.astype(np.float32) + + # HWC -> CHW -> NCHW + face = np.transpose(face, (2, 0, 1)) + face = np.expand_dims(face, axis=0) + + return face + + def _softmax(self, x: np.ndarray) -> np.ndarray: + """Apply softmax to logits along axis 1.""" + e_x = np.exp(x - np.max(x, axis=1, keepdims=True)) + return e_x / e_x.sum(axis=1, keepdims=True) + + def postprocess(self, outputs: np.ndarray) -> Tuple[int, float]: + """ + Postprocess raw model outputs into prediction result. + + Applies softmax to convert logits to probabilities and + returns the predicted label index and confidence score. + + Args: + outputs: Raw outputs from the model inference (logits). + + Returns: + Tuple[int, float]: A tuple of (label_idx, score) where: + - label_idx: 0 = Fake (spoof), 1 = Real (live) + - score: Confidence score for the predicted label (0.0 to 1.0) + """ + probs = self._softmax(outputs) + label_idx = int(np.argmax(probs)) + score = float(probs[0, label_idx]) + + return label_idx, score + + def predict(self, image: np.ndarray, bbox: Union[List, np.ndarray]) -> Tuple[int, float]: + """ + Perform end-to-end anti-spoofing prediction on a face. + + Args: + image: Input image in BGR format containing the face. + bbox: Face bounding box in [x1, y1, x2, y2] format. + + Returns: + Tuple[int, float]: A tuple of (label_idx, score) where: + - label_idx: 0 = Fake (spoof), 1 = Real (live) + - score: Confidence score for the predicted label (0.0 to 1.0) + """ + # Preprocess + input_tensor = self.preprocess(image, bbox) + + # Run inference + outputs = self.session.run([self.output_name], {self.input_name: input_tensor})[0] + + # Postprocess and return + return self.postprocess(outputs)