mirror of
https://github.com/yakhyo/uniface.git
synced 2025-12-30 09:02:25 +00:00
- Add new test files for age_gender, factory, landmark, recognition, scrfd, and utils - Add new scripts for age_gender, landmarks, and video detection - Update documentation in README.md, MODELS.md, QUICKSTART.md - Improve model constants and face utilities - Update detection models (retinaface, scrfd) with enhanced functionality - Update project configuration in pyproject.toml
78 lines
2.7 KiB
Python
78 lines
2.7 KiB
Python
import argparse
|
|
from uniface.constants import (
|
|
RetinaFaceWeights, SphereFaceWeights, MobileFaceWeights, ArcFaceWeights,
|
|
SCRFDWeights, DDAMFNWeights, AgeGenderWeights, LandmarkWeights
|
|
)
|
|
from uniface.model_store import verify_model_weights
|
|
|
|
|
|
# All available model types
|
|
ALL_MODEL_TYPES = {
|
|
'retinaface': RetinaFaceWeights,
|
|
'sphereface': SphereFaceWeights,
|
|
'mobileface': MobileFaceWeights,
|
|
'arcface': ArcFaceWeights,
|
|
'scrfd': SCRFDWeights,
|
|
'ddamfn': DDAMFNWeights,
|
|
'agegender': AgeGenderWeights,
|
|
'landmark': LandmarkWeights,
|
|
}
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Download and verify model weights.")
|
|
parser.add_argument(
|
|
"--model-type",
|
|
type=str,
|
|
choices=list(ALL_MODEL_TYPES.keys()),
|
|
help="Model type to download (e.g. retinaface, arcface). If not specified, all models will be downloaded.",
|
|
)
|
|
parser.add_argument(
|
|
"--model",
|
|
type=str,
|
|
help="Specific model to download (e.g. MNET_V2). For RetinaFace backward compatibility.",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
if args.model and not args.model_type:
|
|
# Backward compatibility - assume RetinaFace
|
|
try:
|
|
weight = RetinaFaceWeights[args.model]
|
|
print(f"Downloading RetinaFace model: {weight.value}")
|
|
verify_model_weights(weight)
|
|
print("Model downloaded successfully.")
|
|
except KeyError:
|
|
print(f"Invalid RetinaFace model: {args.model}")
|
|
print(f"Available models: {[m.name for m in RetinaFaceWeights]}")
|
|
return
|
|
|
|
if args.model_type:
|
|
# Download all models from specific type
|
|
model_enum = ALL_MODEL_TYPES[args.model_type]
|
|
print(f"Downloading all {args.model_type} models...")
|
|
for weight in model_enum:
|
|
print(f"Downloading: {weight.value}")
|
|
try:
|
|
verify_model_weights(weight)
|
|
print(f"Downloaded: {weight.value}")
|
|
except Exception as e:
|
|
print(f"Failed to download {weight.value}: {e}")
|
|
else:
|
|
# Download all models from all types
|
|
print("Downloading all models...")
|
|
for model_type, model_enum in ALL_MODEL_TYPES.items():
|
|
print(f"\nDownloading {model_type} models...")
|
|
for weight in model_enum:
|
|
print(f"Downloading: {weight.value}")
|
|
try:
|
|
verify_model_weights(weight)
|
|
print(f"Downloaded: {weight.value}")
|
|
except Exception as e:
|
|
print(f"Failed to download {weight.value}: {e}")
|
|
|
|
print("\nDownload process completed.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|