mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-17 14:26:08 +00:00
pip package
This commit is contained in:
3
python-package/README.md
Normal file
3
python-package/README.md
Normal file
@@ -0,0 +1,3 @@
|
||||
InsightFace.ai README
|
||||
|
||||
|
||||
28
python-package/insightface/__init__.py
Normal file
28
python-package/insightface/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# coding: utf-8
|
||||
# pylint: disable=wrong-import-position
|
||||
"""InsightFace: A Face Analysis Toolkit."""
|
||||
from __future__ import absolute_import
|
||||
|
||||
# mxnet version check
|
||||
#mx_version = '1.4.0'
|
||||
try:
|
||||
import mxnet as mx
|
||||
#from distutils.version import LooseVersion
|
||||
#if LooseVersion(mx.__version__) < LooseVersion(mx_version):
|
||||
# msg = (
|
||||
# "Legacy mxnet-mkl=={} detected, some new modules may not work properly. "
|
||||
# "mxnet-mkl>={} is required. You can use pip to upgrade mxnet "
|
||||
# "`pip install mxnet-mkl --pre --upgrade` "
|
||||
# "or `pip install mxnet-cu90mkl --pre --upgrade`").format(mx.__version__, mx_version)
|
||||
# raise ImportError(msg)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Unable to import dependency mxnet. "
|
||||
"A quick tip is to install via `pip install mxnet-mkl/mxnet-cu90mkl --pre`. ")
|
||||
|
||||
__version__ = '0.1.2'
|
||||
|
||||
from . import model_zoo
|
||||
#from . import utils
|
||||
#from . import analysis
|
||||
|
||||
1
python-package/insightface/model_zoo/__init__.py
Normal file
1
python-package/insightface/model_zoo/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .model_zoo import get_model, get_model_list
|
||||
422
python-package/insightface/model_zoo/face_detection.py
Normal file
422
python-package/insightface/model_zoo/face_detection.py
Normal file
@@ -0,0 +1,422 @@
|
||||
from __future__ import division
|
||||
import mxnet as mx
|
||||
import numpy as np
|
||||
import mxnet.ndarray as nd
|
||||
|
||||
__all__ = ['FaceDetector',
|
||||
'retinaface_r50_v1',
|
||||
'retinaface_mnet025_v1',
|
||||
'get_retinaface']
|
||||
|
||||
def _whctrs(anchor):
|
||||
"""
|
||||
Return width, height, x center, and y center for an anchor (window).
|
||||
"""
|
||||
|
||||
w = anchor[2] - anchor[0] + 1
|
||||
h = anchor[3] - anchor[1] + 1
|
||||
x_ctr = anchor[0] + 0.5 * (w - 1)
|
||||
y_ctr = anchor[1] + 0.5 * (h - 1)
|
||||
return w, h, x_ctr, y_ctr
|
||||
|
||||
|
||||
def _mkanchors(ws, hs, x_ctr, y_ctr):
|
||||
"""
|
||||
Given a vector of widths (ws) and heights (hs) around a center
|
||||
(x_ctr, y_ctr), output a set of anchors (windows).
|
||||
"""
|
||||
|
||||
ws = ws[:, np.newaxis]
|
||||
hs = hs[:, np.newaxis]
|
||||
anchors = np.hstack((x_ctr - 0.5 * (ws - 1),
|
||||
y_ctr - 0.5 * (hs - 1),
|
||||
x_ctr + 0.5 * (ws - 1),
|
||||
y_ctr + 0.5 * (hs - 1)))
|
||||
return anchors
|
||||
|
||||
def _ratio_enum(anchor, ratios):
|
||||
"""
|
||||
Enumerate a set of anchors for each aspect ratio wrt an anchor.
|
||||
"""
|
||||
|
||||
w, h, x_ctr, y_ctr = _whctrs(anchor)
|
||||
size = w * h
|
||||
size_ratios = size / ratios
|
||||
ws = np.round(np.sqrt(size_ratios))
|
||||
hs = np.round(ws * ratios)
|
||||
anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
|
||||
return anchors
|
||||
|
||||
|
||||
def _scale_enum(anchor, scales):
|
||||
"""
|
||||
Enumerate a set of anchors for each scale wrt an anchor.
|
||||
"""
|
||||
|
||||
w, h, x_ctr, y_ctr = _whctrs(anchor)
|
||||
ws = w * scales
|
||||
hs = h * scales
|
||||
anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
|
||||
return anchors
|
||||
|
||||
def anchors_plane(height, width, stride, base_anchors):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
height: height of plane
|
||||
width: width of plane
|
||||
stride: stride ot the original image
|
||||
anchors_base: (A, 4) a base set of anchors
|
||||
Returns
|
||||
-------
|
||||
all_anchors: (height, width, A, 4) ndarray of anchors spreading over the plane
|
||||
"""
|
||||
A = base_anchors.shape[0]
|
||||
all_anchors = np.zeros((height, width, A, 4), dtype=np.float32)
|
||||
for iw in range(width):
|
||||
sw = iw * stride
|
||||
for ih in range(height):
|
||||
sh = ih * stride
|
||||
for k in range(A):
|
||||
all_anchors[ih, iw, k, 0] = base_anchors[k, 0] + sw
|
||||
all_anchors[ih, iw, k, 1] = base_anchors[k, 1] + sh
|
||||
all_anchors[ih, iw, k, 2] = base_anchors[k, 2] + sw
|
||||
all_anchors[ih, iw, k, 3] = base_anchors[k, 3] + sh
|
||||
return all_anchors
|
||||
|
||||
def generate_anchors(base_size=16, ratios=[0.5, 1, 2],
|
||||
scales=2 ** np.arange(3, 6), stride=16):
|
||||
"""
|
||||
Generate anchor (reference) windows by enumerating aspect ratios X
|
||||
scales wrt a reference (0, 0, 15, 15) window.
|
||||
"""
|
||||
|
||||
base_anchor = np.array([1, 1, base_size, base_size]) - 1
|
||||
ratio_anchors = _ratio_enum(base_anchor, ratios)
|
||||
anchors = np.vstack([_scale_enum(ratio_anchors[i, :], scales)
|
||||
for i in range(ratio_anchors.shape[0])])
|
||||
return anchors
|
||||
|
||||
def generate_anchors_fpn(cfg):
|
||||
"""
|
||||
Generate anchor (reference) windows by enumerating aspect ratios X
|
||||
scales wrt a reference (0, 0, 15, 15) window.
|
||||
"""
|
||||
RPN_FEAT_STRIDE = []
|
||||
for k in cfg:
|
||||
RPN_FEAT_STRIDE.append( int(k) )
|
||||
RPN_FEAT_STRIDE = sorted(RPN_FEAT_STRIDE, reverse=True)
|
||||
anchors = []
|
||||
for k in RPN_FEAT_STRIDE:
|
||||
v = cfg[str(k)]
|
||||
bs = v['BASE_SIZE']
|
||||
__ratios = np.array(v['RATIOS'])
|
||||
__scales = np.array(v['SCALES'])
|
||||
stride = int(k)
|
||||
#print('anchors_fpn', bs, __ratios, __scales, file=sys.stderr)
|
||||
r = generate_anchors(bs, __ratios, __scales, stride)
|
||||
#print('anchors_fpn', r.shape, file=sys.stderr)
|
||||
anchors.append(r)
|
||||
|
||||
return anchors
|
||||
|
||||
def clip_pad(tensor, pad_shape):
|
||||
"""
|
||||
Clip boxes of the pad area.
|
||||
:param tensor: [n, c, H, W]
|
||||
:param pad_shape: [h, w]
|
||||
:return: [n, c, h, w]
|
||||
"""
|
||||
H, W = tensor.shape[2:]
|
||||
h, w = pad_shape
|
||||
|
||||
if h < H or w < W:
|
||||
tensor = tensor[:, :, :h, :w].copy()
|
||||
|
||||
return tensor
|
||||
|
||||
def bbox_pred(boxes, box_deltas):
|
||||
"""
|
||||
Transform the set of class-agnostic boxes into class-specific boxes
|
||||
by applying the predicted offsets (box_deltas)
|
||||
:param boxes: !important [N 4]
|
||||
:param box_deltas: [N, 4 * num_classes]
|
||||
:return: [N 4 * num_classes]
|
||||
"""
|
||||
if boxes.shape[0] == 0:
|
||||
return np.zeros((0, box_deltas.shape[1]))
|
||||
|
||||
boxes = boxes.astype(np.float, copy=False)
|
||||
widths = boxes[:, 2] - boxes[:, 0] + 1.0
|
||||
heights = boxes[:, 3] - boxes[:, 1] + 1.0
|
||||
ctr_x = boxes[:, 0] + 0.5 * (widths - 1.0)
|
||||
ctr_y = boxes[:, 1] + 0.5 * (heights - 1.0)
|
||||
|
||||
dx = box_deltas[:, 0:1]
|
||||
dy = box_deltas[:, 1:2]
|
||||
dw = box_deltas[:, 2:3]
|
||||
dh = box_deltas[:, 3:4]
|
||||
|
||||
pred_ctr_x = dx * widths[:, np.newaxis] + ctr_x[:, np.newaxis]
|
||||
pred_ctr_y = dy * heights[:, np.newaxis] + ctr_y[:, np.newaxis]
|
||||
pred_w = np.exp(dw) * widths[:, np.newaxis]
|
||||
pred_h = np.exp(dh) * heights[:, np.newaxis]
|
||||
|
||||
pred_boxes = np.zeros(box_deltas.shape)
|
||||
# x1
|
||||
pred_boxes[:, 0:1] = pred_ctr_x - 0.5 * (pred_w - 1.0)
|
||||
# y1
|
||||
pred_boxes[:, 1:2] = pred_ctr_y - 0.5 * (pred_h - 1.0)
|
||||
# x2
|
||||
pred_boxes[:, 2:3] = pred_ctr_x + 0.5 * (pred_w - 1.0)
|
||||
# y2
|
||||
pred_boxes[:, 3:4] = pred_ctr_y + 0.5 * (pred_h - 1.0)
|
||||
|
||||
if box_deltas.shape[1]>4:
|
||||
pred_boxes[:,4:] = box_deltas[:,4:]
|
||||
|
||||
return pred_boxes
|
||||
|
||||
def landmark_pred(boxes, landmark_deltas):
|
||||
if boxes.shape[0] == 0:
|
||||
return np.zeros((0, landmark_deltas.shape[1]))
|
||||
boxes = boxes.astype(np.float, copy=False)
|
||||
widths = boxes[:, 2] - boxes[:, 0] + 1.0
|
||||
heights = boxes[:, 3] - boxes[:, 1] + 1.0
|
||||
ctr_x = boxes[:, 0] + 0.5 * (widths - 1.0)
|
||||
ctr_y = boxes[:, 1] + 0.5 * (heights - 1.0)
|
||||
pred = landmark_deltas.copy()
|
||||
for i in range(5):
|
||||
pred[:,i,0] = landmark_deltas[:,i,0]*widths + ctr_x
|
||||
pred[:,i,1] = landmark_deltas[:,i,1]*heights + ctr_y
|
||||
return pred
|
||||
|
||||
class FaceDetector:
|
||||
def __init__(self, param_file, rac):
|
||||
self.param_file = param_file
|
||||
self.rac = rac
|
||||
self.default_image_size = (480, 640)
|
||||
|
||||
def prepare(self, ctx_id, nms=0.4, fix_image_size=None):
|
||||
pos = self.param_file.rfind('-')
|
||||
prefix = self.param_file[0:pos]
|
||||
pos2 = self.param_file.rfind('.')
|
||||
epoch = int(self.param_file[pos+1:pos2])
|
||||
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
|
||||
if ctx_id>=0:
|
||||
ctx = mx.gpu(ctx_id)
|
||||
else:
|
||||
ctx = mx.cpu()
|
||||
model = mx.mod.Module(symbol=sym, context=ctx, label_names = None)
|
||||
if fix_image_size is not None:
|
||||
data_shape = (1,3)+fix_image_size
|
||||
else:
|
||||
data_shape = (1,3)+self.default_image_size
|
||||
model.bind(data_shapes=[('data', data_shape)])
|
||||
model.set_params(arg_params, aux_params)
|
||||
#warmup
|
||||
data = mx.nd.zeros(shape=data_shape)
|
||||
db = mx.io.DataBatch(data=(data,))
|
||||
model.forward(db, is_train=False)
|
||||
out = model.get_outputs()[0].asnumpy()
|
||||
self.model = model
|
||||
self.nms_threshold = nms
|
||||
|
||||
_ratio = (1.,)
|
||||
fmc = 3
|
||||
if self.rac=='net3':
|
||||
_ratio = (1.,)
|
||||
elif network=='net5': #retinaface
|
||||
fmc = 5
|
||||
else:
|
||||
assert False, 'rac setting error %s'%self.rac
|
||||
|
||||
if fmc==3:
|
||||
self._feat_stride_fpn = [32, 16, 8]
|
||||
self.anchor_cfg = {
|
||||
'32': {'SCALES': (32,16), 'BASE_SIZE': 16, 'RATIOS': _ratio, 'ALLOWED_BORDER': 9999},
|
||||
'16': {'SCALES': (8,4), 'BASE_SIZE': 16, 'RATIOS': _ratio, 'ALLOWED_BORDER': 9999},
|
||||
'8': {'SCALES': (2,1), 'BASE_SIZE': 16, 'RATIOS': _ratio, 'ALLOWED_BORDER': 9999},
|
||||
}
|
||||
elif fmc==5:
|
||||
self._feat_stride_fpn = [64, 32, 16, 8, 4]
|
||||
self.anchor_cfg = {}
|
||||
_ass = 2.0**(1.0/3)
|
||||
_basescale = 1.0
|
||||
for _stride in [4, 8, 16, 32, 64]:
|
||||
key = str(_stride)
|
||||
value = {'BASE_SIZE': 16, 'RATIOS': _ratio, 'ALLOWED_BORDER': 9999}
|
||||
scales = []
|
||||
for _ in range(3):
|
||||
scales.append(_basescale)
|
||||
_basescale *= _ass
|
||||
value['SCALES'] = tuple(scales)
|
||||
self.anchor_cfg[key] = value
|
||||
|
||||
print(self._feat_stride_fpn, self.anchor_cfg)
|
||||
self.use_landmarks = False
|
||||
if len(sym)//len(self._feat_stride_fpn)==3:
|
||||
self.use_landmarks = True
|
||||
print('use_landmarks', self.use_landmarks)
|
||||
self.fpn_keys = []
|
||||
|
||||
for s in self._feat_stride_fpn:
|
||||
self.fpn_keys.append('stride%s'%s)
|
||||
|
||||
self._anchors_fpn = dict(zip(self.fpn_keys, generate_anchors_fpn(cfg=self.anchor_cfg)))
|
||||
for k in self._anchors_fpn:
|
||||
v = self._anchors_fpn[k].astype(np.float32)
|
||||
self._anchors_fpn[k] = v
|
||||
self.anchor_plane_cache = {}
|
||||
if fix_image_size is None:
|
||||
self.anchor_plane_cache = None
|
||||
|
||||
self._num_anchors = dict(zip(self.fpn_keys, [anchors.shape[0] for anchors in self._anchors_fpn.values()]))
|
||||
|
||||
def detect(self, img, threshold=0.5, scale=1.0):
|
||||
proposals_list = []
|
||||
scores_list = []
|
||||
landmarks_list = []
|
||||
if scale==1.0:
|
||||
im = img
|
||||
else:
|
||||
im = cv2.resize(img, None, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
|
||||
im_info = [im.shape[0], im.shape[1]]
|
||||
im_tensor = np.zeros((1, 3, im.shape[0], im.shape[1]))
|
||||
for i in range(3):
|
||||
im_tensor[0, i, :, :] = im[:, :, 2 - i]
|
||||
data = nd.array(im_tensor)
|
||||
db = mx.io.DataBatch(data=(data,), provide_data=[('data', data.shape)])
|
||||
self.model.forward(db, is_train=False)
|
||||
net_out = self.model.get_outputs()
|
||||
for _idx,s in enumerate(self._feat_stride_fpn):
|
||||
_key = 'stride%s'%s
|
||||
stride = int(s)
|
||||
if self.use_landmarks:
|
||||
idx = _idx*3
|
||||
else:
|
||||
idx = _idx*2
|
||||
scores = net_out[idx].asnumpy()
|
||||
scores = scores[:, self._num_anchors['stride%s'%s]:, :, :]
|
||||
idx+=1
|
||||
bbox_deltas = net_out[idx].asnumpy()
|
||||
|
||||
height, width = bbox_deltas.shape[2], bbox_deltas.shape[3]
|
||||
A = self._num_anchors['stride%s'%s]
|
||||
K = height * width
|
||||
if self.anchor_plane_cache is not None:
|
||||
key = (height, width, stride)
|
||||
if key in self.anchor_plane_cache:
|
||||
anchors = self.anchor_plane_cache[key]
|
||||
else:
|
||||
anchors_fpn = self._anchors_fpn['stride%s'%s]
|
||||
anchors = anchors_plane(height, width, stride, anchors_fpn)
|
||||
anchors = anchors.reshape((K * A, 4))
|
||||
self.anchor_plane_cache[key] = anchors
|
||||
else:
|
||||
anchors_fpn = self._anchors_fpn['stride%s'%s]
|
||||
anchors = anchors_plane(height, width, stride, anchors_fpn)
|
||||
anchors = anchors.reshape((K * A, 4))
|
||||
|
||||
scores = clip_pad(scores, (height, width))
|
||||
scores = scores.transpose((0, 2, 3, 1)).reshape((-1, 1))
|
||||
|
||||
bbox_deltas = clip_pad(bbox_deltas, (height, width))
|
||||
bbox_deltas = bbox_deltas.transpose((0, 2, 3, 1))
|
||||
bbox_pred_len = bbox_deltas.shape[3]//A
|
||||
bbox_deltas = bbox_deltas.reshape((-1, bbox_pred_len))
|
||||
|
||||
proposals = bbox_pred(anchors, bbox_deltas)
|
||||
#proposals = clip_boxes(proposals, im_info[:2])
|
||||
|
||||
|
||||
scores_ravel = scores.ravel()
|
||||
order = np.where(scores_ravel>=threshold)[0]
|
||||
proposals = proposals[order, :]
|
||||
scores = scores[order]
|
||||
|
||||
proposals[:,0:4] /= scale
|
||||
|
||||
proposals_list.append(proposals)
|
||||
scores_list.append(scores)
|
||||
|
||||
if self.use_landmarks:
|
||||
idx+=1
|
||||
landmark_deltas = net_out[idx].asnumpy()
|
||||
landmark_deltas = clip_pad(landmark_deltas, (height, width))
|
||||
landmark_pred_len = landmark_deltas.shape[1]//A
|
||||
landmark_deltas = landmark_deltas.transpose((0, 2, 3, 1)).reshape((-1, 5, landmark_pred_len//5))
|
||||
#print(landmark_deltas.shape, landmark_deltas)
|
||||
landmarks = landmark_pred(anchors, landmark_deltas)
|
||||
landmarks = landmarks[order, :]
|
||||
|
||||
landmarks[:,:,0:2] /= scale
|
||||
landmarks_list.append(landmarks)
|
||||
|
||||
proposals = np.vstack(proposals_list)
|
||||
landmarks = None
|
||||
if proposals.shape[0]==0:
|
||||
if self.use_landmarks:
|
||||
landmarks = np.zeros( (0,5,2) )
|
||||
return np.zeros( (0,5) ), landmarks
|
||||
scores = np.vstack(scores_list)
|
||||
scores_ravel = scores.ravel()
|
||||
order = scores_ravel.argsort()[::-1]
|
||||
proposals = proposals[order, :]
|
||||
scores = scores[order]
|
||||
if self.use_landmarks:
|
||||
landmarks = np.vstack(landmarks_list)
|
||||
landmarks = landmarks[order].astype(np.float32, copy=False)
|
||||
|
||||
pre_det = np.hstack((proposals[:,0:4], scores)).astype(np.float32, copy=False)
|
||||
keep = self.nms(pre_det)
|
||||
det = np.hstack( (pre_det, proposals[:,4:]) )
|
||||
det = det[keep, :]
|
||||
if self.use_landmarks:
|
||||
landmarks = landmarks[keep]
|
||||
|
||||
return det, landmarks
|
||||
|
||||
def nms(self, dets):
|
||||
thresh = self.nms_threshold
|
||||
x1 = dets[:, 0]
|
||||
y1 = dets[:, 1]
|
||||
x2 = dets[:, 2]
|
||||
y2 = dets[:, 3]
|
||||
scores = dets[:, 4]
|
||||
|
||||
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||
order = scores.argsort()[::-1]
|
||||
|
||||
keep = []
|
||||
while order.size > 0:
|
||||
i = order[0]
|
||||
keep.append(i)
|
||||
xx1 = np.maximum(x1[i], x1[order[1:]])
|
||||
yy1 = np.maximum(y1[i], y1[order[1:]])
|
||||
xx2 = np.minimum(x2[i], x2[order[1:]])
|
||||
yy2 = np.minimum(y2[i], y2[order[1:]])
|
||||
|
||||
w = np.maximum(0.0, xx2 - xx1 + 1)
|
||||
h = np.maximum(0.0, yy2 - yy1 + 1)
|
||||
inter = w * h
|
||||
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
||||
|
||||
inds = np.where(ovr <= thresh)[0]
|
||||
order = order[inds + 1]
|
||||
|
||||
return keep
|
||||
|
||||
|
||||
def get_retinaface(name, rac='net3',
|
||||
root='~/.insightface/models', **kwargs):
|
||||
from .model_store import get_model_file
|
||||
_file = get_model_file("retinaface_%s"%name, root=root)
|
||||
return FaceDetector(_file, rac)
|
||||
|
||||
def retinaface_r50_v1(**kwargs):
|
||||
return get_retinaface("r50_v1", rac='net3', **kwargs)
|
||||
|
||||
def retinaface_mnet025_v1(**kwargs):
|
||||
return get_retinaface("mnet025_v1", rac='net3', **kwargs)
|
||||
|
||||
83
python-package/insightface/model_zoo/face_recognition.py
Normal file
83
python-package/insightface/model_zoo/face_recognition.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from __future__ import division
|
||||
import mxnet as mx
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
__all__ = ['FaceRecognition',
|
||||
'arcface_r100_v1', 'arcface_outofreach_v1', 'arcface_mfn_v1',
|
||||
'get_arcface']
|
||||
|
||||
|
||||
class FaceRecognition:
|
||||
def __init__(self, name, download, param_file):
|
||||
self.name = name
|
||||
self.download = download
|
||||
self.param_file = param_file
|
||||
self.image_size = (112, 112)
|
||||
if download:
|
||||
assert param_file
|
||||
|
||||
def prepare(self, ctx_id):
|
||||
if self.param_file:
|
||||
pos = self.param_file.rfind('-')
|
||||
prefix = self.param_file[0:pos]
|
||||
pos2 = self.param_file.rfind('.')
|
||||
epoch = int(self.param_file[pos+1:pos2])
|
||||
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
|
||||
all_layers = sym.get_internals()
|
||||
sym = all_layers['fc1_output']
|
||||
if ctx_id>=0:
|
||||
ctx = mx.gpu(ctx_id)
|
||||
else:
|
||||
ctx = mx.cpu()
|
||||
model = mx.mod.Module(symbol=sym, context=ctx, label_names = None)
|
||||
data_shape = (1,3)+self.image_size
|
||||
model.bind(data_shapes=[('data', data_shape)])
|
||||
model.set_params(arg_params, aux_params)
|
||||
#warmup
|
||||
data = mx.nd.zeros(shape=data_shape)
|
||||
db = mx.io.DataBatch(data=(data,))
|
||||
model.forward(db, is_train=False)
|
||||
embedding = model.get_outputs()[0].asnumpy()
|
||||
self.model = model
|
||||
else:
|
||||
pass
|
||||
|
||||
def get_embedding(self, img):
|
||||
assert self.param_file and self.model
|
||||
assert img.shape[2]==3 and img.shape[0:2]==self.image_size
|
||||
data = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
data = np.transpose(data, (2,0,1))
|
||||
data = np.expand_dims(data, axis=0)
|
||||
data = mx.nd.array(data)
|
||||
db = mx.io.DataBatch(data=(data,))
|
||||
self.model.forward(db, is_train=False)
|
||||
embedding = self.model.get_outputs()[0].asnumpy()
|
||||
return embedding
|
||||
|
||||
def compute_sim(self, img1, img2):
|
||||
emb1 = self.get_embedding(img1).flatten()
|
||||
emb2 = self.get_embedding(img2).flatten()
|
||||
from numpy.linalg import norm
|
||||
sim = np.dot(emb1, emb2)/(norm(emb1)*norm(emb2))
|
||||
return sim
|
||||
|
||||
def get_arcface(name, download=True,
|
||||
root='~/.insightface/models', **kwargs):
|
||||
if not download:
|
||||
return FaceRecognition(name, False, None)
|
||||
else:
|
||||
from .model_store import get_model_file
|
||||
_file = get_model_file("arcface_%s"%name, root=root)
|
||||
return FaceRecognition(name, True, _file)
|
||||
|
||||
def arcface_r100_v1(**kwargs):
|
||||
return get_arcface("r100_v1", download=True, **kwargs)
|
||||
|
||||
|
||||
def arcface_mfn_v1(**kwargs):
|
||||
return get_arcface("mfn_v1", download=True, **kwargs)
|
||||
|
||||
def arcface_outofreach_v1(**kwargs):
|
||||
return get_arcface("outofreach_v1", download=False, **kwargs)
|
||||
|
||||
92
python-package/insightface/model_zoo/model_store.py
Normal file
92
python-package/insightface/model_zoo/model_store.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""Model store which provides pretrained models."""
|
||||
from __future__ import print_function
|
||||
|
||||
__all__ = ['get_model_file']
|
||||
import os
|
||||
import zipfile
|
||||
import glob
|
||||
|
||||
from ..utils import download, check_sha1
|
||||
|
||||
_model_sha1 = {name: checksum for checksum, name in [
|
||||
('95be21b58e29e9c1237f229dae534bd854009ce0', 'arcface_r100_v1'),
|
||||
('', 'arcface_mfn_v1'),
|
||||
('39fd1e087a2a2ed70a154ac01fecaa86c315d01b', 'retinaface_r50_v1'),
|
||||
('2c9de8116d1f448fd1d4661f90308faae34c990a', 'retinaface_mnet025_v1'),
|
||||
]}
|
||||
|
||||
base_repo_url = 'http://insightface.ai/files/'
|
||||
_url_format = '{repo_url}models/{file_name}.zip'
|
||||
|
||||
|
||||
def short_hash(name):
|
||||
if name not in _model_sha1:
|
||||
raise ValueError('Pretrained model for {name} is not available.'.format(name=name))
|
||||
return _model_sha1[name][:8]
|
||||
|
||||
|
||||
def find_params_file(dir_path):
|
||||
if not os.path.exists(dir_path):
|
||||
return None
|
||||
paths = glob.glob("%s/*.params"%dir_path)
|
||||
if len(paths)==0:
|
||||
return None
|
||||
paths = sorted(paths)
|
||||
return paths[-1]
|
||||
|
||||
def get_model_file(name, root=os.path.join('~', '.insightface', 'models')):
|
||||
r"""Return location for the pretrained on local file system.
|
||||
|
||||
This function will download from online model zoo when model cannot be found or has mismatch.
|
||||
The root directory will be created if it doesn't exist.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
Name of the model.
|
||||
root : str, default '~/.mxnet/models'
|
||||
Location for keeping the model parameters.
|
||||
|
||||
Returns
|
||||
-------
|
||||
file_path
|
||||
Path to the requested pretrained model file.
|
||||
"""
|
||||
|
||||
file_name = name
|
||||
root = os.path.expanduser(root)
|
||||
dir_path = os.path.join(root, name)
|
||||
file_path = find_params_file(dir_path)
|
||||
#file_path = os.path.join(root, file_name + '.params')
|
||||
sha1_hash = _model_sha1[name]
|
||||
if file_path is not None:
|
||||
if check_sha1(file_path, sha1_hash):
|
||||
return file_path
|
||||
else:
|
||||
print('Mismatch in the content of model file detected. Downloading again.')
|
||||
else:
|
||||
print('Model file is not found. Downloading.')
|
||||
|
||||
if not os.path.exists(root):
|
||||
os.makedirs(root)
|
||||
if not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
|
||||
zip_file_path = os.path.join(root, file_name + '.zip')
|
||||
repo_url = base_repo_url
|
||||
if repo_url[-1] != '/':
|
||||
repo_url = repo_url + '/'
|
||||
download(_url_format.format(repo_url=repo_url, file_name=file_name),
|
||||
path=zip_file_path,
|
||||
overwrite=True)
|
||||
with zipfile.ZipFile(zip_file_path) as zf:
|
||||
zf.extractall(dir_path)
|
||||
os.remove(zip_file_path)
|
||||
file_path = find_params_file(dir_path)
|
||||
|
||||
if check_sha1(file_path, sha1_hash):
|
||||
return file_path
|
||||
else:
|
||||
raise ValueError('Downloaded file has different hash. Please try again.')
|
||||
|
||||
|
||||
54
python-package/insightface/model_zoo/model_zoo.py
Normal file
54
python-package/insightface/model_zoo/model_zoo.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# pylint: disable=wildcard-import, unused-wildcard-import
|
||||
"""Model store which handles pretrained models from both
|
||||
mxnet.gluon.model_zoo.vision and gluoncv.models
|
||||
"""
|
||||
from .face_recognition import *
|
||||
from .face_detection import *
|
||||
#from .face_alignment import *
|
||||
|
||||
__all__ = ['get_model', 'get_model_list']
|
||||
|
||||
_models = {
|
||||
'arcface_r100_v1': arcface_r100_v1,
|
||||
#'arcface_mfn_v1': arcface_mfn_v1,
|
||||
#'arcface_outofreach_v1': arcface_outofreach_v1,
|
||||
'retinaface_r50_v1': retinaface_r50_v1,
|
||||
'retinaface_mnet025_v1': retinaface_mnet025_v1,
|
||||
}
|
||||
|
||||
|
||||
def get_model(name, **kwargs):
|
||||
"""Returns a pre-defined model by name
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
Name of the model.
|
||||
root : str, default '~/.insightface/models'
|
||||
Location for keeping the model parameters.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Model
|
||||
The model.
|
||||
"""
|
||||
name = name.lower()
|
||||
if name not in _models:
|
||||
err_str = '"%s" is not among the following model list:\n\t' % (name)
|
||||
err_str += '%s' % ('\n\t'.join(sorted(_models.keys())))
|
||||
raise ValueError(err_str)
|
||||
net = _models[name](**kwargs)
|
||||
return net
|
||||
|
||||
|
||||
def get_model_list():
|
||||
"""Get the entire list of model names in model_zoo.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list of str
|
||||
Entire list of model names in model_zoo.
|
||||
|
||||
"""
|
||||
return sorted(_models.keys())
|
||||
|
||||
17
python-package/insightface/utils/__init__.py
Normal file
17
python-package/insightface/utils/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
#from . import bbox
|
||||
#from . import viz
|
||||
#from . import random
|
||||
#from . import metrics
|
||||
#from . import parallel
|
||||
|
||||
from .download import download, check_sha1
|
||||
from .filesystem import makedirs
|
||||
from .filesystem import try_import_dali
|
||||
#from .bbox import bbox_iou
|
||||
#from .block import recursive_visit, set_lr_mult, freeze_bn
|
||||
#from .lr_scheduler import LRSequential, LRScheduler
|
||||
#from .plot_history import TrainingHistory
|
||||
#from .export_helper import export_block
|
||||
#from .sync_loader_helper import split_data, split_and_load
|
||||
88
python-package/insightface/utils/download.py
Normal file
88
python-package/insightface/utils/download.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""Download files with progress bar."""
|
||||
import os
|
||||
import hashlib
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
def check_sha1(filename, sha1_hash):
|
||||
"""Check whether the sha1 hash of the file content matches the expected hash.
|
||||
Parameters
|
||||
----------
|
||||
filename : str
|
||||
Path to the file.
|
||||
sha1_hash : str
|
||||
Expected sha1 hash in hexadecimal digits.
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
Whether the file content matches the expected hash.
|
||||
"""
|
||||
sha1 = hashlib.sha1()
|
||||
with open(filename, 'rb') as f:
|
||||
while True:
|
||||
data = f.read(1048576)
|
||||
if not data:
|
||||
break
|
||||
sha1.update(data)
|
||||
|
||||
sha1_file = sha1.hexdigest()
|
||||
l = min(len(sha1_file), len(sha1_hash))
|
||||
return sha1.hexdigest()[0:l] == sha1_hash[0:l]
|
||||
|
||||
def download(url, path=None, overwrite=False, sha1_hash=None):
|
||||
"""Download an given URL
|
||||
Parameters
|
||||
----------
|
||||
url : str
|
||||
URL to download
|
||||
path : str, optional
|
||||
Destination path to store downloaded file. By default stores to the
|
||||
current directory with same name as in url.
|
||||
overwrite : bool, optional
|
||||
Whether to overwrite destination file if already exists.
|
||||
sha1_hash : str, optional
|
||||
Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified
|
||||
but doesn't match.
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
The file path of the downloaded file.
|
||||
"""
|
||||
if path is None:
|
||||
fname = url.split('/')[-1]
|
||||
else:
|
||||
path = os.path.expanduser(path)
|
||||
if os.path.isdir(path):
|
||||
fname = os.path.join(path, url.split('/')[-1])
|
||||
else:
|
||||
fname = path
|
||||
|
||||
if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)):
|
||||
dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
|
||||
print('Downloading %s from %s...'%(fname, url))
|
||||
r = requests.get(url, stream=True)
|
||||
if r.status_code != 200:
|
||||
raise RuntimeError("Failed downloading url %s"%url)
|
||||
total_length = r.headers.get('content-length')
|
||||
with open(fname, 'wb') as f:
|
||||
if total_length is None: # no content length header
|
||||
for chunk in r.iter_content(chunk_size=1024):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
f.write(chunk)
|
||||
else:
|
||||
total_length = int(total_length)
|
||||
for chunk in tqdm(r.iter_content(chunk_size=1024),
|
||||
total=int(total_length / 1024. + 0.5),
|
||||
unit='KB', unit_scale=False, dynamic_ncols=True):
|
||||
f.write(chunk)
|
||||
|
||||
if sha1_hash and not check_sha1(fname, sha1_hash):
|
||||
raise UserWarning('File {} is downloaded but the content hash does not match. ' \
|
||||
'The repo may be outdated or download may be incomplete. ' \
|
||||
'If the "repo_url" is overridden, consider switching to ' \
|
||||
'the default repo.'.format(fname))
|
||||
|
||||
return fname
|
||||
135
python-package/insightface/utils/filesystem.py
Normal file
135
python-package/insightface/utils/filesystem.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Filesystem utility functions."""
|
||||
import os
|
||||
import errno
|
||||
|
||||
def makedirs(path):
|
||||
"""Create directory recursively if not exists.
|
||||
Similar to `makedir -p`, you can skip checking existence before this function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : str
|
||||
Path of the desired dir
|
||||
"""
|
||||
try:
|
||||
os.makedirs(path)
|
||||
except OSError as exc:
|
||||
if exc.errno != errno.EEXIST:
|
||||
raise
|
||||
|
||||
def try_import(package, message=None):
|
||||
"""Try import specified package, with custom message support.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
package : str
|
||||
The name of the targeting package.
|
||||
message : str, default is None
|
||||
If not None, this function will raise customized error message when import error is found.
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
module if found, raise ImportError otherwise
|
||||
|
||||
"""
|
||||
try:
|
||||
return __import__(package)
|
||||
except ImportError as e:
|
||||
if not message:
|
||||
raise e
|
||||
raise ImportError(message)
|
||||
|
||||
def try_import_cv2():
|
||||
"""Try import cv2 at runtime.
|
||||
|
||||
Returns
|
||||
-------
|
||||
cv2 module if found. Raise ImportError otherwise
|
||||
|
||||
"""
|
||||
msg = "cv2 is required, you can install by package manager, e.g. 'apt-get', \
|
||||
or `pip install opencv-python --user` (note that this is unofficial PYPI package)."
|
||||
return try_import('cv2', msg)
|
||||
|
||||
def try_import_mmcv():
|
||||
"""Try import mmcv at runtime.
|
||||
|
||||
Returns
|
||||
-------
|
||||
mmcv module if found. Raise ImportError otherwise
|
||||
|
||||
"""
|
||||
msg = "mmcv is required, you can install by first `pip install Cython --user` \
|
||||
and then `pip install mmcv --user` (note that this is unofficial PYPI package)."
|
||||
return try_import('mmcv', msg)
|
||||
|
||||
def try_import_rarfile():
|
||||
"""Try import rarfile at runtime.
|
||||
|
||||
Returns
|
||||
-------
|
||||
rarfile module if found. Raise ImportError otherwise
|
||||
|
||||
"""
|
||||
msg = "rarfile is required, you can install by first `sudo apt-get install unrar` \
|
||||
and then `pip install rarfile --user` (note that this is unofficial PYPI package)."
|
||||
return try_import('rarfile', msg)
|
||||
|
||||
def import_try_install(package, extern_url=None):
|
||||
"""Try import the specified package.
|
||||
If the package not installed, try use pip to install and import if success.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
package : str
|
||||
The name of the package trying to import.
|
||||
extern_url : str or None, optional
|
||||
The external url if package is not hosted on PyPI.
|
||||
For example, you can install a package using:
|
||||
"pip install git+http://github.com/user/repo/tarball/master/egginfo=xxx".
|
||||
In this case, you can pass the url to the extern_url.
|
||||
|
||||
Returns
|
||||
-------
|
||||
<class 'Module'>
|
||||
The imported python module.
|
||||
|
||||
"""
|
||||
try:
|
||||
return __import__(package)
|
||||
except ImportError:
|
||||
try:
|
||||
from pip import main as pipmain
|
||||
except ImportError:
|
||||
from pip._internal import main as pipmain
|
||||
|
||||
# trying to install package
|
||||
url = package if extern_url is None else extern_url
|
||||
pipmain(['install', '--user', url]) # will raise SystemExit Error if fails
|
||||
|
||||
# trying to load again
|
||||
try:
|
||||
return __import__(package)
|
||||
except ImportError:
|
||||
import sys
|
||||
import site
|
||||
user_site = site.getusersitepackages()
|
||||
if user_site not in sys.path:
|
||||
sys.path.append(user_site)
|
||||
return __import__(package)
|
||||
return __import__(package)
|
||||
|
||||
def try_import_dali():
|
||||
"""Try import NVIDIA DALI at runtime.
|
||||
"""
|
||||
try:
|
||||
dali = __import__('nvidia.dali', fromlist=['pipeline', 'ops', 'types'])
|
||||
dali.Pipeline = dali.pipeline.Pipeline
|
||||
except ImportError:
|
||||
class dali:
|
||||
class Pipeline:
|
||||
def __init__(self):
|
||||
raise NotImplementedError(
|
||||
"DALI not found, please check if you installed it correctly.")
|
||||
return dali
|
||||
61
python-package/setup.py
Normal file
61
python-package/setup.py
Normal file
@@ -0,0 +1,61 @@
|
||||
#!/usr/bin/env python
|
||||
import os
|
||||
import io
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
def read(*names, **kwargs):
|
||||
with io.open(
|
||||
os.path.join(os.path.dirname(__file__), *names),
|
||||
encoding=kwargs.get("encoding", "utf8")
|
||||
) as fp:
|
||||
return fp.read()
|
||||
|
||||
|
||||
def find_version(*file_paths):
|
||||
version_file = read(*file_paths)
|
||||
version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
|
||||
version_file, re.M)
|
||||
if version_match:
|
||||
return version_match.group(1)
|
||||
raise RuntimeError("Unable to find version string.")
|
||||
|
||||
try:
|
||||
import pypandoc
|
||||
long_description = pypandoc.convert('README.md', 'rst')
|
||||
except(IOError, ImportError):
|
||||
long_description = open('README.md').read()
|
||||
|
||||
VERSION = find_version('insightface', '__init__.py')
|
||||
|
||||
requirements = [
|
||||
'numpy',
|
||||
'tqdm',
|
||||
'requests',
|
||||
'matplotlib',
|
||||
'Pillow',
|
||||
'scipy',
|
||||
'opencv-python',
|
||||
'scikit-learn',
|
||||
'scikit-image',
|
||||
'easydict',
|
||||
]
|
||||
|
||||
setup(
|
||||
# Metadata
|
||||
name='insightface',
|
||||
version=VERSION,
|
||||
author='InsightFace Contributors',
|
||||
url='https://github.com/deepinsight/insightface',
|
||||
description='InsightFace Toolkit',
|
||||
long_description=long_description,
|
||||
license='Apache-2.0',
|
||||
# Package info
|
||||
packages=find_packages(exclude=('docs', 'tests', 'scripts')),
|
||||
zip_safe=True,
|
||||
include_package_data=True,
|
||||
install_requires=requirements,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user