mirror of
https://github.com/deepinsight/insightface.git
synced 2025-12-30 08:02:27 +00:00
Add output names and fix batch dim onnx export
This commit is contained in:
@@ -75,6 +75,7 @@ class SCRFD:
|
||||
self.model_file = model_file
|
||||
self.session = session
|
||||
self.taskname = 'detection'
|
||||
self.batched = False
|
||||
if self.session is None:
|
||||
assert self.model_file is not None
|
||||
assert osp.exists(self.model_file)
|
||||
@@ -96,6 +97,8 @@ class SCRFD:
|
||||
input_name = input_cfg.name
|
||||
self.input_shape = input_shape
|
||||
outputs = self.session.get_outputs()
|
||||
if len(outputs[0].shape) == 3:
|
||||
self.batched = True
|
||||
output_names = []
|
||||
for o in outputs:
|
||||
output_names.append(o.name)
|
||||
@@ -155,11 +158,21 @@ class SCRFD:
|
||||
input_width = blob.shape[3]
|
||||
fmc = self.fmc
|
||||
for idx, stride in enumerate(self._feat_stride_fpn):
|
||||
scores = net_outs[idx]
|
||||
bbox_preds = net_outs[idx+fmc]
|
||||
bbox_preds = bbox_preds * stride
|
||||
if self.use_kps:
|
||||
kps_preds = net_outs[idx+fmc*2] * stride
|
||||
# If model support batch dim, take first output
|
||||
if self.batched:
|
||||
scores = net_outs[idx][0]
|
||||
bbox_preds = net_outs[idx + fmc][0]
|
||||
bbox_preds = bbox_preds * stride
|
||||
if self.use_kps:
|
||||
kps_preds = net_outs[idx + fmc * 2][0] * stride
|
||||
# If model doesn't support batching take output as is
|
||||
else:
|
||||
scores = net_outs[idx]
|
||||
bbox_preds = net_outs[idx + fmc]
|
||||
bbox_preds = bbox_preds * stride
|
||||
if self.use_kps:
|
||||
kps_preds = net_outs[idx + fmc * 2] * stride
|
||||
|
||||
height = input_height // stride
|
||||
width = input_width // stride
|
||||
K = height * width
|
||||
|
||||
Reference in New Issue
Block a user