Files
uniface/scripts/download_model.py
yakhyo 2c78f39e5d ref: Add comprehensive test suite and enhance model functionality
- Add new test files for age_gender, factory, landmark, recognition, scrfd, and utils
- Add new scripts for age_gender, landmarks, and video detection
- Update documentation in README.md, MODELS.md, QUICKSTART.md
- Improve model constants and face utilities
- Update detection models (retinaface, scrfd) with enhanced functionality
- Update project configuration in pyproject.toml
2025-11-15 21:09:37 +09:00

78 lines
2.7 KiB
Python

import argparse
from uniface.constants import (
RetinaFaceWeights, SphereFaceWeights, MobileFaceWeights, ArcFaceWeights,
SCRFDWeights, DDAMFNWeights, AgeGenderWeights, LandmarkWeights
)
from uniface.model_store import verify_model_weights
# All available model types
ALL_MODEL_TYPES = {
'retinaface': RetinaFaceWeights,
'sphereface': SphereFaceWeights,
'mobileface': MobileFaceWeights,
'arcface': ArcFaceWeights,
'scrfd': SCRFDWeights,
'ddamfn': DDAMFNWeights,
'agegender': AgeGenderWeights,
'landmark': LandmarkWeights,
}
def main():
parser = argparse.ArgumentParser(description="Download and verify model weights.")
parser.add_argument(
"--model-type",
type=str,
choices=list(ALL_MODEL_TYPES.keys()),
help="Model type to download (e.g. retinaface, arcface). If not specified, all models will be downloaded.",
)
parser.add_argument(
"--model",
type=str,
help="Specific model to download (e.g. MNET_V2). For RetinaFace backward compatibility.",
)
args = parser.parse_args()
if args.model and not args.model_type:
# Backward compatibility - assume RetinaFace
try:
weight = RetinaFaceWeights[args.model]
print(f"Downloading RetinaFace model: {weight.value}")
verify_model_weights(weight)
print("Model downloaded successfully.")
except KeyError:
print(f"Invalid RetinaFace model: {args.model}")
print(f"Available models: {[m.name for m in RetinaFaceWeights]}")
return
if args.model_type:
# Download all models from specific type
model_enum = ALL_MODEL_TYPES[args.model_type]
print(f"Downloading all {args.model_type} models...")
for weight in model_enum:
print(f"Downloading: {weight.value}")
try:
verify_model_weights(weight)
print(f"Downloaded: {weight.value}")
except Exception as e:
print(f"Failed to download {weight.value}: {e}")
else:
# Download all models from all types
print("Downloading all models...")
for model_type, model_enum in ALL_MODEL_TYPES.items():
print(f"\nDownloading {model_type} models...")
for weight in model_enum:
print(f"Downloading: {weight.value}")
try:
verify_model_weights(weight)
print(f"Downloaded: {weight.value}")
except Exception as e:
print(f"Failed to download {weight.value}: {e}")
print("\nDownload process completed.")
if __name__ == "__main__":
main()