mirror of
https://github.com/deepinsight/insightface.git
synced 2026-04-15 06:30:25 +00:00
103 lines
3.0 KiB
Python
103 lines
3.0 KiB
Python
"""
|
|
This code file mainly comes from https://github.com/dmlc/gluon-cv/blob/master/gluoncv/model_zoo/model_store.py
|
|
"""
|
|
from __future__ import print_function
|
|
|
|
__all__ = ['get_model_file']
|
|
import os
|
|
import zipfile
|
|
import glob
|
|
|
|
from ..utils import download, check_sha1
|
|
|
|
_model_sha1 = {
|
|
name: checksum
|
|
for checksum, name in [
|
|
('95be21b58e29e9c1237f229dae534bd854009ce0', 'arcface_r100_v1'),
|
|
('', 'arcface_mfn_v1'),
|
|
('39fd1e087a2a2ed70a154ac01fecaa86c315d01b', 'retinaface_r50_v1'),
|
|
('2c9de8116d1f448fd1d4661f90308faae34c990a', 'retinaface_mnet025_v1'),
|
|
('0db1d07921d005e6c9a5b38e059452fc5645e5a4', 'retinaface_mnet025_v2'),
|
|
('7dd8111652b7aac2490c5dcddeb268e53ac643e6', 'genderage_v1'),
|
|
]
|
|
}
|
|
|
|
base_repo_url = 'https://insightface.ai/files/'
|
|
_url_format = '{repo_url}models/{file_name}.zip'
|
|
|
|
|
|
def short_hash(name):
|
|
if name not in _model_sha1:
|
|
raise ValueError(
|
|
'Pretrained model for {name} is not available.'.format(name=name))
|
|
return _model_sha1[name][:8]
|
|
|
|
|
|
def find_params_file(dir_path):
|
|
if not os.path.exists(dir_path):
|
|
return None
|
|
paths = glob.glob("%s/*.params" % dir_path)
|
|
if len(paths) == 0:
|
|
return None
|
|
paths = sorted(paths)
|
|
return paths[-1]
|
|
|
|
|
|
def get_model_file(name, root=os.path.join('~', '.insightface', 'models')):
|
|
r"""Return location for the pretrained on local file system.
|
|
|
|
This function will download from online model zoo when model cannot be found or has mismatch.
|
|
The root directory will be created if it doesn't exist.
|
|
|
|
Parameters
|
|
----------
|
|
name : str
|
|
Name of the model.
|
|
root : str, default '~/.mxnet/models'
|
|
Location for keeping the model parameters.
|
|
|
|
Returns
|
|
-------
|
|
file_path
|
|
Path to the requested pretrained model file.
|
|
"""
|
|
|
|
file_name = name
|
|
root = os.path.expanduser(root)
|
|
dir_path = os.path.join(root, name)
|
|
file_path = find_params_file(dir_path)
|
|
#file_path = os.path.join(root, file_name + '.params')
|
|
sha1_hash = _model_sha1[name]
|
|
if file_path is not None:
|
|
if check_sha1(file_path, sha1_hash):
|
|
return file_path
|
|
else:
|
|
print(
|
|
'Mismatch in the content of model file detected. Downloading again.'
|
|
)
|
|
else:
|
|
print('Model file is not found. Downloading.')
|
|
|
|
if not os.path.exists(root):
|
|
os.makedirs(root)
|
|
if not os.path.exists(dir_path):
|
|
os.makedirs(dir_path)
|
|
|
|
zip_file_path = os.path.join(root, file_name + '.zip')
|
|
repo_url = base_repo_url
|
|
if repo_url[-1] != '/':
|
|
repo_url = repo_url + '/'
|
|
download(_url_format.format(repo_url=repo_url, file_name=file_name),
|
|
path=zip_file_path,
|
|
overwrite=True)
|
|
with zipfile.ZipFile(zip_file_path) as zf:
|
|
zf.extractall(dir_path)
|
|
os.remove(zip_file_path)
|
|
file_path = find_params_file(dir_path)
|
|
|
|
if check_sha1(file_path, sha1_hash):
|
|
return file_path
|
|
else:
|
|
raise ValueError(
|
|
'Downloaded file has different hash. Please try again.')
|