mirror of
https://github.com/deepinsight/insightface.git
synced 2025-12-30 08:02:27 +00:00
code of pip insightface==0.2
This commit is contained in:
@@ -3,25 +3,15 @@
|
||||
"""InsightFace: A Face Analysis Toolkit."""
|
||||
from __future__ import absolute_import
|
||||
|
||||
# mxnet version check
|
||||
#mx_version = '1.4.0'
|
||||
try:
|
||||
import mxnet as mx
|
||||
#from distutils.version import LooseVersion
|
||||
#if LooseVersion(mx.__version__) < LooseVersion(mx_version):
|
||||
# msg = (
|
||||
# "Legacy mxnet-mkl=={} detected, some new modules may not work properly. "
|
||||
# "mxnet-mkl>={} is required. You can use pip to upgrade mxnet "
|
||||
# "`pip install mxnet-mkl --pre --upgrade` "
|
||||
# "or `pip install mxnet-cu90mkl --pre --upgrade`").format(mx.__version__, mx_version)
|
||||
# raise ImportError(msg)
|
||||
#import mxnet as mx
|
||||
import onnxruntime
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Unable to import dependency mxnet. "
|
||||
"A quick tip is to install via `pip install mxnet-mkl/mxnet-cu90mkl --pre`. "
|
||||
"Unable to import dependency onnxruntime. "
|
||||
)
|
||||
|
||||
__version__ = '0.1.5'
|
||||
__version__ = '0.2.0'
|
||||
|
||||
from . import model_zoo
|
||||
from . import utils
|
||||
|
||||
@@ -1,85 +1,100 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Organization : insightface.ai
|
||||
# @Author : Jia Guo
|
||||
# @Time : 2021-05-04
|
||||
# @Function :
|
||||
|
||||
|
||||
from __future__ import division
|
||||
import collections
|
||||
import mxnet as mx
|
||||
import numpy as np
|
||||
import glob
|
||||
import os
|
||||
import os.path as osp
|
||||
from numpy.linalg import norm
|
||||
import mxnet.ndarray as nd
|
||||
from ..model_zoo import model_zoo
|
||||
from ..utils import face_align
|
||||
|
||||
__all__ = ['FaceAnalysis', 'Face']
|
||||
|
||||
Face = collections.namedtuple('Face', [
|
||||
'bbox', 'landmark', 'det_score', 'embedding', 'gender', 'age',
|
||||
'embedding_norm', 'normed_embedding'
|
||||
'bbox', 'kps', 'det_score', 'embedding', 'gender', 'age',
|
||||
'embedding_norm', 'normed_embedding',
|
||||
'landmark'
|
||||
])
|
||||
|
||||
Face.__new__.__defaults__ = (None, ) * len(Face._fields)
|
||||
|
||||
|
||||
class FaceAnalysis:
|
||||
def __init__(self,
|
||||
det_name='retinaface_r50_v1',
|
||||
rec_name='arcface_r100_v1',
|
||||
ga_name='genderage_v1'):
|
||||
assert det_name is not None
|
||||
self.det_model = model_zoo.get_model(det_name)
|
||||
if rec_name is not None:
|
||||
self.rec_model = model_zoo.get_model(rec_name)
|
||||
else:
|
||||
self.rec_model = None
|
||||
if ga_name is not None:
|
||||
self.ga_model = model_zoo.get_model(ga_name)
|
||||
else:
|
||||
self.ga_model = None
|
||||
def __init__(self, name, root='~/.insightface/models'):
|
||||
self.models = {}
|
||||
root = os.path.expanduser(root)
|
||||
onnx_files = glob.glob(osp.join(root, name, '*.onnx'))
|
||||
onnx_files = sorted(onnx_files)
|
||||
for onnx_file in onnx_files:
|
||||
if onnx_file.find('_selfgen_')>0:
|
||||
#print('ignore:', onnx_file)
|
||||
continue
|
||||
model = model_zoo.get_model(onnx_file)
|
||||
if model.taskname not in self.models:
|
||||
print('find model:', onnx_file, model.taskname)
|
||||
self.models[model.taskname] = model
|
||||
else:
|
||||
print('duplicated model task type, ignore:', onnx_file, model.taskname)
|
||||
del model
|
||||
assert 'detection' in self.models
|
||||
self.det_model = self.models['detection']
|
||||
|
||||
def prepare(self, ctx_id, nms=0.4):
|
||||
self.det_model.prepare(ctx_id, nms)
|
||||
if self.rec_model is not None:
|
||||
self.rec_model.prepare(ctx_id)
|
||||
if self.ga_model is not None:
|
||||
self.ga_model.prepare(ctx_id)
|
||||
|
||||
def get(self, img, det_thresh=0.8, det_scale=1.0, max_num=0):
|
||||
bboxes, landmarks = self.det_model.detect(img,
|
||||
threshold=det_thresh,
|
||||
scale=det_scale)
|
||||
def prepare(self, ctx_id, det_thresh=0.5, det_size=(640, 640)):
|
||||
self.det_thresh = det_thresh
|
||||
assert det_size is not None
|
||||
print('set det-size:', det_size)
|
||||
self.det_size = det_size
|
||||
for taskname, model in self.models.items():
|
||||
if taskname=='detection':
|
||||
model.prepare(ctx_id, input_size=det_size)
|
||||
else:
|
||||
model.prepare(ctx_id)
|
||||
|
||||
def get(self, img, max_num=0):
|
||||
bboxes, kpss = self.det_model.detect(img,
|
||||
threshold=self.det_thresh,
|
||||
max_num=max_num,
|
||||
metric='default')
|
||||
if bboxes.shape[0] == 0:
|
||||
return []
|
||||
if max_num > 0 and bboxes.shape[0] > max_num:
|
||||
area = (bboxes[:, 2] - bboxes[:, 0]) * (bboxes[:, 3] -
|
||||
bboxes[:, 1])
|
||||
img_center = img.shape[0] // 2, img.shape[1] // 2
|
||||
offsets = np.vstack([
|
||||
(bboxes[:, 0] + bboxes[:, 2]) / 2 - img_center[1],
|
||||
(bboxes[:, 1] + bboxes[:, 3]) / 2 - img_center[0]
|
||||
])
|
||||
offset_dist_squared = np.sum(np.power(offsets, 2.0), 0)
|
||||
values = area - offset_dist_squared * 2.0 # some extra weight on the centering
|
||||
bindex = np.argsort(
|
||||
values)[::-1] # some extra weight on the centering
|
||||
bindex = bindex[0:max_num]
|
||||
bboxes = bboxes[bindex, :]
|
||||
landmarks = landmarks[bindex, :]
|
||||
ret = []
|
||||
for i in range(bboxes.shape[0]):
|
||||
bbox = bboxes[i, 0:4]
|
||||
det_score = bboxes[i, 4]
|
||||
landmark = landmarks[i]
|
||||
_img = face_align.norm_crop(img, landmark=landmark)
|
||||
kps = None
|
||||
if kpss is not None:
|
||||
kps = kpss[i]
|
||||
embedding = None
|
||||
embedding_norm = None
|
||||
normed_embedding = None
|
||||
embedding_norm = None
|
||||
gender = None
|
||||
age = None
|
||||
if self.rec_model is not None:
|
||||
embedding = self.rec_model.get_embedding(_img).flatten()
|
||||
if 'recognition' in self.models:
|
||||
assert kps is not None
|
||||
rec_model = self.models['recognition']
|
||||
aimg = face_align.norm_crop(img, landmark=kps)
|
||||
embedding = None
|
||||
embedding_norm = None
|
||||
normed_embedding = None
|
||||
gender = None
|
||||
age = None
|
||||
embedding = rec_model.get_feat(aimg).flatten()
|
||||
embedding_norm = norm(embedding)
|
||||
normed_embedding = embedding / embedding_norm
|
||||
if self.ga_model is not None:
|
||||
gender, age = self.ga_model.get(_img)
|
||||
if 'genderage' in self.models:
|
||||
assert aimg is not None
|
||||
ga_model = self.models['genderage']
|
||||
gender, age = ga_model.get(_img)
|
||||
face = Face(bbox=bbox,
|
||||
landmark=landmark,
|
||||
kps=kps,
|
||||
det_score=det_score,
|
||||
embedding=embedding,
|
||||
gender=gender,
|
||||
@@ -88,3 +103,4 @@ class FaceAnalysis:
|
||||
embedding_norm=embedding_norm)
|
||||
ret.append(face)
|
||||
return ret
|
||||
|
||||
|
||||
@@ -1 +1,3 @@
|
||||
from .model_zoo import get_model, get_model_list
|
||||
from .model_zoo import get_model
|
||||
from .arcface_onnx import ArcFaceONNX
|
||||
from .scrfd import SCRFD
|
||||
|
||||
82
python-package/insightface/model_zoo/arcface_onnx.py
Normal file
82
python-package/insightface/model_zoo/arcface_onnx.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Organization : insightface.ai
|
||||
# @Author : Jia Guo
|
||||
# @Time : 2021-05-04
|
||||
# @Function :
|
||||
|
||||
from __future__ import division
|
||||
import numpy as np
|
||||
import cv2
|
||||
import onnx
|
||||
import onnxruntime
|
||||
from ..utils import face_align
|
||||
|
||||
__all__ = [
|
||||
'ArcFaceONNX',
|
||||
]
|
||||
|
||||
|
||||
class ArcFaceONNX:
|
||||
def __init__(self, model_file=None, session=None):
|
||||
import onnxruntime
|
||||
assert model_file is not None
|
||||
self.model_file = model_file
|
||||
self.session = session
|
||||
self.taskname = 'recognition'
|
||||
find_sub = False
|
||||
find_mul = False
|
||||
model = onnx.load(self.model_file)
|
||||
graph = model.graph
|
||||
for nid, node in enumerate(graph.node[:8]):
|
||||
#print(nid, node.name)
|
||||
if node.name.startswith('Sub') or node.name.startswith('_minus'):
|
||||
find_sub = True
|
||||
if node.name.startswith('Mul') or node.name.startswith('_mul'):
|
||||
find_mul = True
|
||||
if find_sub and find_mul:
|
||||
#mxnet arcface model
|
||||
input_mean = 0.0
|
||||
input_std = 1.0
|
||||
else:
|
||||
input_mean = 127.5
|
||||
input_std = 127.5
|
||||
self.input_mean = input_mean
|
||||
self.input_std = input_std
|
||||
print('input mean and std:', self.input_mean, self.input_std)
|
||||
if self.session is None:
|
||||
self.session = onnxruntime.InferenceSession(self.model_file, None)
|
||||
input_cfg = self.session.get_inputs()[0]
|
||||
input_shape = input_cfg.shape
|
||||
input_name = input_cfg.name
|
||||
self.input_size = tuple(input_shape[2:4][::-1])
|
||||
outputs = self.session.get_outputs()
|
||||
output_names = []
|
||||
for out in outputs:
|
||||
output_names.append(out.name)
|
||||
self.session = session
|
||||
self.input_name = input_name
|
||||
self.output_names = output_names
|
||||
assert len(self.output_names)==1
|
||||
|
||||
def prepare(self, ctx_id, **kwargs):
|
||||
if ctx_id<0:
|
||||
self.session.set_providers(['CPUExecutionProvider'])
|
||||
|
||||
def get_feat(self, img):
|
||||
assert img.shape[2] == 3
|
||||
input_size = tuple(img.shape[0:2][::-1])
|
||||
assert input_size==self.input_size
|
||||
blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
|
||||
net_outs = self.session.run(self.output_names, {self.input_name : blob})
|
||||
feat = net_outs[0]
|
||||
return feat
|
||||
|
||||
|
||||
def compute_sim(self, feat1, feat2):
|
||||
from np.linalg import norm
|
||||
feat1 = feat1.ravel()
|
||||
feat2 = feat2.ravel()
|
||||
sim = np.dot(feat1, feat2) / (norm(feat1) * norm(feat2))
|
||||
return sim
|
||||
|
||||
|
||||
@@ -1,466 +0,0 @@
|
||||
from __future__ import division
|
||||
import mxnet as mx
|
||||
import numpy as np
|
||||
import mxnet.ndarray as nd
|
||||
import cv2
|
||||
|
||||
__all__ = [
|
||||
'FaceDetector', 'retinaface_r50_v1', 'retinaface_mnet025_v1',
|
||||
'retinaface_mnet025_v2', 'get_retinaface'
|
||||
]
|
||||
|
||||
|
||||
def _whctrs(anchor):
|
||||
"""
|
||||
Return width, height, x center, and y center for an anchor (window).
|
||||
"""
|
||||
|
||||
w = anchor[2] - anchor[0] + 1
|
||||
h = anchor[3] - anchor[1] + 1
|
||||
x_ctr = anchor[0] + 0.5 * (w - 1)
|
||||
y_ctr = anchor[1] + 0.5 * (h - 1)
|
||||
return w, h, x_ctr, y_ctr
|
||||
|
||||
|
||||
def _mkanchors(ws, hs, x_ctr, y_ctr):
|
||||
"""
|
||||
Given a vector of widths (ws) and heights (hs) around a center
|
||||
(x_ctr, y_ctr), output a set of anchors (windows).
|
||||
"""
|
||||
|
||||
ws = ws[:, np.newaxis]
|
||||
hs = hs[:, np.newaxis]
|
||||
anchors = np.hstack((x_ctr - 0.5 * (ws - 1), y_ctr - 0.5 * (hs - 1),
|
||||
x_ctr + 0.5 * (ws - 1), y_ctr + 0.5 * (hs - 1)))
|
||||
return anchors
|
||||
|
||||
|
||||
def _ratio_enum(anchor, ratios):
|
||||
"""
|
||||
Enumerate a set of anchors for each aspect ratio wrt an anchor.
|
||||
"""
|
||||
|
||||
w, h, x_ctr, y_ctr = _whctrs(anchor)
|
||||
size = w * h
|
||||
size_ratios = size / ratios
|
||||
ws = np.round(np.sqrt(size_ratios))
|
||||
hs = np.round(ws * ratios)
|
||||
anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
|
||||
return anchors
|
||||
|
||||
|
||||
def _scale_enum(anchor, scales):
|
||||
"""
|
||||
Enumerate a set of anchors for each scale wrt an anchor.
|
||||
"""
|
||||
|
||||
w, h, x_ctr, y_ctr = _whctrs(anchor)
|
||||
ws = w * scales
|
||||
hs = h * scales
|
||||
anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
|
||||
return anchors
|
||||
|
||||
|
||||
def anchors_plane(height, width, stride, base_anchors):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
height: height of plane
|
||||
width: width of plane
|
||||
stride: stride ot the original image
|
||||
anchors_base: (A, 4) a base set of anchors
|
||||
Returns
|
||||
-------
|
||||
all_anchors: (height, width, A, 4) ndarray of anchors spreading over the plane
|
||||
"""
|
||||
A = base_anchors.shape[0]
|
||||
all_anchors = np.zeros((height, width, A, 4), dtype=np.float32)
|
||||
for iw in range(width):
|
||||
sw = iw * stride
|
||||
for ih in range(height):
|
||||
sh = ih * stride
|
||||
for k in range(A):
|
||||
all_anchors[ih, iw, k, 0] = base_anchors[k, 0] + sw
|
||||
all_anchors[ih, iw, k, 1] = base_anchors[k, 1] + sh
|
||||
all_anchors[ih, iw, k, 2] = base_anchors[k, 2] + sw
|
||||
all_anchors[ih, iw, k, 3] = base_anchors[k, 3] + sh
|
||||
return all_anchors
|
||||
|
||||
|
||||
def generate_anchors(base_size=16,
|
||||
ratios=[0.5, 1, 2],
|
||||
scales=2**np.arange(3, 6),
|
||||
stride=16):
|
||||
"""
|
||||
Generate anchor (reference) windows by enumerating aspect ratios X
|
||||
scales wrt a reference (0, 0, 15, 15) window.
|
||||
"""
|
||||
|
||||
base_anchor = np.array([1, 1, base_size, base_size]) - 1
|
||||
ratio_anchors = _ratio_enum(base_anchor, ratios)
|
||||
anchors = np.vstack([
|
||||
_scale_enum(ratio_anchors[i, :], scales)
|
||||
for i in range(ratio_anchors.shape[0])
|
||||
])
|
||||
return anchors
|
||||
|
||||
|
||||
def generate_anchors_fpn(cfg):
|
||||
"""
|
||||
Generate anchor (reference) windows by enumerating aspect ratios X
|
||||
scales wrt a reference (0, 0, 15, 15) window.
|
||||
"""
|
||||
RPN_FEAT_STRIDE = []
|
||||
for k in cfg:
|
||||
RPN_FEAT_STRIDE.append(int(k))
|
||||
RPN_FEAT_STRIDE = sorted(RPN_FEAT_STRIDE, reverse=True)
|
||||
anchors = []
|
||||
for k in RPN_FEAT_STRIDE:
|
||||
v = cfg[str(k)]
|
||||
bs = v['BASE_SIZE']
|
||||
__ratios = np.array(v['RATIOS'])
|
||||
__scales = np.array(v['SCALES'])
|
||||
stride = int(k)
|
||||
#print('anchors_fpn', bs, __ratios, __scales, file=sys.stderr)
|
||||
r = generate_anchors(bs, __ratios, __scales, stride)
|
||||
#print('anchors_fpn', r.shape, file=sys.stderr)
|
||||
anchors.append(r)
|
||||
|
||||
return anchors
|
||||
|
||||
|
||||
def clip_pad(tensor, pad_shape):
|
||||
"""
|
||||
Clip boxes of the pad area.
|
||||
:param tensor: [n, c, H, W]
|
||||
:param pad_shape: [h, w]
|
||||
:return: [n, c, h, w]
|
||||
"""
|
||||
H, W = tensor.shape[2:]
|
||||
h, w = pad_shape
|
||||
|
||||
if h < H or w < W:
|
||||
tensor = tensor[:, :, :h, :w].copy()
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
def bbox_pred(boxes, box_deltas):
|
||||
"""
|
||||
Transform the set of class-agnostic boxes into class-specific boxes
|
||||
by applying the predicted offsets (box_deltas)
|
||||
:param boxes: !important [N 4]
|
||||
:param box_deltas: [N, 4 * num_classes]
|
||||
:return: [N 4 * num_classes]
|
||||
"""
|
||||
if boxes.shape[0] == 0:
|
||||
return np.zeros((0, box_deltas.shape[1]))
|
||||
|
||||
boxes = boxes.astype(np.float, copy=False)
|
||||
widths = boxes[:, 2] - boxes[:, 0] + 1.0
|
||||
heights = boxes[:, 3] - boxes[:, 1] + 1.0
|
||||
ctr_x = boxes[:, 0] + 0.5 * (widths - 1.0)
|
||||
ctr_y = boxes[:, 1] + 0.5 * (heights - 1.0)
|
||||
|
||||
dx = box_deltas[:, 0:1]
|
||||
dy = box_deltas[:, 1:2]
|
||||
dw = box_deltas[:, 2:3]
|
||||
dh = box_deltas[:, 3:4]
|
||||
|
||||
pred_ctr_x = dx * widths[:, np.newaxis] + ctr_x[:, np.newaxis]
|
||||
pred_ctr_y = dy * heights[:, np.newaxis] + ctr_y[:, np.newaxis]
|
||||
pred_w = np.exp(dw) * widths[:, np.newaxis]
|
||||
pred_h = np.exp(dh) * heights[:, np.newaxis]
|
||||
|
||||
pred_boxes = np.zeros(box_deltas.shape)
|
||||
# x1
|
||||
pred_boxes[:, 0:1] = pred_ctr_x - 0.5 * (pred_w - 1.0)
|
||||
# y1
|
||||
pred_boxes[:, 1:2] = pred_ctr_y - 0.5 * (pred_h - 1.0)
|
||||
# x2
|
||||
pred_boxes[:, 2:3] = pred_ctr_x + 0.5 * (pred_w - 1.0)
|
||||
# y2
|
||||
pred_boxes[:, 3:4] = pred_ctr_y + 0.5 * (pred_h - 1.0)
|
||||
|
||||
if box_deltas.shape[1] > 4:
|
||||
pred_boxes[:, 4:] = box_deltas[:, 4:]
|
||||
|
||||
return pred_boxes
|
||||
|
||||
|
||||
def landmark_pred(boxes, landmark_deltas):
|
||||
if boxes.shape[0] == 0:
|
||||
return np.zeros((0, landmark_deltas.shape[1]))
|
||||
boxes = boxes.astype(np.float, copy=False)
|
||||
widths = boxes[:, 2] - boxes[:, 0] + 1.0
|
||||
heights = boxes[:, 3] - boxes[:, 1] + 1.0
|
||||
ctr_x = boxes[:, 0] + 0.5 * (widths - 1.0)
|
||||
ctr_y = boxes[:, 1] + 0.5 * (heights - 1.0)
|
||||
pred = landmark_deltas.copy()
|
||||
for i in range(5):
|
||||
pred[:, i, 0] = landmark_deltas[:, i, 0] * widths + ctr_x
|
||||
pred[:, i, 1] = landmark_deltas[:, i, 1] * heights + ctr_y
|
||||
return pred
|
||||
|
||||
|
||||
class FaceDetector:
|
||||
def __init__(self, param_file, rac):
|
||||
self.param_file = param_file
|
||||
self.rac = rac
|
||||
self.default_image_size = (480, 640)
|
||||
|
||||
def prepare(self, ctx_id, nms=0.4, fix_image_size=None):
|
||||
pos = self.param_file.rfind('-')
|
||||
prefix = self.param_file[0:pos]
|
||||
pos2 = self.param_file.rfind('.')
|
||||
epoch = int(self.param_file[pos + 1:pos2])
|
||||
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
|
||||
if ctx_id >= 0:
|
||||
ctx = mx.gpu(ctx_id)
|
||||
else:
|
||||
ctx = mx.cpu()
|
||||
model = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
|
||||
if fix_image_size is not None:
|
||||
data_shape = (1, 3) + fix_image_size
|
||||
else:
|
||||
data_shape = (1, 3) + self.default_image_size
|
||||
model.bind(data_shapes=[('data', data_shape)])
|
||||
model.set_params(arg_params, aux_params)
|
||||
#warmup
|
||||
data = mx.nd.zeros(shape=data_shape)
|
||||
db = mx.io.DataBatch(data=(data, ))
|
||||
model.forward(db, is_train=False)
|
||||
out = model.get_outputs()[0].asnumpy()
|
||||
self.model = model
|
||||
self.nms_threshold = nms
|
||||
|
||||
self.landmark_std = 1.0
|
||||
_ratio = (1., )
|
||||
fmc = 3
|
||||
if self.rac == 'net3':
|
||||
_ratio = (1., )
|
||||
elif self.rac == 'net3l':
|
||||
_ratio = (1., )
|
||||
self.landmark_std = 0.2
|
||||
elif network == 'net5': #retinaface
|
||||
fmc = 5
|
||||
else:
|
||||
assert False, 'rac setting error %s' % self.rac
|
||||
|
||||
if fmc == 3:
|
||||
self._feat_stride_fpn = [32, 16, 8]
|
||||
self.anchor_cfg = {
|
||||
'32': {
|
||||
'SCALES': (32, 16),
|
||||
'BASE_SIZE': 16,
|
||||
'RATIOS': _ratio,
|
||||
'ALLOWED_BORDER': 9999
|
||||
},
|
||||
'16': {
|
||||
'SCALES': (8, 4),
|
||||
'BASE_SIZE': 16,
|
||||
'RATIOS': _ratio,
|
||||
'ALLOWED_BORDER': 9999
|
||||
},
|
||||
'8': {
|
||||
'SCALES': (2, 1),
|
||||
'BASE_SIZE': 16,
|
||||
'RATIOS': _ratio,
|
||||
'ALLOWED_BORDER': 9999
|
||||
},
|
||||
}
|
||||
elif fmc == 5:
|
||||
self._feat_stride_fpn = [64, 32, 16, 8, 4]
|
||||
self.anchor_cfg = {}
|
||||
_ass = 2.0**(1.0 / 3)
|
||||
_basescale = 1.0
|
||||
for _stride in [4, 8, 16, 32, 64]:
|
||||
key = str(_stride)
|
||||
value = {
|
||||
'BASE_SIZE': 16,
|
||||
'RATIOS': _ratio,
|
||||
'ALLOWED_BORDER': 9999
|
||||
}
|
||||
scales = []
|
||||
for _ in range(3):
|
||||
scales.append(_basescale)
|
||||
_basescale *= _ass
|
||||
value['SCALES'] = tuple(scales)
|
||||
self.anchor_cfg[key] = value
|
||||
|
||||
print(self._feat_stride_fpn, self.anchor_cfg)
|
||||
self.use_landmarks = False
|
||||
if len(sym) // len(self._feat_stride_fpn) == 3:
|
||||
self.use_landmarks = True
|
||||
print('use_landmarks', self.use_landmarks)
|
||||
self.fpn_keys = []
|
||||
|
||||
for s in self._feat_stride_fpn:
|
||||
self.fpn_keys.append('stride%s' % s)
|
||||
|
||||
self._anchors_fpn = dict(
|
||||
zip(self.fpn_keys, generate_anchors_fpn(cfg=self.anchor_cfg)))
|
||||
for k in self._anchors_fpn:
|
||||
v = self._anchors_fpn[k].astype(np.float32)
|
||||
self._anchors_fpn[k] = v
|
||||
self.anchor_plane_cache = {}
|
||||
|
||||
self._num_anchors = dict(
|
||||
zip(self.fpn_keys,
|
||||
[anchors.shape[0] for anchors in self._anchors_fpn.values()]))
|
||||
|
||||
def detect(self, img, threshold=0.5, scale=1.0):
|
||||
proposals_list = []
|
||||
scores_list = []
|
||||
landmarks_list = []
|
||||
if scale == 1.0:
|
||||
im = img
|
||||
else:
|
||||
im = cv2.resize(img,
|
||||
None,
|
||||
None,
|
||||
fx=scale,
|
||||
fy=scale,
|
||||
interpolation=cv2.INTER_LINEAR)
|
||||
im_info = [im.shape[0], im.shape[1]]
|
||||
im_tensor = np.zeros((1, 3, im.shape[0], im.shape[1]))
|
||||
for i in range(3):
|
||||
im_tensor[0, i, :, :] = im[:, :, 2 - i]
|
||||
data = nd.array(im_tensor)
|
||||
db = mx.io.DataBatch(data=(data, ),
|
||||
provide_data=[('data', data.shape)])
|
||||
self.model.forward(db, is_train=False)
|
||||
net_out = self.model.get_outputs()
|
||||
for _idx, s in enumerate(self._feat_stride_fpn):
|
||||
_key = 'stride%s' % s
|
||||
stride = int(s)
|
||||
if self.use_landmarks:
|
||||
idx = _idx * 3
|
||||
else:
|
||||
idx = _idx * 2
|
||||
scores = net_out[idx].asnumpy()
|
||||
scores = scores[:, self._num_anchors['stride%s' % s]:, :, :]
|
||||
idx += 1
|
||||
bbox_deltas = net_out[idx].asnumpy()
|
||||
|
||||
height, width = bbox_deltas.shape[2], bbox_deltas.shape[3]
|
||||
A = self._num_anchors['stride%s' % s]
|
||||
K = height * width
|
||||
key = (height, width, stride)
|
||||
if key in self.anchor_plane_cache:
|
||||
anchors = self.anchor_plane_cache[key]
|
||||
else:
|
||||
anchors_fpn = self._anchors_fpn['stride%s' % s]
|
||||
anchors = anchors_plane(height, width, stride, anchors_fpn)
|
||||
anchors = anchors.reshape((K * A, 4))
|
||||
if len(self.anchor_plane_cache) < 100:
|
||||
self.anchor_plane_cache[key] = anchors
|
||||
|
||||
scores = clip_pad(scores, (height, width))
|
||||
scores = scores.transpose((0, 2, 3, 1)).reshape((-1, 1))
|
||||
|
||||
bbox_deltas = clip_pad(bbox_deltas, (height, width))
|
||||
bbox_deltas = bbox_deltas.transpose((0, 2, 3, 1))
|
||||
bbox_pred_len = bbox_deltas.shape[3] // A
|
||||
bbox_deltas = bbox_deltas.reshape((-1, bbox_pred_len))
|
||||
|
||||
proposals = bbox_pred(anchors, bbox_deltas)
|
||||
#proposals = clip_boxes(proposals, im_info[:2])
|
||||
|
||||
scores_ravel = scores.ravel()
|
||||
order = np.where(scores_ravel >= threshold)[0]
|
||||
proposals = proposals[order, :]
|
||||
scores = scores[order]
|
||||
|
||||
proposals[:, 0:4] /= scale
|
||||
|
||||
proposals_list.append(proposals)
|
||||
scores_list.append(scores)
|
||||
|
||||
if self.use_landmarks:
|
||||
idx += 1
|
||||
landmark_deltas = net_out[idx].asnumpy()
|
||||
landmark_deltas = clip_pad(landmark_deltas, (height, width))
|
||||
landmark_pred_len = landmark_deltas.shape[1] // A
|
||||
landmark_deltas = landmark_deltas.transpose(
|
||||
(0, 2, 3, 1)).reshape((-1, 5, landmark_pred_len // 5))
|
||||
landmark_deltas *= self.landmark_std
|
||||
#print(landmark_deltas.shape, landmark_deltas)
|
||||
landmarks = landmark_pred(anchors, landmark_deltas)
|
||||
landmarks = landmarks[order, :]
|
||||
|
||||
landmarks[:, :, 0:2] /= scale
|
||||
landmarks_list.append(landmarks)
|
||||
|
||||
proposals = np.vstack(proposals_list)
|
||||
landmarks = None
|
||||
if proposals.shape[0] == 0:
|
||||
if self.use_landmarks:
|
||||
landmarks = np.zeros((0, 5, 2))
|
||||
return np.zeros((0, 5)), landmarks
|
||||
scores = np.vstack(scores_list)
|
||||
scores_ravel = scores.ravel()
|
||||
order = scores_ravel.argsort()[::-1]
|
||||
proposals = proposals[order, :]
|
||||
scores = scores[order]
|
||||
if self.use_landmarks:
|
||||
landmarks = np.vstack(landmarks_list)
|
||||
landmarks = landmarks[order].astype(np.float32, copy=False)
|
||||
|
||||
pre_det = np.hstack((proposals[:, 0:4], scores)).astype(np.float32,
|
||||
copy=False)
|
||||
keep = self.nms(pre_det)
|
||||
det = np.hstack((pre_det, proposals[:, 4:]))
|
||||
det = det[keep, :]
|
||||
if self.use_landmarks:
|
||||
landmarks = landmarks[keep]
|
||||
|
||||
return det, landmarks
|
||||
|
||||
def nms(self, dets):
|
||||
thresh = self.nms_threshold
|
||||
x1 = dets[:, 0]
|
||||
y1 = dets[:, 1]
|
||||
x2 = dets[:, 2]
|
||||
y2 = dets[:, 3]
|
||||
scores = dets[:, 4]
|
||||
|
||||
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||
order = scores.argsort()[::-1]
|
||||
|
||||
keep = []
|
||||
while order.size > 0:
|
||||
i = order[0]
|
||||
keep.append(i)
|
||||
xx1 = np.maximum(x1[i], x1[order[1:]])
|
||||
yy1 = np.maximum(y1[i], y1[order[1:]])
|
||||
xx2 = np.minimum(x2[i], x2[order[1:]])
|
||||
yy2 = np.minimum(y2[i], y2[order[1:]])
|
||||
|
||||
w = np.maximum(0.0, xx2 - xx1 + 1)
|
||||
h = np.maximum(0.0, yy2 - yy1 + 1)
|
||||
inter = w * h
|
||||
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
||||
|
||||
inds = np.where(ovr <= thresh)[0]
|
||||
order = order[inds + 1]
|
||||
|
||||
return keep
|
||||
|
||||
|
||||
def get_retinaface(name, rac='net3', root='~/.insightface/models', **kwargs):
|
||||
from .model_store import get_model_file
|
||||
_file = get_model_file("retinaface_%s" % name, root=root)
|
||||
return FaceDetector(_file, rac)
|
||||
|
||||
|
||||
def retinaface_r50_v1(**kwargs):
|
||||
return get_retinaface("r50_v1", rac='net3', **kwargs)
|
||||
|
||||
|
||||
def retinaface_mnet025_v1(**kwargs):
|
||||
return get_retinaface("mnet025_v1", rac='net3', **kwargs)
|
||||
|
||||
|
||||
def retinaface_mnet025_v2(**kwargs):
|
||||
return get_retinaface("mnet025_v2", rac='net3l', **kwargs)
|
||||
@@ -1,73 +0,0 @@
|
||||
from __future__ import division
|
||||
import mxnet as mx
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
__all__ = ['FaceGenderage', 'genderage_v1', 'get_genderage']
|
||||
|
||||
|
||||
class FaceGenderage:
|
||||
def __init__(self, name, download, param_file):
|
||||
self.name = name
|
||||
self.download = download
|
||||
self.param_file = param_file
|
||||
self.image_size = (112, 112)
|
||||
if download:
|
||||
assert param_file
|
||||
|
||||
def prepare(self, ctx_id):
|
||||
if self.param_file:
|
||||
pos = self.param_file.rfind('-')
|
||||
prefix = self.param_file[0:pos]
|
||||
pos2 = self.param_file.rfind('.')
|
||||
epoch = int(self.param_file[pos + 1:pos2])
|
||||
sym, arg_params, aux_params = mx.model.load_checkpoint(
|
||||
prefix, epoch)
|
||||
all_layers = sym.get_internals()
|
||||
sym = all_layers['fc1_output']
|
||||
if ctx_id >= 0:
|
||||
ctx = mx.gpu(ctx_id)
|
||||
else:
|
||||
ctx = mx.cpu()
|
||||
model = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
|
||||
data_shape = (1, 3) + self.image_size
|
||||
model.bind(data_shapes=[('data', data_shape)])
|
||||
model.set_params(arg_params, aux_params)
|
||||
#warmup
|
||||
data = mx.nd.zeros(shape=data_shape)
|
||||
db = mx.io.DataBatch(data=(data, ))
|
||||
model.forward(db, is_train=False)
|
||||
embedding = model.get_outputs()[0].asnumpy()
|
||||
self.model = model
|
||||
else:
|
||||
pass
|
||||
|
||||
def get(self, img):
|
||||
assert self.param_file and self.model
|
||||
assert img.shape[2] == 3 and img.shape[0:2] == self.image_size
|
||||
data = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
data = np.transpose(data, (2, 0, 1))
|
||||
data = np.expand_dims(data, axis=0)
|
||||
data = mx.nd.array(data)
|
||||
db = mx.io.DataBatch(data=(data, ))
|
||||
self.model.forward(db, is_train=False)
|
||||
ret = self.model.get_outputs()[0].asnumpy()
|
||||
g = ret[:, 0:2].flatten()
|
||||
gender = np.argmax(g)
|
||||
a = ret[:, 2:202].reshape((100, 2))
|
||||
a = np.argmax(a, axis=1)
|
||||
age = int(sum(a))
|
||||
return gender, age
|
||||
|
||||
|
||||
def get_genderage(name, download=True, root='~/.insightface/models', **kwargs):
|
||||
if not download:
|
||||
return FaceGenderage(name, False, None)
|
||||
else:
|
||||
from .model_store import get_model_file
|
||||
_file = get_model_file("genderage_%s" % name, root=root)
|
||||
return FaceGenderage(name, True, _file)
|
||||
|
||||
|
||||
def genderage_v1(**kwargs):
|
||||
return get_genderage("v1", download=True, **kwargs)
|
||||
@@ -1,86 +0,0 @@
|
||||
from __future__ import division
|
||||
import mxnet as mx
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
__all__ = [
|
||||
'FaceRecognition', 'arcface_r100_v1', 'arcface_outofreach_v1',
|
||||
'arcface_mfn_v1', 'get_arcface'
|
||||
]
|
||||
|
||||
|
||||
class FaceRecognition:
|
||||
def __init__(self, name, download, param_file):
|
||||
self.name = name
|
||||
self.download = download
|
||||
self.param_file = param_file
|
||||
self.image_size = (112, 112)
|
||||
if download:
|
||||
assert param_file
|
||||
|
||||
def prepare(self, ctx_id):
|
||||
if self.param_file:
|
||||
pos = self.param_file.rfind('-')
|
||||
prefix = self.param_file[0:pos]
|
||||
pos2 = self.param_file.rfind('.')
|
||||
epoch = int(self.param_file[pos + 1:pos2])
|
||||
sym, arg_params, aux_params = mx.model.load_checkpoint(
|
||||
prefix, epoch)
|
||||
all_layers = sym.get_internals()
|
||||
sym = all_layers['fc1_output']
|
||||
if ctx_id >= 0:
|
||||
ctx = mx.gpu(ctx_id)
|
||||
else:
|
||||
ctx = mx.cpu()
|
||||
model = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
|
||||
data_shape = (1, 3) + self.image_size
|
||||
model.bind(data_shapes=[('data', data_shape)])
|
||||
model.set_params(arg_params, aux_params)
|
||||
#warmup
|
||||
data = mx.nd.zeros(shape=data_shape)
|
||||
db = mx.io.DataBatch(data=(data, ))
|
||||
model.forward(db, is_train=False)
|
||||
embedding = model.get_outputs()[0].asnumpy()
|
||||
self.model = model
|
||||
else:
|
||||
pass
|
||||
|
||||
def get_embedding(self, img):
|
||||
assert self.param_file and self.model
|
||||
assert img.shape[2] == 3 and img.shape[0:2] == self.image_size
|
||||
data = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
data = np.transpose(data, (2, 0, 1))
|
||||
data = np.expand_dims(data, axis=0)
|
||||
data = mx.nd.array(data)
|
||||
db = mx.io.DataBatch(data=(data, ))
|
||||
self.model.forward(db, is_train=False)
|
||||
embedding = self.model.get_outputs()[0].asnumpy()
|
||||
return embedding
|
||||
|
||||
def compute_sim(self, img1, img2):
|
||||
emb1 = self.get_embedding(img1).flatten()
|
||||
emb2 = self.get_embedding(img2).flatten()
|
||||
from numpy.linalg import norm
|
||||
sim = np.dot(emb1, emb2) / (norm(emb1) * norm(emb2))
|
||||
return sim
|
||||
|
||||
|
||||
def get_arcface(name, download=True, root='~/.insightface/models', **kwargs):
|
||||
if not download:
|
||||
return FaceRecognition(name, False, None)
|
||||
else:
|
||||
from .model_store import get_model_file
|
||||
_file = get_model_file("arcface_%s" % name, root=root)
|
||||
return FaceRecognition(name, True, _file)
|
||||
|
||||
|
||||
def arcface_r100_v1(**kwargs):
|
||||
return get_arcface("r100_v1", download=True, **kwargs)
|
||||
|
||||
|
||||
def arcface_mfn_v1(**kwargs):
|
||||
return get_arcface("mfn_v1", download=True, **kwargs)
|
||||
|
||||
|
||||
def arcface_outofreach_v1(**kwargs):
|
||||
return get_arcface("outofreach_v1", download=False, **kwargs)
|
||||
@@ -22,7 +22,7 @@ _model_sha1 = {
|
||||
]
|
||||
}
|
||||
|
||||
base_repo_url = 'http://insightface.ai/files/'
|
||||
base_repo_url = 'https://insightface.ai/files/'
|
||||
_url_format = '{repo_url}models/{file_name}.zip'
|
||||
|
||||
|
||||
|
||||
@@ -1,56 +1,59 @@
|
||||
# pylint: disable=wildcard-import, unused-wildcard-import
|
||||
"""
|
||||
This code file mainly comes from https://github.com/dmlc/gluon-cv/blob/master/gluoncv/model_zoo/model_zoo.py
|
||||
"""
|
||||
from .face_recognition import *
|
||||
from .face_detection import *
|
||||
from .face_genderage import *
|
||||
#from .face_alignment import *
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Organization : insightface.ai
|
||||
# @Author : Jia Guo
|
||||
# @Time : 2021-05-04
|
||||
# @Function :
|
||||
|
||||
__all__ = ['get_model', 'get_model_list']
|
||||
import os
|
||||
import os.path as osp
|
||||
import glob
|
||||
import onnxruntime
|
||||
from .arcface_onnx import *
|
||||
from .scrfd import *
|
||||
|
||||
_models = {
|
||||
'arcface_r100_v1': arcface_r100_v1,
|
||||
#'arcface_mfn_v1': arcface_mfn_v1,
|
||||
#'arcface_outofreach_v1': arcface_outofreach_v1,
|
||||
'retinaface_r50_v1': retinaface_r50_v1,
|
||||
'retinaface_mnet025_v1': retinaface_mnet025_v1,
|
||||
'retinaface_mnet025_v2': retinaface_mnet025_v2,
|
||||
'genderage_v1': genderage_v1,
|
||||
}
|
||||
#__all__ = ['get_model', 'get_model_list', 'get_arcface_onnx', 'get_scrfd']
|
||||
__all__ = ['get_model']
|
||||
|
||||
|
||||
class ModelRouter:
|
||||
def __init__(self, onnx_file):
|
||||
self.onnx_file = onnx_file
|
||||
|
||||
def get_model(self):
|
||||
session = onnxruntime.InferenceSession(self.onnx_file, None)
|
||||
input_cfg = session.get_inputs()[0]
|
||||
input_shape = input_cfg.shape
|
||||
outputs = session.get_outputs()
|
||||
#print(input_shape)
|
||||
if len(outputs)>=5:
|
||||
return SCRFD(model_file=self.onnx_file, session=session)
|
||||
elif input_shape[2]==112 and input_shape[3]==112:
|
||||
return ArcFaceONNX(model_file=self.onnx_file, session=session)
|
||||
else:
|
||||
raise RuntimeError('error on model routing')
|
||||
|
||||
def find_onnx_file(dir_path):
|
||||
if not os.path.exists(dir_path):
|
||||
return None
|
||||
paths = glob.glob("%s/*.onnx" % dir_path)
|
||||
if len(paths) == 0:
|
||||
return None
|
||||
paths = sorted(paths)
|
||||
return paths[-1]
|
||||
|
||||
def get_model(name, **kwargs):
|
||||
"""Returns a pre-defined model by name
|
||||
root = kwargs.get('root', '~/.insightface/models')
|
||||
root = os.path.expanduser(root)
|
||||
if not name.endswith('.onnx'):
|
||||
model_dir = os.path.join(root, name)
|
||||
model_file = find_onnx_file(model_dir)
|
||||
if model_file is None:
|
||||
return None
|
||||
else:
|
||||
model_file = name
|
||||
assert osp.isfile(model_file), 'model should be file'
|
||||
router = ModelRouter(name)
|
||||
model = router.get_model()
|
||||
#print('get-model for ', name,' : ', model.taskname)
|
||||
return model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
Name of the model.
|
||||
root : str, default '~/.insightface/models'
|
||||
Location for keeping the model parameters.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Model
|
||||
The model.
|
||||
"""
|
||||
name = name.lower()
|
||||
if name not in _models:
|
||||
err_str = '"%s" is not among the following model list:\n\t' % (name)
|
||||
err_str += '%s' % ('\n\t'.join(sorted(_models.keys())))
|
||||
raise ValueError(err_str)
|
||||
net = _models[name](**kwargs)
|
||||
return net
|
||||
|
||||
|
||||
def get_model_list():
|
||||
"""Get the entire list of model names in model_zoo.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list of str
|
||||
Entire list of model names in model_zoo.
|
||||
|
||||
"""
|
||||
return sorted(_models.keys())
|
||||
|
||||
353
python-package/insightface/model_zoo/scrfd.py
Normal file
353
python-package/insightface/model_zoo/scrfd.py
Normal file
@@ -0,0 +1,353 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Organization : insightface.ai
|
||||
# @Author : Jia Guo
|
||||
# @Time : 2021-05-04
|
||||
# @Function :
|
||||
|
||||
from __future__ import division
|
||||
import datetime
|
||||
import numpy as np
|
||||
import onnx
|
||||
import onnxruntime
|
||||
import os
|
||||
import os.path as osp
|
||||
import cv2
|
||||
import sys
|
||||
|
||||
def softmax(z):
|
||||
assert len(z.shape) == 2
|
||||
s = np.max(z, axis=1)
|
||||
s = s[:, np.newaxis] # necessary step to do broadcasting
|
||||
e_x = np.exp(z - s)
|
||||
div = np.sum(e_x, axis=1)
|
||||
div = div[:, np.newaxis] # dito
|
||||
return e_x / div
|
||||
|
||||
def distance2bbox(points, distance, max_shape=None):
|
||||
"""Decode distance prediction to bounding box.
|
||||
|
||||
Args:
|
||||
points (Tensor): Shape (n, 2), [x, y].
|
||||
distance (Tensor): Distance from the given point to 4
|
||||
boundaries (left, top, right, bottom).
|
||||
max_shape (tuple): Shape of the image.
|
||||
|
||||
Returns:
|
||||
Tensor: Decoded bboxes.
|
||||
"""
|
||||
x1 = points[:, 0] - distance[:, 0]
|
||||
y1 = points[:, 1] - distance[:, 1]
|
||||
x2 = points[:, 0] + distance[:, 2]
|
||||
y2 = points[:, 1] + distance[:, 3]
|
||||
if max_shape is not None:
|
||||
x1 = x1.clamp(min=0, max=max_shape[1])
|
||||
y1 = y1.clamp(min=0, max=max_shape[0])
|
||||
x2 = x2.clamp(min=0, max=max_shape[1])
|
||||
y2 = y2.clamp(min=0, max=max_shape[0])
|
||||
return np.stack([x1, y1, x2, y2], axis=-1)
|
||||
|
||||
def distance2kps(points, distance, max_shape=None):
|
||||
"""Decode distance prediction to bounding box.
|
||||
|
||||
Args:
|
||||
points (Tensor): Shape (n, 2), [x, y].
|
||||
distance (Tensor): Distance from the given point to 4
|
||||
boundaries (left, top, right, bottom).
|
||||
max_shape (tuple): Shape of the image.
|
||||
|
||||
Returns:
|
||||
Tensor: Decoded bboxes.
|
||||
"""
|
||||
preds = []
|
||||
for i in range(0, distance.shape[1], 2):
|
||||
px = points[:, i%2] + distance[:, i]
|
||||
py = points[:, i%2+1] + distance[:, i+1]
|
||||
if max_shape is not None:
|
||||
px = px.clamp(min=0, max=max_shape[1])
|
||||
py = py.clamp(min=0, max=max_shape[0])
|
||||
preds.append(px)
|
||||
preds.append(py)
|
||||
return np.stack(preds, axis=-1)
|
||||
|
||||
class SCRFD:
|
||||
def __init__(self, model_file=None, session=None):
|
||||
import onnxruntime
|
||||
self.model_file = model_file
|
||||
self.session = session
|
||||
self.taskname = 'detection'
|
||||
if self.session is None:
|
||||
assert self.model_file is not None
|
||||
assert osp.exists(self.model_file)
|
||||
self.session = onnxruntime.InferenceSession(self.model_file, None)
|
||||
self.center_cache = {}
|
||||
self.nms_threshold = 0.4
|
||||
self._init_vars()
|
||||
|
||||
def _init_vars(self):
|
||||
input_cfg = self.session.get_inputs()[0]
|
||||
input_shape = input_cfg.shape
|
||||
#print(input_shape)
|
||||
if isinstance(input_shape[2], str):
|
||||
self.input_size = None
|
||||
else:
|
||||
self.input_size = tuple(input_shape[2:4][::-1])
|
||||
#print('image_size:', self.image_size)
|
||||
input_name = input_cfg.name
|
||||
outputs = self.session.get_outputs()
|
||||
output_names = []
|
||||
for o in outputs:
|
||||
output_names.append(o.name)
|
||||
self.input_name = input_name
|
||||
self.output_names = output_names
|
||||
#print(self.output_names)
|
||||
#assert len(outputs)==10 or len(outputs)==15
|
||||
self.use_kps = False
|
||||
self._anchor_ratio = 1.0
|
||||
self._num_anchors = 1
|
||||
if len(outputs)==6:
|
||||
self.fmc = 3
|
||||
self._feat_stride_fpn = [8, 16, 32]
|
||||
self._num_anchors = 2
|
||||
elif len(outputs)==9:
|
||||
self.fmc = 3
|
||||
self._feat_stride_fpn = [8, 16, 32]
|
||||
self._num_anchors = 2
|
||||
self.use_kps = True
|
||||
elif len(outputs)==10:
|
||||
self.fmc = 5
|
||||
self._feat_stride_fpn = [8, 16, 32, 64, 128]
|
||||
self._num_anchors = 1
|
||||
elif len(outputs)==15:
|
||||
self.fmc = 5
|
||||
self._feat_stride_fpn = [8, 16, 32, 64, 128]
|
||||
self._num_anchors = 1
|
||||
self.use_kps = True
|
||||
|
||||
def prepare(self, ctx_id, **kwargs):
|
||||
if ctx_id<0:
|
||||
self.session.set_providers(['CPUExecutionProvider'])
|
||||
nms_threshold = kwargs.get('nms_threshold', None)
|
||||
if nms_threshold is not None:
|
||||
self.nms_threshold = nms_threshold
|
||||
input_size = kwargs.get('input_size', None)
|
||||
if input_size is not None:
|
||||
if self.input_size is not None:
|
||||
print('warning: det_size is already set in scrfd model, ignore')
|
||||
else:
|
||||
self.input_size = input_size
|
||||
#for keyword in ['_selfgen', '_shape', '.']:
|
||||
# pos = self.model_file.rfind(keyword)
|
||||
# if pos>=0:
|
||||
# break
|
||||
#model_prefix = self.model_file[0:pos]
|
||||
#new_model_file = model_prefix+"_selfgen_shape%dx%d.onnx"%(input_size[1], input_size[0])
|
||||
#if not osp.exists(new_model_file):
|
||||
# model = onnx.load(self.model_file)
|
||||
# from onnxsim import simplify
|
||||
# input = model.graph.input[0]
|
||||
# #model.graph.input[0].type.tensor_type.elem_type = 0
|
||||
# #print(input.type.tensor_type.elem_type)
|
||||
# #print(input.type.tensor_type.shape.dim[0].dim_param.__class__)
|
||||
# #print(input.type.tensor_type.shape.dim[0])
|
||||
# #input.type.tensor_type.shape.dim[2].dim_param = input_size[1]
|
||||
# #input.type.tensor_type.shape.dim[3].dim_param = input_size[0]
|
||||
# input.type.tensor_type.shape.dim[2].dim_value = input_size[1]
|
||||
# input.type.tensor_type.shape.dim[3].dim_value = input_size[0]
|
||||
# model, check = simplify(model)
|
||||
# assert check, "Simplified ONNX model could not be validated"
|
||||
# onnx.save(model, new_model_file)
|
||||
# print('saved new onnx scrfd model:', new_model_file)
|
||||
#self.model_file = new_model_file
|
||||
#self.session = onnxruntime.InferenceSession(self.model_file, None)
|
||||
#self._init_vars()
|
||||
|
||||
def forward(self, img, threshold):
|
||||
scores_list = []
|
||||
bboxes_list = []
|
||||
kpss_list = []
|
||||
input_size = tuple(img.shape[0:2][::-1])
|
||||
blob = cv2.dnn.blobFromImage(img, 1.0/128, input_size, (127.5, 127.5, 127.5), swapRB=True)
|
||||
net_outs = self.session.run(self.output_names, {self.input_name : blob})
|
||||
|
||||
input_height = blob.shape[2]
|
||||
input_width = blob.shape[3]
|
||||
fmc = self.fmc
|
||||
for idx, stride in enumerate(self._feat_stride_fpn):
|
||||
scores = net_outs[idx]
|
||||
bbox_preds = net_outs[idx+fmc]
|
||||
bbox_preds = bbox_preds * stride
|
||||
if self.use_kps:
|
||||
kps_preds = net_outs[idx+fmc*2] * stride
|
||||
height = input_height // stride
|
||||
width = input_width // stride
|
||||
K = height * width
|
||||
key = (height, width, stride)
|
||||
if key in self.center_cache:
|
||||
anchor_centers = self.center_cache[key]
|
||||
else:
|
||||
#solution-1, c style:
|
||||
#anchor_centers = np.zeros( (height, width, 2), dtype=np.float32 )
|
||||
#for i in range(height):
|
||||
# anchor_centers[i, :, 1] = i
|
||||
#for i in range(width):
|
||||
# anchor_centers[:, i, 0] = i
|
||||
|
||||
#solution-2:
|
||||
#ax = np.arange(width, dtype=np.float32)
|
||||
#ay = np.arange(height, dtype=np.float32)
|
||||
#xv, yv = np.meshgrid(np.arange(width), np.arange(height))
|
||||
#anchor_centers = np.stack([xv, yv], axis=-1).astype(np.float32)
|
||||
|
||||
#solution-3:
|
||||
anchor_centers = np.stack(np.mgrid[:height, :width][::-1], axis=-1).astype(np.float32)
|
||||
#print(anchor_centers.shape)
|
||||
|
||||
anchor_centers = (anchor_centers * stride).reshape( (-1, 2) )
|
||||
if self._num_anchors>1:
|
||||
anchor_centers = np.stack([anchor_centers]*self._num_anchors, axis=1).reshape( (-1,2) )
|
||||
if len(self.center_cache)<100:
|
||||
self.center_cache[key] = anchor_centers
|
||||
|
||||
pos_inds = np.where(scores>=threshold)[0]
|
||||
bboxes = distance2bbox(anchor_centers, bbox_preds)
|
||||
pos_scores = scores[pos_inds]
|
||||
pos_bboxes = bboxes[pos_inds]
|
||||
scores_list.append(pos_scores)
|
||||
bboxes_list.append(pos_bboxes)
|
||||
if self.use_kps:
|
||||
kpss = distance2kps(anchor_centers, kps_preds)
|
||||
#kpss = kps_preds
|
||||
kpss = kpss.reshape( (kpss.shape[0], -1, 2) )
|
||||
pos_kpss = kpss[pos_inds]
|
||||
kpss_list.append(pos_kpss)
|
||||
return scores_list, bboxes_list, kpss_list
|
||||
|
||||
def detect(self, img, threshold=0.5, input_size = None, max_num=0, metric='default'):
|
||||
assert input_size is not None or self.input_size is not None
|
||||
input_size = self.input_size if input_size is None else input_size
|
||||
|
||||
im_ratio = float(img.shape[0]) / img.shape[1]
|
||||
model_ratio = float(input_size[1]) / input_size[0]
|
||||
if im_ratio>model_ratio:
|
||||
new_height = input_size[1]
|
||||
new_width = int(new_height / im_ratio)
|
||||
else:
|
||||
new_width = input_size[0]
|
||||
new_height = int(new_width * im_ratio)
|
||||
det_scale = float(new_height) / img.shape[0]
|
||||
resized_img = cv2.resize(img, (new_width, new_height))
|
||||
det_img = np.zeros( (input_size[1], input_size[0], 3), dtype=np.uint8 )
|
||||
det_img[:new_height, :new_width, :] = resized_img
|
||||
|
||||
scores_list, bboxes_list, kpss_list = self.forward(det_img, threshold)
|
||||
|
||||
scores = np.vstack(scores_list)
|
||||
scores_ravel = scores.ravel()
|
||||
order = scores_ravel.argsort()[::-1]
|
||||
bboxes = np.vstack(bboxes_list) / det_scale
|
||||
if self.use_kps:
|
||||
kpss = np.vstack(kpss_list) / det_scale
|
||||
pre_det = np.hstack((bboxes, scores)).astype(np.float32, copy=False)
|
||||
pre_det = pre_det[order, :]
|
||||
keep = self.nms(pre_det)
|
||||
det = pre_det[keep, :]
|
||||
if self.use_kps:
|
||||
kpss = kpss[order,:,:]
|
||||
kpss = kpss[keep,:,:]
|
||||
else:
|
||||
kpss = None
|
||||
if max_num > 0 and det.shape[0] > max_num:
|
||||
area = (det[:, 2] - det[:, 0]) * (det[:, 3] -
|
||||
det[:, 1])
|
||||
img_center = img.shape[0] // 2, img.shape[1] // 2
|
||||
offsets = np.vstack([
|
||||
(det[:, 0] + det[:, 2]) / 2 - img_center[1],
|
||||
(det[:, 1] + det[:, 3]) / 2 - img_center[0]
|
||||
])
|
||||
offset_dist_squared = np.sum(np.power(offsets, 2.0), 0)
|
||||
if metric=='max':
|
||||
values = area
|
||||
else:
|
||||
values = area - offset_dist_squared * 2.0 # some extra weight on the centering
|
||||
bindex = np.argsort(
|
||||
values)[::-1] # some extra weight on the centering
|
||||
bindex = bindex[0:max_num]
|
||||
bboxes = bboxes[bindex, :]
|
||||
if kpss is not None:
|
||||
kpss = kpss[bindex, :]
|
||||
return det, kpss
|
||||
|
||||
def nms(self, dets):
|
||||
thresh = self.nms_threshold
|
||||
x1 = dets[:, 0]
|
||||
y1 = dets[:, 1]
|
||||
x2 = dets[:, 2]
|
||||
y2 = dets[:, 3]
|
||||
scores = dets[:, 4]
|
||||
|
||||
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||
order = scores.argsort()[::-1]
|
||||
|
||||
keep = []
|
||||
while order.size > 0:
|
||||
i = order[0]
|
||||
keep.append(i)
|
||||
xx1 = np.maximum(x1[i], x1[order[1:]])
|
||||
yy1 = np.maximum(y1[i], y1[order[1:]])
|
||||
xx2 = np.minimum(x2[i], x2[order[1:]])
|
||||
yy2 = np.minimum(y2[i], y2[order[1:]])
|
||||
|
||||
w = np.maximum(0.0, xx2 - xx1 + 1)
|
||||
h = np.maximum(0.0, yy2 - yy1 + 1)
|
||||
inter = w * h
|
||||
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
||||
|
||||
inds = np.where(ovr <= thresh)[0]
|
||||
order = order[inds + 1]
|
||||
|
||||
return keep
|
||||
|
||||
def get_scrfd(name, download=False, root='~/.insightface/models', **kwargs):
|
||||
if not download:
|
||||
assert os.path.exists(name)
|
||||
return SCRFD(name)
|
||||
else:
|
||||
from .model_store import get_model_file
|
||||
_file = get_model_file("scrfd_%s" % name, root=root)
|
||||
return SCRFD(_file)
|
||||
|
||||
|
||||
def scrfd_2p5gkps(**kwargs):
|
||||
return get_scrfd("2p5gkps", download=True, **kwargs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import glob
|
||||
detector = SCRFD(model_file='./det.onnx')
|
||||
detector.prepare(-1)
|
||||
img_paths = ['tests/data/t1.jpg']
|
||||
for img_path in img_paths:
|
||||
img = cv2.imread(img_path)
|
||||
|
||||
for _ in range(1):
|
||||
ta = datetime.datetime.now()
|
||||
#bboxes, kpss = detector.detect(img, 0.5, input_size = (640, 640))
|
||||
bboxes, kpss = detector.detect(img, 0.5)
|
||||
tb = datetime.datetime.now()
|
||||
print('all cost:', (tb-ta).total_seconds()*1000)
|
||||
print(img_path, bboxes.shape)
|
||||
if kpss is not None:
|
||||
print(kpss.shape)
|
||||
for i in range(bboxes.shape[0]):
|
||||
bbox = bboxes[i]
|
||||
x1,y1,x2,y2,score = bbox.astype(np.int)
|
||||
cv2.rectangle(img, (x1,y1) , (x2,y2) , (255,0,0) , 2)
|
||||
if kpss is not None:
|
||||
kps = kpss[i]
|
||||
for kp in kps:
|
||||
kp = kp.astype(np.int)
|
||||
cv2.circle(img, tuple(kp) , 1, (0,0,255) , 2)
|
||||
filename = img_path.split('/')[-1]
|
||||
print('output:', filename)
|
||||
cv2.imwrite('./outputs/%s'%filename, img)
|
||||
|
||||
Reference in New Issue
Block a user