code of pip insightface==0.2

This commit is contained in:
Jia Guo
2021-05-04 16:17:50 +08:00
parent 5895a9fb3c
commit 6fce7ddaf0
12 changed files with 615 additions and 761 deletions

View File

@@ -3,25 +3,15 @@
"""InsightFace: A Face Analysis Toolkit."""
from __future__ import absolute_import
# mxnet version check
#mx_version = '1.4.0'
try:
import mxnet as mx
#from distutils.version import LooseVersion
#if LooseVersion(mx.__version__) < LooseVersion(mx_version):
# msg = (
# "Legacy mxnet-mkl=={} detected, some new modules may not work properly. "
# "mxnet-mkl>={} is required. You can use pip to upgrade mxnet "
# "`pip install mxnet-mkl --pre --upgrade` "
# "or `pip install mxnet-cu90mkl --pre --upgrade`").format(mx.__version__, mx_version)
# raise ImportError(msg)
#import mxnet as mx
import onnxruntime
except ImportError:
raise ImportError(
"Unable to import dependency mxnet. "
"A quick tip is to install via `pip install mxnet-mkl/mxnet-cu90mkl --pre`. "
"Unable to import dependency onnxruntime. "
)
__version__ = '0.1.5'
__version__ = '0.2.0'
from . import model_zoo
from . import utils

View File

@@ -1,85 +1,100 @@
# -*- coding: utf-8 -*-
# @Organization : insightface.ai
# @Author : Jia Guo
# @Time : 2021-05-04
# @Function :
from __future__ import division
import collections
import mxnet as mx
import numpy as np
import glob
import os
import os.path as osp
from numpy.linalg import norm
import mxnet.ndarray as nd
from ..model_zoo import model_zoo
from ..utils import face_align
__all__ = ['FaceAnalysis', 'Face']
Face = collections.namedtuple('Face', [
'bbox', 'landmark', 'det_score', 'embedding', 'gender', 'age',
'embedding_norm', 'normed_embedding'
'bbox', 'kps', 'det_score', 'embedding', 'gender', 'age',
'embedding_norm', 'normed_embedding',
'landmark'
])
Face.__new__.__defaults__ = (None, ) * len(Face._fields)
class FaceAnalysis:
def __init__(self,
det_name='retinaface_r50_v1',
rec_name='arcface_r100_v1',
ga_name='genderage_v1'):
assert det_name is not None
self.det_model = model_zoo.get_model(det_name)
if rec_name is not None:
self.rec_model = model_zoo.get_model(rec_name)
else:
self.rec_model = None
if ga_name is not None:
self.ga_model = model_zoo.get_model(ga_name)
else:
self.ga_model = None
def __init__(self, name, root='~/.insightface/models'):
self.models = {}
root = os.path.expanduser(root)
onnx_files = glob.glob(osp.join(root, name, '*.onnx'))
onnx_files = sorted(onnx_files)
for onnx_file in onnx_files:
if onnx_file.find('_selfgen_')>0:
#print('ignore:', onnx_file)
continue
model = model_zoo.get_model(onnx_file)
if model.taskname not in self.models:
print('find model:', onnx_file, model.taskname)
self.models[model.taskname] = model
else:
print('duplicated model task type, ignore:', onnx_file, model.taskname)
del model
assert 'detection' in self.models
self.det_model = self.models['detection']
def prepare(self, ctx_id, nms=0.4):
self.det_model.prepare(ctx_id, nms)
if self.rec_model is not None:
self.rec_model.prepare(ctx_id)
if self.ga_model is not None:
self.ga_model.prepare(ctx_id)
def get(self, img, det_thresh=0.8, det_scale=1.0, max_num=0):
bboxes, landmarks = self.det_model.detect(img,
threshold=det_thresh,
scale=det_scale)
def prepare(self, ctx_id, det_thresh=0.5, det_size=(640, 640)):
self.det_thresh = det_thresh
assert det_size is not None
print('set det-size:', det_size)
self.det_size = det_size
for taskname, model in self.models.items():
if taskname=='detection':
model.prepare(ctx_id, input_size=det_size)
else:
model.prepare(ctx_id)
def get(self, img, max_num=0):
bboxes, kpss = self.det_model.detect(img,
threshold=self.det_thresh,
max_num=max_num,
metric='default')
if bboxes.shape[0] == 0:
return []
if max_num > 0 and bboxes.shape[0] > max_num:
area = (bboxes[:, 2] - bboxes[:, 0]) * (bboxes[:, 3] -
bboxes[:, 1])
img_center = img.shape[0] // 2, img.shape[1] // 2
offsets = np.vstack([
(bboxes[:, 0] + bboxes[:, 2]) / 2 - img_center[1],
(bboxes[:, 1] + bboxes[:, 3]) / 2 - img_center[0]
])
offset_dist_squared = np.sum(np.power(offsets, 2.0), 0)
values = area - offset_dist_squared * 2.0 # some extra weight on the centering
bindex = np.argsort(
values)[::-1] # some extra weight on the centering
bindex = bindex[0:max_num]
bboxes = bboxes[bindex, :]
landmarks = landmarks[bindex, :]
ret = []
for i in range(bboxes.shape[0]):
bbox = bboxes[i, 0:4]
det_score = bboxes[i, 4]
landmark = landmarks[i]
_img = face_align.norm_crop(img, landmark=landmark)
kps = None
if kpss is not None:
kps = kpss[i]
embedding = None
embedding_norm = None
normed_embedding = None
embedding_norm = None
gender = None
age = None
if self.rec_model is not None:
embedding = self.rec_model.get_embedding(_img).flatten()
if 'recognition' in self.models:
assert kps is not None
rec_model = self.models['recognition']
aimg = face_align.norm_crop(img, landmark=kps)
embedding = None
embedding_norm = None
normed_embedding = None
gender = None
age = None
embedding = rec_model.get_feat(aimg).flatten()
embedding_norm = norm(embedding)
normed_embedding = embedding / embedding_norm
if self.ga_model is not None:
gender, age = self.ga_model.get(_img)
if 'genderage' in self.models:
assert aimg is not None
ga_model = self.models['genderage']
gender, age = ga_model.get(_img)
face = Face(bbox=bbox,
landmark=landmark,
kps=kps,
det_score=det_score,
embedding=embedding,
gender=gender,
@@ -88,3 +103,4 @@ class FaceAnalysis:
embedding_norm=embedding_norm)
ret.append(face)
return ret

View File

@@ -1 +1,3 @@
from .model_zoo import get_model, get_model_list
from .model_zoo import get_model
from .arcface_onnx import ArcFaceONNX
from .scrfd import SCRFD

View File

@@ -0,0 +1,82 @@
# -*- coding: utf-8 -*-
# @Organization : insightface.ai
# @Author : Jia Guo
# @Time : 2021-05-04
# @Function :
from __future__ import division
import numpy as np
import cv2
import onnx
import onnxruntime
from ..utils import face_align
__all__ = [
'ArcFaceONNX',
]
class ArcFaceONNX:
def __init__(self, model_file=None, session=None):
import onnxruntime
assert model_file is not None
self.model_file = model_file
self.session = session
self.taskname = 'recognition'
find_sub = False
find_mul = False
model = onnx.load(self.model_file)
graph = model.graph
for nid, node in enumerate(graph.node[:8]):
#print(nid, node.name)
if node.name.startswith('Sub') or node.name.startswith('_minus'):
find_sub = True
if node.name.startswith('Mul') or node.name.startswith('_mul'):
find_mul = True
if find_sub and find_mul:
#mxnet arcface model
input_mean = 0.0
input_std = 1.0
else:
input_mean = 127.5
input_std = 127.5
self.input_mean = input_mean
self.input_std = input_std
print('input mean and std:', self.input_mean, self.input_std)
if self.session is None:
self.session = onnxruntime.InferenceSession(self.model_file, None)
input_cfg = self.session.get_inputs()[0]
input_shape = input_cfg.shape
input_name = input_cfg.name
self.input_size = tuple(input_shape[2:4][::-1])
outputs = self.session.get_outputs()
output_names = []
for out in outputs:
output_names.append(out.name)
self.session = session
self.input_name = input_name
self.output_names = output_names
assert len(self.output_names)==1
def prepare(self, ctx_id, **kwargs):
if ctx_id<0:
self.session.set_providers(['CPUExecutionProvider'])
def get_feat(self, img):
assert img.shape[2] == 3
input_size = tuple(img.shape[0:2][::-1])
assert input_size==self.input_size
blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
net_outs = self.session.run(self.output_names, {self.input_name : blob})
feat = net_outs[0]
return feat
def compute_sim(self, feat1, feat2):
from np.linalg import norm
feat1 = feat1.ravel()
feat2 = feat2.ravel()
sim = np.dot(feat1, feat2) / (norm(feat1) * norm(feat2))
return sim

View File

@@ -1,466 +0,0 @@
from __future__ import division
import mxnet as mx
import numpy as np
import mxnet.ndarray as nd
import cv2
__all__ = [
'FaceDetector', 'retinaface_r50_v1', 'retinaface_mnet025_v1',
'retinaface_mnet025_v2', 'get_retinaface'
]
def _whctrs(anchor):
"""
Return width, height, x center, and y center for an anchor (window).
"""
w = anchor[2] - anchor[0] + 1
h = anchor[3] - anchor[1] + 1
x_ctr = anchor[0] + 0.5 * (w - 1)
y_ctr = anchor[1] + 0.5 * (h - 1)
return w, h, x_ctr, y_ctr
def _mkanchors(ws, hs, x_ctr, y_ctr):
"""
Given a vector of widths (ws) and heights (hs) around a center
(x_ctr, y_ctr), output a set of anchors (windows).
"""
ws = ws[:, np.newaxis]
hs = hs[:, np.newaxis]
anchors = np.hstack((x_ctr - 0.5 * (ws - 1), y_ctr - 0.5 * (hs - 1),
x_ctr + 0.5 * (ws - 1), y_ctr + 0.5 * (hs - 1)))
return anchors
def _ratio_enum(anchor, ratios):
"""
Enumerate a set of anchors for each aspect ratio wrt an anchor.
"""
w, h, x_ctr, y_ctr = _whctrs(anchor)
size = w * h
size_ratios = size / ratios
ws = np.round(np.sqrt(size_ratios))
hs = np.round(ws * ratios)
anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
return anchors
def _scale_enum(anchor, scales):
"""
Enumerate a set of anchors for each scale wrt an anchor.
"""
w, h, x_ctr, y_ctr = _whctrs(anchor)
ws = w * scales
hs = h * scales
anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
return anchors
def anchors_plane(height, width, stride, base_anchors):
"""
Parameters
----------
height: height of plane
width: width of plane
stride: stride ot the original image
anchors_base: (A, 4) a base set of anchors
Returns
-------
all_anchors: (height, width, A, 4) ndarray of anchors spreading over the plane
"""
A = base_anchors.shape[0]
all_anchors = np.zeros((height, width, A, 4), dtype=np.float32)
for iw in range(width):
sw = iw * stride
for ih in range(height):
sh = ih * stride
for k in range(A):
all_anchors[ih, iw, k, 0] = base_anchors[k, 0] + sw
all_anchors[ih, iw, k, 1] = base_anchors[k, 1] + sh
all_anchors[ih, iw, k, 2] = base_anchors[k, 2] + sw
all_anchors[ih, iw, k, 3] = base_anchors[k, 3] + sh
return all_anchors
def generate_anchors(base_size=16,
ratios=[0.5, 1, 2],
scales=2**np.arange(3, 6),
stride=16):
"""
Generate anchor (reference) windows by enumerating aspect ratios X
scales wrt a reference (0, 0, 15, 15) window.
"""
base_anchor = np.array([1, 1, base_size, base_size]) - 1
ratio_anchors = _ratio_enum(base_anchor, ratios)
anchors = np.vstack([
_scale_enum(ratio_anchors[i, :], scales)
for i in range(ratio_anchors.shape[0])
])
return anchors
def generate_anchors_fpn(cfg):
"""
Generate anchor (reference) windows by enumerating aspect ratios X
scales wrt a reference (0, 0, 15, 15) window.
"""
RPN_FEAT_STRIDE = []
for k in cfg:
RPN_FEAT_STRIDE.append(int(k))
RPN_FEAT_STRIDE = sorted(RPN_FEAT_STRIDE, reverse=True)
anchors = []
for k in RPN_FEAT_STRIDE:
v = cfg[str(k)]
bs = v['BASE_SIZE']
__ratios = np.array(v['RATIOS'])
__scales = np.array(v['SCALES'])
stride = int(k)
#print('anchors_fpn', bs, __ratios, __scales, file=sys.stderr)
r = generate_anchors(bs, __ratios, __scales, stride)
#print('anchors_fpn', r.shape, file=sys.stderr)
anchors.append(r)
return anchors
def clip_pad(tensor, pad_shape):
"""
Clip boxes of the pad area.
:param tensor: [n, c, H, W]
:param pad_shape: [h, w]
:return: [n, c, h, w]
"""
H, W = tensor.shape[2:]
h, w = pad_shape
if h < H or w < W:
tensor = tensor[:, :, :h, :w].copy()
return tensor
def bbox_pred(boxes, box_deltas):
"""
Transform the set of class-agnostic boxes into class-specific boxes
by applying the predicted offsets (box_deltas)
:param boxes: !important [N 4]
:param box_deltas: [N, 4 * num_classes]
:return: [N 4 * num_classes]
"""
if boxes.shape[0] == 0:
return np.zeros((0, box_deltas.shape[1]))
boxes = boxes.astype(np.float, copy=False)
widths = boxes[:, 2] - boxes[:, 0] + 1.0
heights = boxes[:, 3] - boxes[:, 1] + 1.0
ctr_x = boxes[:, 0] + 0.5 * (widths - 1.0)
ctr_y = boxes[:, 1] + 0.5 * (heights - 1.0)
dx = box_deltas[:, 0:1]
dy = box_deltas[:, 1:2]
dw = box_deltas[:, 2:3]
dh = box_deltas[:, 3:4]
pred_ctr_x = dx * widths[:, np.newaxis] + ctr_x[:, np.newaxis]
pred_ctr_y = dy * heights[:, np.newaxis] + ctr_y[:, np.newaxis]
pred_w = np.exp(dw) * widths[:, np.newaxis]
pred_h = np.exp(dh) * heights[:, np.newaxis]
pred_boxes = np.zeros(box_deltas.shape)
# x1
pred_boxes[:, 0:1] = pred_ctr_x - 0.5 * (pred_w - 1.0)
# y1
pred_boxes[:, 1:2] = pred_ctr_y - 0.5 * (pred_h - 1.0)
# x2
pred_boxes[:, 2:3] = pred_ctr_x + 0.5 * (pred_w - 1.0)
# y2
pred_boxes[:, 3:4] = pred_ctr_y + 0.5 * (pred_h - 1.0)
if box_deltas.shape[1] > 4:
pred_boxes[:, 4:] = box_deltas[:, 4:]
return pred_boxes
def landmark_pred(boxes, landmark_deltas):
if boxes.shape[0] == 0:
return np.zeros((0, landmark_deltas.shape[1]))
boxes = boxes.astype(np.float, copy=False)
widths = boxes[:, 2] - boxes[:, 0] + 1.0
heights = boxes[:, 3] - boxes[:, 1] + 1.0
ctr_x = boxes[:, 0] + 0.5 * (widths - 1.0)
ctr_y = boxes[:, 1] + 0.5 * (heights - 1.0)
pred = landmark_deltas.copy()
for i in range(5):
pred[:, i, 0] = landmark_deltas[:, i, 0] * widths + ctr_x
pred[:, i, 1] = landmark_deltas[:, i, 1] * heights + ctr_y
return pred
class FaceDetector:
def __init__(self, param_file, rac):
self.param_file = param_file
self.rac = rac
self.default_image_size = (480, 640)
def prepare(self, ctx_id, nms=0.4, fix_image_size=None):
pos = self.param_file.rfind('-')
prefix = self.param_file[0:pos]
pos2 = self.param_file.rfind('.')
epoch = int(self.param_file[pos + 1:pos2])
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
if ctx_id >= 0:
ctx = mx.gpu(ctx_id)
else:
ctx = mx.cpu()
model = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
if fix_image_size is not None:
data_shape = (1, 3) + fix_image_size
else:
data_shape = (1, 3) + self.default_image_size
model.bind(data_shapes=[('data', data_shape)])
model.set_params(arg_params, aux_params)
#warmup
data = mx.nd.zeros(shape=data_shape)
db = mx.io.DataBatch(data=(data, ))
model.forward(db, is_train=False)
out = model.get_outputs()[0].asnumpy()
self.model = model
self.nms_threshold = nms
self.landmark_std = 1.0
_ratio = (1., )
fmc = 3
if self.rac == 'net3':
_ratio = (1., )
elif self.rac == 'net3l':
_ratio = (1., )
self.landmark_std = 0.2
elif network == 'net5': #retinaface
fmc = 5
else:
assert False, 'rac setting error %s' % self.rac
if fmc == 3:
self._feat_stride_fpn = [32, 16, 8]
self.anchor_cfg = {
'32': {
'SCALES': (32, 16),
'BASE_SIZE': 16,
'RATIOS': _ratio,
'ALLOWED_BORDER': 9999
},
'16': {
'SCALES': (8, 4),
'BASE_SIZE': 16,
'RATIOS': _ratio,
'ALLOWED_BORDER': 9999
},
'8': {
'SCALES': (2, 1),
'BASE_SIZE': 16,
'RATIOS': _ratio,
'ALLOWED_BORDER': 9999
},
}
elif fmc == 5:
self._feat_stride_fpn = [64, 32, 16, 8, 4]
self.anchor_cfg = {}
_ass = 2.0**(1.0 / 3)
_basescale = 1.0
for _stride in [4, 8, 16, 32, 64]:
key = str(_stride)
value = {
'BASE_SIZE': 16,
'RATIOS': _ratio,
'ALLOWED_BORDER': 9999
}
scales = []
for _ in range(3):
scales.append(_basescale)
_basescale *= _ass
value['SCALES'] = tuple(scales)
self.anchor_cfg[key] = value
print(self._feat_stride_fpn, self.anchor_cfg)
self.use_landmarks = False
if len(sym) // len(self._feat_stride_fpn) == 3:
self.use_landmarks = True
print('use_landmarks', self.use_landmarks)
self.fpn_keys = []
for s in self._feat_stride_fpn:
self.fpn_keys.append('stride%s' % s)
self._anchors_fpn = dict(
zip(self.fpn_keys, generate_anchors_fpn(cfg=self.anchor_cfg)))
for k in self._anchors_fpn:
v = self._anchors_fpn[k].astype(np.float32)
self._anchors_fpn[k] = v
self.anchor_plane_cache = {}
self._num_anchors = dict(
zip(self.fpn_keys,
[anchors.shape[0] for anchors in self._anchors_fpn.values()]))
def detect(self, img, threshold=0.5, scale=1.0):
proposals_list = []
scores_list = []
landmarks_list = []
if scale == 1.0:
im = img
else:
im = cv2.resize(img,
None,
None,
fx=scale,
fy=scale,
interpolation=cv2.INTER_LINEAR)
im_info = [im.shape[0], im.shape[1]]
im_tensor = np.zeros((1, 3, im.shape[0], im.shape[1]))
for i in range(3):
im_tensor[0, i, :, :] = im[:, :, 2 - i]
data = nd.array(im_tensor)
db = mx.io.DataBatch(data=(data, ),
provide_data=[('data', data.shape)])
self.model.forward(db, is_train=False)
net_out = self.model.get_outputs()
for _idx, s in enumerate(self._feat_stride_fpn):
_key = 'stride%s' % s
stride = int(s)
if self.use_landmarks:
idx = _idx * 3
else:
idx = _idx * 2
scores = net_out[idx].asnumpy()
scores = scores[:, self._num_anchors['stride%s' % s]:, :, :]
idx += 1
bbox_deltas = net_out[idx].asnumpy()
height, width = bbox_deltas.shape[2], bbox_deltas.shape[3]
A = self._num_anchors['stride%s' % s]
K = height * width
key = (height, width, stride)
if key in self.anchor_plane_cache:
anchors = self.anchor_plane_cache[key]
else:
anchors_fpn = self._anchors_fpn['stride%s' % s]
anchors = anchors_plane(height, width, stride, anchors_fpn)
anchors = anchors.reshape((K * A, 4))
if len(self.anchor_plane_cache) < 100:
self.anchor_plane_cache[key] = anchors
scores = clip_pad(scores, (height, width))
scores = scores.transpose((0, 2, 3, 1)).reshape((-1, 1))
bbox_deltas = clip_pad(bbox_deltas, (height, width))
bbox_deltas = bbox_deltas.transpose((0, 2, 3, 1))
bbox_pred_len = bbox_deltas.shape[3] // A
bbox_deltas = bbox_deltas.reshape((-1, bbox_pred_len))
proposals = bbox_pred(anchors, bbox_deltas)
#proposals = clip_boxes(proposals, im_info[:2])
scores_ravel = scores.ravel()
order = np.where(scores_ravel >= threshold)[0]
proposals = proposals[order, :]
scores = scores[order]
proposals[:, 0:4] /= scale
proposals_list.append(proposals)
scores_list.append(scores)
if self.use_landmarks:
idx += 1
landmark_deltas = net_out[idx].asnumpy()
landmark_deltas = clip_pad(landmark_deltas, (height, width))
landmark_pred_len = landmark_deltas.shape[1] // A
landmark_deltas = landmark_deltas.transpose(
(0, 2, 3, 1)).reshape((-1, 5, landmark_pred_len // 5))
landmark_deltas *= self.landmark_std
#print(landmark_deltas.shape, landmark_deltas)
landmarks = landmark_pred(anchors, landmark_deltas)
landmarks = landmarks[order, :]
landmarks[:, :, 0:2] /= scale
landmarks_list.append(landmarks)
proposals = np.vstack(proposals_list)
landmarks = None
if proposals.shape[0] == 0:
if self.use_landmarks:
landmarks = np.zeros((0, 5, 2))
return np.zeros((0, 5)), landmarks
scores = np.vstack(scores_list)
scores_ravel = scores.ravel()
order = scores_ravel.argsort()[::-1]
proposals = proposals[order, :]
scores = scores[order]
if self.use_landmarks:
landmarks = np.vstack(landmarks_list)
landmarks = landmarks[order].astype(np.float32, copy=False)
pre_det = np.hstack((proposals[:, 0:4], scores)).astype(np.float32,
copy=False)
keep = self.nms(pre_det)
det = np.hstack((pre_det, proposals[:, 4:]))
det = det[keep, :]
if self.use_landmarks:
landmarks = landmarks[keep]
return det, landmarks
def nms(self, dets):
thresh = self.nms_threshold
x1 = dets[:, 0]
y1 = dets[:, 1]
x2 = dets[:, 2]
y2 = dets[:, 3]
scores = dets[:, 4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
inds = np.where(ovr <= thresh)[0]
order = order[inds + 1]
return keep
def get_retinaface(name, rac='net3', root='~/.insightface/models', **kwargs):
from .model_store import get_model_file
_file = get_model_file("retinaface_%s" % name, root=root)
return FaceDetector(_file, rac)
def retinaface_r50_v1(**kwargs):
return get_retinaface("r50_v1", rac='net3', **kwargs)
def retinaface_mnet025_v1(**kwargs):
return get_retinaface("mnet025_v1", rac='net3', **kwargs)
def retinaface_mnet025_v2(**kwargs):
return get_retinaface("mnet025_v2", rac='net3l', **kwargs)

View File

@@ -1,73 +0,0 @@
from __future__ import division
import mxnet as mx
import numpy as np
import cv2
__all__ = ['FaceGenderage', 'genderage_v1', 'get_genderage']
class FaceGenderage:
def __init__(self, name, download, param_file):
self.name = name
self.download = download
self.param_file = param_file
self.image_size = (112, 112)
if download:
assert param_file
def prepare(self, ctx_id):
if self.param_file:
pos = self.param_file.rfind('-')
prefix = self.param_file[0:pos]
pos2 = self.param_file.rfind('.')
epoch = int(self.param_file[pos + 1:pos2])
sym, arg_params, aux_params = mx.model.load_checkpoint(
prefix, epoch)
all_layers = sym.get_internals()
sym = all_layers['fc1_output']
if ctx_id >= 0:
ctx = mx.gpu(ctx_id)
else:
ctx = mx.cpu()
model = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
data_shape = (1, 3) + self.image_size
model.bind(data_shapes=[('data', data_shape)])
model.set_params(arg_params, aux_params)
#warmup
data = mx.nd.zeros(shape=data_shape)
db = mx.io.DataBatch(data=(data, ))
model.forward(db, is_train=False)
embedding = model.get_outputs()[0].asnumpy()
self.model = model
else:
pass
def get(self, img):
assert self.param_file and self.model
assert img.shape[2] == 3 and img.shape[0:2] == self.image_size
data = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
data = np.transpose(data, (2, 0, 1))
data = np.expand_dims(data, axis=0)
data = mx.nd.array(data)
db = mx.io.DataBatch(data=(data, ))
self.model.forward(db, is_train=False)
ret = self.model.get_outputs()[0].asnumpy()
g = ret[:, 0:2].flatten()
gender = np.argmax(g)
a = ret[:, 2:202].reshape((100, 2))
a = np.argmax(a, axis=1)
age = int(sum(a))
return gender, age
def get_genderage(name, download=True, root='~/.insightface/models', **kwargs):
if not download:
return FaceGenderage(name, False, None)
else:
from .model_store import get_model_file
_file = get_model_file("genderage_%s" % name, root=root)
return FaceGenderage(name, True, _file)
def genderage_v1(**kwargs):
return get_genderage("v1", download=True, **kwargs)

View File

@@ -1,86 +0,0 @@
from __future__ import division
import mxnet as mx
import numpy as np
import cv2
__all__ = [
'FaceRecognition', 'arcface_r100_v1', 'arcface_outofreach_v1',
'arcface_mfn_v1', 'get_arcface'
]
class FaceRecognition:
def __init__(self, name, download, param_file):
self.name = name
self.download = download
self.param_file = param_file
self.image_size = (112, 112)
if download:
assert param_file
def prepare(self, ctx_id):
if self.param_file:
pos = self.param_file.rfind('-')
prefix = self.param_file[0:pos]
pos2 = self.param_file.rfind('.')
epoch = int(self.param_file[pos + 1:pos2])
sym, arg_params, aux_params = mx.model.load_checkpoint(
prefix, epoch)
all_layers = sym.get_internals()
sym = all_layers['fc1_output']
if ctx_id >= 0:
ctx = mx.gpu(ctx_id)
else:
ctx = mx.cpu()
model = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
data_shape = (1, 3) + self.image_size
model.bind(data_shapes=[('data', data_shape)])
model.set_params(arg_params, aux_params)
#warmup
data = mx.nd.zeros(shape=data_shape)
db = mx.io.DataBatch(data=(data, ))
model.forward(db, is_train=False)
embedding = model.get_outputs()[0].asnumpy()
self.model = model
else:
pass
def get_embedding(self, img):
assert self.param_file and self.model
assert img.shape[2] == 3 and img.shape[0:2] == self.image_size
data = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
data = np.transpose(data, (2, 0, 1))
data = np.expand_dims(data, axis=0)
data = mx.nd.array(data)
db = mx.io.DataBatch(data=(data, ))
self.model.forward(db, is_train=False)
embedding = self.model.get_outputs()[0].asnumpy()
return embedding
def compute_sim(self, img1, img2):
emb1 = self.get_embedding(img1).flatten()
emb2 = self.get_embedding(img2).flatten()
from numpy.linalg import norm
sim = np.dot(emb1, emb2) / (norm(emb1) * norm(emb2))
return sim
def get_arcface(name, download=True, root='~/.insightface/models', **kwargs):
if not download:
return FaceRecognition(name, False, None)
else:
from .model_store import get_model_file
_file = get_model_file("arcface_%s" % name, root=root)
return FaceRecognition(name, True, _file)
def arcface_r100_v1(**kwargs):
return get_arcface("r100_v1", download=True, **kwargs)
def arcface_mfn_v1(**kwargs):
return get_arcface("mfn_v1", download=True, **kwargs)
def arcface_outofreach_v1(**kwargs):
return get_arcface("outofreach_v1", download=False, **kwargs)

View File

@@ -22,7 +22,7 @@ _model_sha1 = {
]
}
base_repo_url = 'http://insightface.ai/files/'
base_repo_url = 'https://insightface.ai/files/'
_url_format = '{repo_url}models/{file_name}.zip'

View File

@@ -1,56 +1,59 @@
# pylint: disable=wildcard-import, unused-wildcard-import
"""
This code file mainly comes from https://github.com/dmlc/gluon-cv/blob/master/gluoncv/model_zoo/model_zoo.py
"""
from .face_recognition import *
from .face_detection import *
from .face_genderage import *
#from .face_alignment import *
# -*- coding: utf-8 -*-
# @Organization : insightface.ai
# @Author : Jia Guo
# @Time : 2021-05-04
# @Function :
__all__ = ['get_model', 'get_model_list']
import os
import os.path as osp
import glob
import onnxruntime
from .arcface_onnx import *
from .scrfd import *
_models = {
'arcface_r100_v1': arcface_r100_v1,
#'arcface_mfn_v1': arcface_mfn_v1,
#'arcface_outofreach_v1': arcface_outofreach_v1,
'retinaface_r50_v1': retinaface_r50_v1,
'retinaface_mnet025_v1': retinaface_mnet025_v1,
'retinaface_mnet025_v2': retinaface_mnet025_v2,
'genderage_v1': genderage_v1,
}
#__all__ = ['get_model', 'get_model_list', 'get_arcface_onnx', 'get_scrfd']
__all__ = ['get_model']
class ModelRouter:
def __init__(self, onnx_file):
self.onnx_file = onnx_file
def get_model(self):
session = onnxruntime.InferenceSession(self.onnx_file, None)
input_cfg = session.get_inputs()[0]
input_shape = input_cfg.shape
outputs = session.get_outputs()
#print(input_shape)
if len(outputs)>=5:
return SCRFD(model_file=self.onnx_file, session=session)
elif input_shape[2]==112 and input_shape[3]==112:
return ArcFaceONNX(model_file=self.onnx_file, session=session)
else:
raise RuntimeError('error on model routing')
def find_onnx_file(dir_path):
if not os.path.exists(dir_path):
return None
paths = glob.glob("%s/*.onnx" % dir_path)
if len(paths) == 0:
return None
paths = sorted(paths)
return paths[-1]
def get_model(name, **kwargs):
"""Returns a pre-defined model by name
root = kwargs.get('root', '~/.insightface/models')
root = os.path.expanduser(root)
if not name.endswith('.onnx'):
model_dir = os.path.join(root, name)
model_file = find_onnx_file(model_dir)
if model_file is None:
return None
else:
model_file = name
assert osp.isfile(model_file), 'model should be file'
router = ModelRouter(name)
model = router.get_model()
#print('get-model for ', name,' : ', model.taskname)
return model
Parameters
----------
name : str
Name of the model.
root : str, default '~/.insightface/models'
Location for keeping the model parameters.
Returns
-------
Model
The model.
"""
name = name.lower()
if name not in _models:
err_str = '"%s" is not among the following model list:\n\t' % (name)
err_str += '%s' % ('\n\t'.join(sorted(_models.keys())))
raise ValueError(err_str)
net = _models[name](**kwargs)
return net
def get_model_list():
"""Get the entire list of model names in model_zoo.
Returns
-------
list of str
Entire list of model names in model_zoo.
"""
return sorted(_models.keys())

View File

@@ -0,0 +1,353 @@
# -*- coding: utf-8 -*-
# @Organization : insightface.ai
# @Author : Jia Guo
# @Time : 2021-05-04
# @Function :
from __future__ import division
import datetime
import numpy as np
import onnx
import onnxruntime
import os
import os.path as osp
import cv2
import sys
def softmax(z):
assert len(z.shape) == 2
s = np.max(z, axis=1)
s = s[:, np.newaxis] # necessary step to do broadcasting
e_x = np.exp(z - s)
div = np.sum(e_x, axis=1)
div = div[:, np.newaxis] # dito
return e_x / div
def distance2bbox(points, distance, max_shape=None):
"""Decode distance prediction to bounding box.
Args:
points (Tensor): Shape (n, 2), [x, y].
distance (Tensor): Distance from the given point to 4
boundaries (left, top, right, bottom).
max_shape (tuple): Shape of the image.
Returns:
Tensor: Decoded bboxes.
"""
x1 = points[:, 0] - distance[:, 0]
y1 = points[:, 1] - distance[:, 1]
x2 = points[:, 0] + distance[:, 2]
y2 = points[:, 1] + distance[:, 3]
if max_shape is not None:
x1 = x1.clamp(min=0, max=max_shape[1])
y1 = y1.clamp(min=0, max=max_shape[0])
x2 = x2.clamp(min=0, max=max_shape[1])
y2 = y2.clamp(min=0, max=max_shape[0])
return np.stack([x1, y1, x2, y2], axis=-1)
def distance2kps(points, distance, max_shape=None):
"""Decode distance prediction to bounding box.
Args:
points (Tensor): Shape (n, 2), [x, y].
distance (Tensor): Distance from the given point to 4
boundaries (left, top, right, bottom).
max_shape (tuple): Shape of the image.
Returns:
Tensor: Decoded bboxes.
"""
preds = []
for i in range(0, distance.shape[1], 2):
px = points[:, i%2] + distance[:, i]
py = points[:, i%2+1] + distance[:, i+1]
if max_shape is not None:
px = px.clamp(min=0, max=max_shape[1])
py = py.clamp(min=0, max=max_shape[0])
preds.append(px)
preds.append(py)
return np.stack(preds, axis=-1)
class SCRFD:
def __init__(self, model_file=None, session=None):
import onnxruntime
self.model_file = model_file
self.session = session
self.taskname = 'detection'
if self.session is None:
assert self.model_file is not None
assert osp.exists(self.model_file)
self.session = onnxruntime.InferenceSession(self.model_file, None)
self.center_cache = {}
self.nms_threshold = 0.4
self._init_vars()
def _init_vars(self):
input_cfg = self.session.get_inputs()[0]
input_shape = input_cfg.shape
#print(input_shape)
if isinstance(input_shape[2], str):
self.input_size = None
else:
self.input_size = tuple(input_shape[2:4][::-1])
#print('image_size:', self.image_size)
input_name = input_cfg.name
outputs = self.session.get_outputs()
output_names = []
for o in outputs:
output_names.append(o.name)
self.input_name = input_name
self.output_names = output_names
#print(self.output_names)
#assert len(outputs)==10 or len(outputs)==15
self.use_kps = False
self._anchor_ratio = 1.0
self._num_anchors = 1
if len(outputs)==6:
self.fmc = 3
self._feat_stride_fpn = [8, 16, 32]
self._num_anchors = 2
elif len(outputs)==9:
self.fmc = 3
self._feat_stride_fpn = [8, 16, 32]
self._num_anchors = 2
self.use_kps = True
elif len(outputs)==10:
self.fmc = 5
self._feat_stride_fpn = [8, 16, 32, 64, 128]
self._num_anchors = 1
elif len(outputs)==15:
self.fmc = 5
self._feat_stride_fpn = [8, 16, 32, 64, 128]
self._num_anchors = 1
self.use_kps = True
def prepare(self, ctx_id, **kwargs):
if ctx_id<0:
self.session.set_providers(['CPUExecutionProvider'])
nms_threshold = kwargs.get('nms_threshold', None)
if nms_threshold is not None:
self.nms_threshold = nms_threshold
input_size = kwargs.get('input_size', None)
if input_size is not None:
if self.input_size is not None:
print('warning: det_size is already set in scrfd model, ignore')
else:
self.input_size = input_size
#for keyword in ['_selfgen', '_shape', '.']:
# pos = self.model_file.rfind(keyword)
# if pos>=0:
# break
#model_prefix = self.model_file[0:pos]
#new_model_file = model_prefix+"_selfgen_shape%dx%d.onnx"%(input_size[1], input_size[0])
#if not osp.exists(new_model_file):
# model = onnx.load(self.model_file)
# from onnxsim import simplify
# input = model.graph.input[0]
# #model.graph.input[0].type.tensor_type.elem_type = 0
# #print(input.type.tensor_type.elem_type)
# #print(input.type.tensor_type.shape.dim[0].dim_param.__class__)
# #print(input.type.tensor_type.shape.dim[0])
# #input.type.tensor_type.shape.dim[2].dim_param = input_size[1]
# #input.type.tensor_type.shape.dim[3].dim_param = input_size[0]
# input.type.tensor_type.shape.dim[2].dim_value = input_size[1]
# input.type.tensor_type.shape.dim[3].dim_value = input_size[0]
# model, check = simplify(model)
# assert check, "Simplified ONNX model could not be validated"
# onnx.save(model, new_model_file)
# print('saved new onnx scrfd model:', new_model_file)
#self.model_file = new_model_file
#self.session = onnxruntime.InferenceSession(self.model_file, None)
#self._init_vars()
def forward(self, img, threshold):
scores_list = []
bboxes_list = []
kpss_list = []
input_size = tuple(img.shape[0:2][::-1])
blob = cv2.dnn.blobFromImage(img, 1.0/128, input_size, (127.5, 127.5, 127.5), swapRB=True)
net_outs = self.session.run(self.output_names, {self.input_name : blob})
input_height = blob.shape[2]
input_width = blob.shape[3]
fmc = self.fmc
for idx, stride in enumerate(self._feat_stride_fpn):
scores = net_outs[idx]
bbox_preds = net_outs[idx+fmc]
bbox_preds = bbox_preds * stride
if self.use_kps:
kps_preds = net_outs[idx+fmc*2] * stride
height = input_height // stride
width = input_width // stride
K = height * width
key = (height, width, stride)
if key in self.center_cache:
anchor_centers = self.center_cache[key]
else:
#solution-1, c style:
#anchor_centers = np.zeros( (height, width, 2), dtype=np.float32 )
#for i in range(height):
# anchor_centers[i, :, 1] = i
#for i in range(width):
# anchor_centers[:, i, 0] = i
#solution-2:
#ax = np.arange(width, dtype=np.float32)
#ay = np.arange(height, dtype=np.float32)
#xv, yv = np.meshgrid(np.arange(width), np.arange(height))
#anchor_centers = np.stack([xv, yv], axis=-1).astype(np.float32)
#solution-3:
anchor_centers = np.stack(np.mgrid[:height, :width][::-1], axis=-1).astype(np.float32)
#print(anchor_centers.shape)
anchor_centers = (anchor_centers * stride).reshape( (-1, 2) )
if self._num_anchors>1:
anchor_centers = np.stack([anchor_centers]*self._num_anchors, axis=1).reshape( (-1,2) )
if len(self.center_cache)<100:
self.center_cache[key] = anchor_centers
pos_inds = np.where(scores>=threshold)[0]
bboxes = distance2bbox(anchor_centers, bbox_preds)
pos_scores = scores[pos_inds]
pos_bboxes = bboxes[pos_inds]
scores_list.append(pos_scores)
bboxes_list.append(pos_bboxes)
if self.use_kps:
kpss = distance2kps(anchor_centers, kps_preds)
#kpss = kps_preds
kpss = kpss.reshape( (kpss.shape[0], -1, 2) )
pos_kpss = kpss[pos_inds]
kpss_list.append(pos_kpss)
return scores_list, bboxes_list, kpss_list
def detect(self, img, threshold=0.5, input_size = None, max_num=0, metric='default'):
assert input_size is not None or self.input_size is not None
input_size = self.input_size if input_size is None else input_size
im_ratio = float(img.shape[0]) / img.shape[1]
model_ratio = float(input_size[1]) / input_size[0]
if im_ratio>model_ratio:
new_height = input_size[1]
new_width = int(new_height / im_ratio)
else:
new_width = input_size[0]
new_height = int(new_width * im_ratio)
det_scale = float(new_height) / img.shape[0]
resized_img = cv2.resize(img, (new_width, new_height))
det_img = np.zeros( (input_size[1], input_size[0], 3), dtype=np.uint8 )
det_img[:new_height, :new_width, :] = resized_img
scores_list, bboxes_list, kpss_list = self.forward(det_img, threshold)
scores = np.vstack(scores_list)
scores_ravel = scores.ravel()
order = scores_ravel.argsort()[::-1]
bboxes = np.vstack(bboxes_list) / det_scale
if self.use_kps:
kpss = np.vstack(kpss_list) / det_scale
pre_det = np.hstack((bboxes, scores)).astype(np.float32, copy=False)
pre_det = pre_det[order, :]
keep = self.nms(pre_det)
det = pre_det[keep, :]
if self.use_kps:
kpss = kpss[order,:,:]
kpss = kpss[keep,:,:]
else:
kpss = None
if max_num > 0 and det.shape[0] > max_num:
area = (det[:, 2] - det[:, 0]) * (det[:, 3] -
det[:, 1])
img_center = img.shape[0] // 2, img.shape[1] // 2
offsets = np.vstack([
(det[:, 0] + det[:, 2]) / 2 - img_center[1],
(det[:, 1] + det[:, 3]) / 2 - img_center[0]
])
offset_dist_squared = np.sum(np.power(offsets, 2.0), 0)
if metric=='max':
values = area
else:
values = area - offset_dist_squared * 2.0 # some extra weight on the centering
bindex = np.argsort(
values)[::-1] # some extra weight on the centering
bindex = bindex[0:max_num]
bboxes = bboxes[bindex, :]
if kpss is not None:
kpss = kpss[bindex, :]
return det, kpss
def nms(self, dets):
thresh = self.nms_threshold
x1 = dets[:, 0]
y1 = dets[:, 1]
x2 = dets[:, 2]
y2 = dets[:, 3]
scores = dets[:, 4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
inds = np.where(ovr <= thresh)[0]
order = order[inds + 1]
return keep
def get_scrfd(name, download=False, root='~/.insightface/models', **kwargs):
if not download:
assert os.path.exists(name)
return SCRFD(name)
else:
from .model_store import get_model_file
_file = get_model_file("scrfd_%s" % name, root=root)
return SCRFD(_file)
def scrfd_2p5gkps(**kwargs):
return get_scrfd("2p5gkps", download=True, **kwargs)
if __name__ == '__main__':
import glob
detector = SCRFD(model_file='./det.onnx')
detector.prepare(-1)
img_paths = ['tests/data/t1.jpg']
for img_path in img_paths:
img = cv2.imread(img_path)
for _ in range(1):
ta = datetime.datetime.now()
#bboxes, kpss = detector.detect(img, 0.5, input_size = (640, 640))
bboxes, kpss = detector.detect(img, 0.5)
tb = datetime.datetime.now()
print('all cost:', (tb-ta).total_seconds()*1000)
print(img_path, bboxes.shape)
if kpss is not None:
print(kpss.shape)
for i in range(bboxes.shape[0]):
bbox = bboxes[i]
x1,y1,x2,y2,score = bbox.astype(np.int)
cv2.rectangle(img, (x1,y1) , (x2,y2) , (255,0,0) , 2)
if kpss is not None:
kps = kpss[i]
for kp in kps:
kp = kp.astype(np.int)
cv2.circle(img, tuple(kp) , 1, (0,0,255) , 2)
filename = img_path.split('/')[-1]
print('output:', filename)
cv2.imwrite('./outputs/%s'%filename, img)