Provide possibility to add provider

This commit is contained in:
Tanja Bayer
2021-08-24 16:53:53 +02:00
parent cdc3d4ed5d
commit 6538da9787
2 changed files with 14 additions and 13 deletions

View File

@@ -6,23 +6,22 @@
from __future__ import division
import collections
import numpy as np
import glob
import os
import os.path as osp
from numpy.linalg import norm
import numpy as np
import onnxruntime
from numpy.linalg import norm
from ..model_zoo import model_zoo
from ..utils import face_align
from ..utils import ensure_available
from ..utils import DEFAULT_MP_NAME, ensure_available
from .common import Face
from ..utils import DEFAULT_MP_NAME
__all__ = ['FaceAnalysis']
class FaceAnalysis:
def __init__(self, name=DEFAULT_MP_NAME, root='~/.insightface', allowed_modules=None):
def __init__(self, name=DEFAULT_MP_NAME, root='~/.insightface', allowed_modules=None, **kwargs):
onnxruntime.set_default_logger_severity(3)
self.models = {}
self.model_dir = ensure_available('models', name, root=root)
@@ -32,7 +31,8 @@ class FaceAnalysis:
if onnx_file.find('_selfgen_')>0:
#print('ignore:', onnx_file)
continue
model = model_zoo.get_model(onnx_file)
model = model_zoo.get_model(onnx_file, **kwargs)
if model is None:
print('model not recognized:', onnx_file)
elif allowed_modules is not None and model.taskname not in allowed_modules:

View File

@@ -22,12 +22,13 @@ class ModelRouter:
def __init__(self, onnx_file):
self.onnx_file = onnx_file
def get_model(self):
session = onnxruntime.InferenceSession(self.onnx_file, None)
def get_model(self, **kwargs):
session = onnxruntime.InferenceSession(self.onnx_file, **kwargs)
print(f'Applied providers: {session._providers}, with options: {session._provider_options}')
input_cfg = session.get_inputs()[0]
input_shape = input_cfg.shape
outputs = session.get_outputs()
#print(input_shape)
if len(outputs)>=5:
return SCRFD(model_file=self.onnx_file, session=session)
elif input_shape[2]==112 and input_shape[3]==112:
@@ -66,6 +67,6 @@ def get_model(name, **kwargs):
assert osp.exists(model_file), 'model_file should exist'
assert osp.isfile(model_file), 'model_file should be file'
router = ModelRouter(model_file)
model = router.get_model()
model = router.get_model(providers=kwargs.get('providers'), provider_options=kwargs.get('provider_options'))
return model