Files
uniface/examples/08_gaze_estimation.ipynb

271 lines
1.2 MiB
Plaintext
Raw Permalink Normal View History

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Gaze Estimation with UniFace\n",
"\n",
"This notebook demonstrates gaze estimation using the **UniFace** library.\n",
"\n",
"## 1. Install UniFace"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%pip install -q uniface"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Import Libraries"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"UniFace version: 2.0.0\n"
]
}
],
"source": [
"import cv2\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from pathlib import Path\n",
"from PIL import Image\n",
"\n",
"import uniface\n",
"from uniface.detection import RetinaFace\n",
"from uniface.gaze import MobileGaze\n",
"from uniface.visualization import draw_gaze\n",
"\n",
"print(f\"UniFace version: {uniface.__version__}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Initialize Models"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"✓ Model loaded (CoreML (Apple Silicon))\n",
"✓ Model loaded (CoreML (Apple Silicon))\n"
]
}
],
"source": [
"# Initialize face detector\n",
"detector = RetinaFace(confidence_threshold=0.5)\n",
"\n",
"# Initialize gaze estimator (uses ResNet34 by default)\n",
"gaze_estimator = MobileGaze()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. Process All Test Images\n",
"\n",
"Display original images in the first row and gaze-annotated images in the second row."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processing: image0.jpg\n",
" Detected 1 face(s)\n",
" Face 1: pitch=-0.0°, yaw=7.1°\n",
"Processing: image1.jpg\n",
" Detected 1 face(s)\n",
" Face 1: pitch=-3.3°, yaw=-5.6°\n",
"Processing: image2.jpg\n",
" Detected 1 face(s)\n",
" Face 1: pitch=-3.9°, yaw=-0.3°\n",
"Processing: image3.jpg\n",
" Detected 1 face(s)\n",
" Face 1: pitch=-22.1°, yaw=1.0°\n",
"Processing: image4.jpg\n",
" Detected 1 face(s)\n",
" Face 1: pitch=2.1°, yaw=5.0°\n",
"\n",
"Processed 5 images\n"
]
}
],
"source": [
"# Get all test images\n",
"test_images_dir = Path('../assets/test_images')\n",
"test_images = sorted(test_images_dir.glob('*.jpg'))\n",
"\n",
"# Store original and processed images\n",
"original_images = []\n",
"processed_images = []\n",
"\n",
"for image_path in test_images:\n",
" print(f\"Processing: {image_path.name}\")\n",
"\n",
" # Load image\n",
" image = cv2.imread(str(image_path))\n",
" original = image.copy()\n",
"\n",
" # Detect faces\n",
" faces = detector.detect(image)\n",
" print(f' Detected {len(faces)} face(s)')\n",
"\n",
" # Estimate gaze for each face\n",
" for i, face in enumerate(faces):\n",
" x1, y1, x2, y2 = map(int, face.bbox[:4])\n",
" face_crop = image[y1:y2, x1:x2]\n",
"\n",
" if face_crop.size > 0:\n",
" pitch, yaw = gaze_estimator.estimate(face_crop)\n",
" pitch_deg = np.degrees(pitch)\n",
" yaw_deg = np.degrees(yaw)\n",
"\n",
" print(f' Face {i+1}: pitch={pitch_deg:.1f}°, yaw={yaw_deg:.1f}°')\n",
"\n",
" # Draw gaze without angle text\n",
" draw_gaze(image, face.bbox, pitch, yaw, draw_angles=False)\n",
"\n",
" # Convert BGR to RGB for display\n",
" original_rgb = cv2.cvtColor(original, cv2.COLOR_BGR2RGB)\n",
" processed_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
"\n",
" original_images.append(original_rgb)\n",
" processed_images.append(processed_rgb)\n",
"\n",
"print(f\"\\nProcessed {len(test_images)} images\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5. Visualize Results\n",
"\n",
"**First row**: Original images \n",
"**Second row**: Images with gaze direction arrows"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAB7QAAAMcCAYAAADQUZqvAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjMsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvZiW1igAAAAlwSFlzAAAPYQAAD2EBqD+naQABAABJREFUeJzsvQmwrdlZl/+dc8ce0p0QMpM0GTsJSWcEkhhAFDSCMihYCoJSlDhUaSElCKKlWOBfgxOFRYkTgwIORIKKU0AQQoWQkQ7pJJ2ZEObMPd3hnPuvZ+397P07713fPvtOfW+n19u9797nG9b4rnde79o5d+7cuWnAgAEDBgwYMGDAgAEDBgwYMGDAgAEDBgwYMGDAgAEDBgy4xmD3ajdgwIABAwYMGDBgwIABAwYMGDBgwIABAwYMGDBgwIABAwYM6MFwaA8YMGDAgAEDBgwYMGDAgAEDBgwYMGDAgAEDBgwYMGDAgGsShkN7wIABAwYMGDBgwIABAwYMGDBgwIABAwYMGDBgwIABAwZckzAc2gMGDBgwYMCAAQMGDBgwYMCAAQMGDBgwYMCAAQMGDBgw4JqE4dAeMGDAgAEDBgwYMGDAgAEDBgwYMGDAgAEDBgwYMGDAgAHXJAyH9oABAwYMGDBgwIABAwYMGDBgwIABAwYMGDBgwIABAwYMuCZhOLQHDBgwYMCAAQMGDBgwYMCAAQMGDBgwYMCAAQMGDBgwYMA1CcOhPWDAgAEDBgwYMGDAgAEDBgwYMGDAgAEDBgwYMGDAgAEDrkkYDu0BAwYMGDBgwIABAwYMGDBgwIABAwYMGDBgwIABAwYMGHBNwnBoD3jAwd/9u3932tnZuah3f/AHf7C9+773vW+6UkDZ1EFdAwYMGDDg6sLgGQMGDBgwYFsYPGPAgAEDBmwLg2cMGDBgwIBtYfCMAQMuDwyH9oD7Dd761rdOf+bP/JnpcY973HTixInpsY997PTVX/3V7fqDFfb396eXv/zl0xOf+MTp5MmT02233Tb92I/92NVu1oABAwZcdRg843z4ru/6rulLvuRLpkc96lFN0UAhGjBgwIABg2dUePvb3z59y7d8y/Tc5z53eshDHjI95jGPmb74i794ev3rX3+1mzZgwIABVx0GzzgIv/Ebv9HG49Zbb20846EPfej0WZ/1WdMP/dAPTefOnbvazRswYMCAqwqDZ2yGH/mRH2n2qRtvvPFqN2XAgwSGQ3vA/QL/5b/8l+n5z3/+9DM/8zPT133d103f933fN33913/99LM/+7Pt+k/8xE9sXdbf+lt/a7r33nsvqh1f8zVf09695ZZbpmsBvv3bv336G3/jb0xf+IVfOH3v937v9IQnPGH6qq/6quk//If/cLWbNmDAgAFXDQbPmO/L6173uul5z3ve1W7KgAEDBlwzMHjG+fCv//W/nv7Vv/pX0wtf+MLpH//jfzx90zd90/SOd7xjetGLXjT99E//9NVu3oABAwZcNRg843z4vd/7venXf/3Xp6/4iq+Y/tE/+kfTd37nd7ZAqD/35/5cs1kNGDBgwIMVBs/YDHfddVcLor3hhhuudlMGPIhg59wItxtwheHd735323mMs/bnf/7np0c84hEHBOfP+ZzPmT7wgQ9Mt99++/SkJz1ptpy77777AUEgSdHBjusf+IEfaArAHHzwgx9sz33DN3zD9M//+T9v11iOn/d5nze9973vbeUcOXLkfmz5gAEDBlx9GDxj87Of/umf3saBcfk7f+fvjF3aAwYMeFDD4Bl9eMMb3tB22uVOiQ996EPTM57xjOlpT3va9OpXv/p+avGAAQMGXDsweMaFwR/7Y3+sOW0+9rGPDdvUgAEDHnQweMbh8K3f+q3TK1/5yhZEyzcO7gEDrjSMHdoDrjh893d/93TPPfdM//Jf/ssDxB/41E/91On7v//7G3En9XY9V+KOO+5oO5Yf9rCHTS996UsP3EsgSumv/tW/2sojRRIpWXEY15SsvTMncA780T/6R5thh7RKpP6GEf3wD//wgTo+/OEPT3/9r//16dnPfnYzDt10003TH/kjf2T6lV/5lYsal5/8yZ+czpw5M/3lv/yXV9do21/6S3+pRce+5jWvuahyBwwYMOCBDINnzAN1DxgwYMCANQye0YcXvOAF56X9e/jDH94Mb29729suqswBAwYMeKDD4BkXBrSH8Tp9+vRlLXfAgAEDHggweMZmeOc73zn903/6T6d/8k/+yXT06NFLKmvAgAuB4dAecMXhv/23/9aILAaUHnzu535uu/9TP/VT5937yq/8ysY8/v7f//vTn//zf362DiKHSNn9RV/0RdM//If/cLruuuvaOXHbwrve9a6WXonU36Tlg+FQZp6H8Z73vKdFG8EsINbf/M3fPL3lLW9pO6o5c+hC4U1velOL0GKnRAJMyPsDBgwY8GCDwTMGDBgwYMC2MHjGhcFv/dZvNYPZgAEDBjwYYfCMzYBjhV2HOEw4P5tdei9+8YtbHwYMGDDgwQaDZ2yGb/zGb5w+//M/v7V9wID7E0b4xIArCqQmgjh+6Zd+6cbnSOHxX//rf50+8YlPtIgk4TnPec70oz/6oxvffeMb3zj9p//0nxohJTIIYNczZ1tsG23EmXKkD5FJ/ck/+Senxz/+8U2A5wwhgEimO++8c9rd3T1whsXTn/706d/8m38z/e2//benC4Hf/M3fnB71qEedF53FWUXAcHgMGDDgwQaDZwwYMGDAgG1h8IwLg1/4hV9oGaA4v2/AgAEDHmwweMbh8D3f8z3Tt33bt63+/oN/8A+2egcMGDDgwQaDZ2wGnPj/5//8n8ueGWTAgG1g7NAecEUBgg4kUe+B9z/+8Y8fuP4X/+JfPLSO//W//lf7ztTdwF/5K39l63Y+85nPPBBxRSoRzp0jikk4ceLEivjv7e21c+hI1cFzMKGLiX6lzAqkCPH+gAEDBjyYYPCMAQMGDBiwLQyesT38zu/8Tkt7yLl43/It33LJ5Q0YMGDAAw0Gzzgc/vSf/tPTq171quaEgWcAwy41YMCAByMMnjEPHEPx1/7aX2t9pP4BA+5vGA7tAVcUJOwyggtlFBhdDoP3v//9jTDXZ5/ylKds3c4nPOEJ510jTcdHPvKR1d/7+/stYuqpT31qYwak64NR3H777S1y60KBNCKnTp067/p99923uj9gwIABDyYYPGPAgAEDBmwLg2dsB5ztR4pBxuEnf/Inzztbe8CAAQMeDDB4xuFwyy23TF/wBV/QHNs/8iM/0s5i5e/h1B4wYMCDDQbPmAfK4niK7/iO77jgdwcMuBwwHNoDrijcfPPNLYU2RHITcP9xj3vcdNNNNx24fn85dY8cOdK9fu7cudVvzr34pm/6pnZGxr//9/9++t//+3+36NXP+IzPaMzhQoFx4Ry7rMNU5MBjH/vYCy5zwIABAx7IMHjGgAEDBgzYFgbP2G4HxR//43+8jQHO7Gc961kXXdaAAQMGPJBh8IwLB85l/cAHPtDS2Q4YMGDAgwkGz+gDDvDv/M7vbOeCsyv9fe97X/vcddddrU5+kxlqwIArCcOhPeCKAzsC3vve906vfvWrZ89zg+Dx3MVGkUKAqSPhXe9613Q54cd//Menz//8z2/nS/ypP/Wnpj/0h/5Qi1b96Ec/elHlPfe5z53uueee6W1ve9uB66997WtX9wcMGDDgwQaDZwwYMGDAgG1h8Ix5oN1f+7VfO/3Mz/xMSx/7eZ/3eZe1zQMGDBjwQIPBMy4M3Jk9sksNGDDgwQiDZ5wP7PzGef3yl7+87Sz384pXvKL5OPj9Dd/wDZe1/QMGVBgO7QFXHL75m7+5RSb9hb/wF9o5DQkf/vCH25kL119/fXvuYuAP/+E/3L6/7/u+78D17/3e750ud9RT3U39n//zf54++MEPXlR5X/qlXzodO3bsQLsp/1/8i3/Rorte8pKXXHKbBwwYMOCBBoNnDBgwYMCAbWHwjHng/L3/+B//Y2s7u7QHDBgw4MEOg2f04Xd/93e713F+7OzsTM9//vMvqtwBAwYMeCDD4BnnwyMf+cjpJ37iJ8774DA
"text/plain": [
"<Figure size 2000x800 with 10 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"num_images = len(original_images)\n",
"\n",
"# Create figure with 2 rows\n",
"fig, axes = plt.subplots(2, num_images, figsize=(4*num_images, 8))\n",
"\n",
"# Handle case where there's only one image\n",
"if num_images == 1:\n",
" axes = axes.reshape(2, 1)\n",
"\n",
"# First row: Original images\n",
"for i, img in enumerate(original_images):\n",
" axes[0, i].imshow(img)\n",
" axes[0, i].set_title(f'Original {i}', fontsize=12)\n",
" axes[0, i].axis('off')\n",
"\n",
"# Second row: Gaze-annotated images\n",
"for i, img in enumerate(processed_images):\n",
" axes[1, i].imshow(img)\n",
" axes[1, i].set_title(f'Gaze Estimation {i}', fontsize=12)\n",
" axes[1, i].axis('off')\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Notes\n",
"\n",
"- **Input**: Gaze estimation requires a face crop (obtained from face detection)\n",
"- **Output**: Returns (pitch, yaw) angles in radians\n",
"- **Visualization**: `draw_gaze()` automatically draws bounding box and gaze arrow\n",
"- **Models**: Trained on Gaze360 dataset with diverse head poses\n",
"- **Performance**: MAE (Mean Absolute Error) ranges from 11-13 degrees\n",
"\n",
"### Tips for Best Results\n",
"- Ensure faces are clearly visible and well-lit\n",
"- Works best with frontal to semi-profile faces\n",
"- Accuracy may vary with extreme head poses or occlusions"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}