mirror of
https://github.com/deepinsight/insightface.git
synced 2025-12-30 08:02:27 +00:00
Provide possibility to add provider
This commit is contained in:
@@ -6,23 +6,22 @@
|
||||
|
||||
|
||||
from __future__ import division
|
||||
import collections
|
||||
import numpy as np
|
||||
|
||||
import glob
|
||||
import os
|
||||
import os.path as osp
|
||||
from numpy.linalg import norm
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime
|
||||
from numpy.linalg import norm
|
||||
|
||||
from ..model_zoo import model_zoo
|
||||
from ..utils import face_align
|
||||
from ..utils import ensure_available
|
||||
from ..utils import DEFAULT_MP_NAME, ensure_available
|
||||
from .common import Face
|
||||
from ..utils import DEFAULT_MP_NAME
|
||||
|
||||
__all__ = ['FaceAnalysis']
|
||||
|
||||
class FaceAnalysis:
|
||||
def __init__(self, name=DEFAULT_MP_NAME, root='~/.insightface', allowed_modules=None):
|
||||
def __init__(self, name=DEFAULT_MP_NAME, root='~/.insightface', allowed_modules=None, **kwargs):
|
||||
onnxruntime.set_default_logger_severity(3)
|
||||
self.models = {}
|
||||
self.model_dir = ensure_available('models', name, root=root)
|
||||
@@ -32,7 +31,8 @@ class FaceAnalysis:
|
||||
if onnx_file.find('_selfgen_')>0:
|
||||
#print('ignore:', onnx_file)
|
||||
continue
|
||||
model = model_zoo.get_model(onnx_file)
|
||||
model = model_zoo.get_model(onnx_file, **kwargs)
|
||||
|
||||
if model is None:
|
||||
print('model not recognized:', onnx_file)
|
||||
elif allowed_modules is not None and model.taskname not in allowed_modules:
|
||||
|
||||
@@ -22,12 +22,13 @@ class ModelRouter:
|
||||
def __init__(self, onnx_file):
|
||||
self.onnx_file = onnx_file
|
||||
|
||||
def get_model(self):
|
||||
session = onnxruntime.InferenceSession(self.onnx_file, None)
|
||||
def get_model(self, **kwargs):
|
||||
session = onnxruntime.InferenceSession(self.onnx_file, **kwargs)
|
||||
print(f'Applied providers: {session._providers}, with options: {session._provider_options}')
|
||||
input_cfg = session.get_inputs()[0]
|
||||
input_shape = input_cfg.shape
|
||||
outputs = session.get_outputs()
|
||||
#print(input_shape)
|
||||
|
||||
if len(outputs)>=5:
|
||||
return SCRFD(model_file=self.onnx_file, session=session)
|
||||
elif input_shape[2]==112 and input_shape[3]==112:
|
||||
@@ -66,6 +67,6 @@ def get_model(name, **kwargs):
|
||||
assert osp.exists(model_file), 'model_file should exist'
|
||||
assert osp.isfile(model_file), 'model_file should be file'
|
||||
router = ModelRouter(model_file)
|
||||
model = router.get_model()
|
||||
model = router.get_model(providers=kwargs.get('providers'), provider_options=kwargs.get('provider_options'))
|
||||
return model
|
||||
|
||||
|
||||
Reference in New Issue
Block a user