mirror of
https://github.com/yakhyo/uniface.git
synced 2025-12-30 09:02:25 +00:00
* chore: Rename scripts to tools folder and unify argument parser * refactor: Centralize dataclasses in types.py and add __call__ to all models - Move Face and result dataclasses to uniface/types.py - Add GazeResult, SpoofingResult, EmotionResult (frozen=True) - Add __call__ to BaseDetector, BaseRecognizer, BaseLandmarker - Add __repr__ to all dataclasses - Replace print() with Logger in onnx_utils.py - Update tools and docs to use new dataclass return types - Add test_types.py with comprehensive dataclass testschore: Rename files under tools folder and unitify argument parser for them
61 lines
1.5 KiB
Python
61 lines
1.5 KiB
Python
import argparse
|
|
|
|
from uniface.constants import (
|
|
AgeGenderWeights,
|
|
ArcFaceWeights,
|
|
DDAMFNWeights,
|
|
LandmarkWeights,
|
|
MobileFaceWeights,
|
|
RetinaFaceWeights,
|
|
SCRFDWeights,
|
|
SphereFaceWeights,
|
|
)
|
|
from uniface.model_store import verify_model_weights
|
|
|
|
MODEL_TYPES = {
|
|
'retinaface': RetinaFaceWeights,
|
|
'sphereface': SphereFaceWeights,
|
|
'mobileface': MobileFaceWeights,
|
|
'arcface': ArcFaceWeights,
|
|
'scrfd': SCRFDWeights,
|
|
'ddamfn': DDAMFNWeights,
|
|
'agegender': AgeGenderWeights,
|
|
'landmark': LandmarkWeights,
|
|
}
|
|
|
|
|
|
def download_models(model_enum):
|
|
for weight in model_enum:
|
|
print(f'Downloading: {weight.value}')
|
|
try:
|
|
verify_model_weights(weight)
|
|
print(f' Done: {weight.value}')
|
|
except Exception as e:
|
|
print(f' Failed: {e}')
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='Download model weights')
|
|
parser.add_argument(
|
|
'--model-type',
|
|
type=str,
|
|
choices=list(MODEL_TYPES.keys()),
|
|
help='Model type to download. If not specified, downloads all.',
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
if args.model_type:
|
|
print(f'Downloading {args.model_type} models...')
|
|
download_models(MODEL_TYPES[args.model_type])
|
|
else:
|
|
print('Downloading all models...')
|
|
for name, model_enum in MODEL_TYPES.items():
|
|
print(f'\n{name}:')
|
|
download_models(model_enum)
|
|
|
|
print('\nDone!')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|