pip package

This commit is contained in:
nttstar
2019-08-29 23:23:35 +08:00
parent 159ab42490
commit 9fc1ca6e46
11 changed files with 984 additions and 0 deletions

3
python-package/README.md Normal file
View File

@@ -0,0 +1,3 @@
InsightFace.ai README

View File

@@ -0,0 +1,28 @@
# coding: utf-8
# pylint: disable=wrong-import-position
"""InsightFace: A Face Analysis Toolkit."""
from __future__ import absolute_import
# mxnet version check
#mx_version = '1.4.0'
try:
import mxnet as mx
#from distutils.version import LooseVersion
#if LooseVersion(mx.__version__) < LooseVersion(mx_version):
# msg = (
# "Legacy mxnet-mkl=={} detected, some new modules may not work properly. "
# "mxnet-mkl>={} is required. You can use pip to upgrade mxnet "
# "`pip install mxnet-mkl --pre --upgrade` "
# "or `pip install mxnet-cu90mkl --pre --upgrade`").format(mx.__version__, mx_version)
# raise ImportError(msg)
except ImportError:
raise ImportError(
"Unable to import dependency mxnet. "
"A quick tip is to install via `pip install mxnet-mkl/mxnet-cu90mkl --pre`. ")
__version__ = '0.1.2'
from . import model_zoo
#from . import utils
#from . import analysis

View File

@@ -0,0 +1 @@
from .model_zoo import get_model, get_model_list

View File

@@ -0,0 +1,422 @@
from __future__ import division
import mxnet as mx
import numpy as np
import mxnet.ndarray as nd
__all__ = ['FaceDetector',
'retinaface_r50_v1',
'retinaface_mnet025_v1',
'get_retinaface']
def _whctrs(anchor):
"""
Return width, height, x center, and y center for an anchor (window).
"""
w = anchor[2] - anchor[0] + 1
h = anchor[3] - anchor[1] + 1
x_ctr = anchor[0] + 0.5 * (w - 1)
y_ctr = anchor[1] + 0.5 * (h - 1)
return w, h, x_ctr, y_ctr
def _mkanchors(ws, hs, x_ctr, y_ctr):
"""
Given a vector of widths (ws) and heights (hs) around a center
(x_ctr, y_ctr), output a set of anchors (windows).
"""
ws = ws[:, np.newaxis]
hs = hs[:, np.newaxis]
anchors = np.hstack((x_ctr - 0.5 * (ws - 1),
y_ctr - 0.5 * (hs - 1),
x_ctr + 0.5 * (ws - 1),
y_ctr + 0.5 * (hs - 1)))
return anchors
def _ratio_enum(anchor, ratios):
"""
Enumerate a set of anchors for each aspect ratio wrt an anchor.
"""
w, h, x_ctr, y_ctr = _whctrs(anchor)
size = w * h
size_ratios = size / ratios
ws = np.round(np.sqrt(size_ratios))
hs = np.round(ws * ratios)
anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
return anchors
def _scale_enum(anchor, scales):
"""
Enumerate a set of anchors for each scale wrt an anchor.
"""
w, h, x_ctr, y_ctr = _whctrs(anchor)
ws = w * scales
hs = h * scales
anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
return anchors
def anchors_plane(height, width, stride, base_anchors):
"""
Parameters
----------
height: height of plane
width: width of plane
stride: stride ot the original image
anchors_base: (A, 4) a base set of anchors
Returns
-------
all_anchors: (height, width, A, 4) ndarray of anchors spreading over the plane
"""
A = base_anchors.shape[0]
all_anchors = np.zeros((height, width, A, 4), dtype=np.float32)
for iw in range(width):
sw = iw * stride
for ih in range(height):
sh = ih * stride
for k in range(A):
all_anchors[ih, iw, k, 0] = base_anchors[k, 0] + sw
all_anchors[ih, iw, k, 1] = base_anchors[k, 1] + sh
all_anchors[ih, iw, k, 2] = base_anchors[k, 2] + sw
all_anchors[ih, iw, k, 3] = base_anchors[k, 3] + sh
return all_anchors
def generate_anchors(base_size=16, ratios=[0.5, 1, 2],
scales=2 ** np.arange(3, 6), stride=16):
"""
Generate anchor (reference) windows by enumerating aspect ratios X
scales wrt a reference (0, 0, 15, 15) window.
"""
base_anchor = np.array([1, 1, base_size, base_size]) - 1
ratio_anchors = _ratio_enum(base_anchor, ratios)
anchors = np.vstack([_scale_enum(ratio_anchors[i, :], scales)
for i in range(ratio_anchors.shape[0])])
return anchors
def generate_anchors_fpn(cfg):
"""
Generate anchor (reference) windows by enumerating aspect ratios X
scales wrt a reference (0, 0, 15, 15) window.
"""
RPN_FEAT_STRIDE = []
for k in cfg:
RPN_FEAT_STRIDE.append( int(k) )
RPN_FEAT_STRIDE = sorted(RPN_FEAT_STRIDE, reverse=True)
anchors = []
for k in RPN_FEAT_STRIDE:
v = cfg[str(k)]
bs = v['BASE_SIZE']
__ratios = np.array(v['RATIOS'])
__scales = np.array(v['SCALES'])
stride = int(k)
#print('anchors_fpn', bs, __ratios, __scales, file=sys.stderr)
r = generate_anchors(bs, __ratios, __scales, stride)
#print('anchors_fpn', r.shape, file=sys.stderr)
anchors.append(r)
return anchors
def clip_pad(tensor, pad_shape):
"""
Clip boxes of the pad area.
:param tensor: [n, c, H, W]
:param pad_shape: [h, w]
:return: [n, c, h, w]
"""
H, W = tensor.shape[2:]
h, w = pad_shape
if h < H or w < W:
tensor = tensor[:, :, :h, :w].copy()
return tensor
def bbox_pred(boxes, box_deltas):
"""
Transform the set of class-agnostic boxes into class-specific boxes
by applying the predicted offsets (box_deltas)
:param boxes: !important [N 4]
:param box_deltas: [N, 4 * num_classes]
:return: [N 4 * num_classes]
"""
if boxes.shape[0] == 0:
return np.zeros((0, box_deltas.shape[1]))
boxes = boxes.astype(np.float, copy=False)
widths = boxes[:, 2] - boxes[:, 0] + 1.0
heights = boxes[:, 3] - boxes[:, 1] + 1.0
ctr_x = boxes[:, 0] + 0.5 * (widths - 1.0)
ctr_y = boxes[:, 1] + 0.5 * (heights - 1.0)
dx = box_deltas[:, 0:1]
dy = box_deltas[:, 1:2]
dw = box_deltas[:, 2:3]
dh = box_deltas[:, 3:4]
pred_ctr_x = dx * widths[:, np.newaxis] + ctr_x[:, np.newaxis]
pred_ctr_y = dy * heights[:, np.newaxis] + ctr_y[:, np.newaxis]
pred_w = np.exp(dw) * widths[:, np.newaxis]
pred_h = np.exp(dh) * heights[:, np.newaxis]
pred_boxes = np.zeros(box_deltas.shape)
# x1
pred_boxes[:, 0:1] = pred_ctr_x - 0.5 * (pred_w - 1.0)
# y1
pred_boxes[:, 1:2] = pred_ctr_y - 0.5 * (pred_h - 1.0)
# x2
pred_boxes[:, 2:3] = pred_ctr_x + 0.5 * (pred_w - 1.0)
# y2
pred_boxes[:, 3:4] = pred_ctr_y + 0.5 * (pred_h - 1.0)
if box_deltas.shape[1]>4:
pred_boxes[:,4:] = box_deltas[:,4:]
return pred_boxes
def landmark_pred(boxes, landmark_deltas):
if boxes.shape[0] == 0:
return np.zeros((0, landmark_deltas.shape[1]))
boxes = boxes.astype(np.float, copy=False)
widths = boxes[:, 2] - boxes[:, 0] + 1.0
heights = boxes[:, 3] - boxes[:, 1] + 1.0
ctr_x = boxes[:, 0] + 0.5 * (widths - 1.0)
ctr_y = boxes[:, 1] + 0.5 * (heights - 1.0)
pred = landmark_deltas.copy()
for i in range(5):
pred[:,i,0] = landmark_deltas[:,i,0]*widths + ctr_x
pred[:,i,1] = landmark_deltas[:,i,1]*heights + ctr_y
return pred
class FaceDetector:
def __init__(self, param_file, rac):
self.param_file = param_file
self.rac = rac
self.default_image_size = (480, 640)
def prepare(self, ctx_id, nms=0.4, fix_image_size=None):
pos = self.param_file.rfind('-')
prefix = self.param_file[0:pos]
pos2 = self.param_file.rfind('.')
epoch = int(self.param_file[pos+1:pos2])
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
if ctx_id>=0:
ctx = mx.gpu(ctx_id)
else:
ctx = mx.cpu()
model = mx.mod.Module(symbol=sym, context=ctx, label_names = None)
if fix_image_size is not None:
data_shape = (1,3)+fix_image_size
else:
data_shape = (1,3)+self.default_image_size
model.bind(data_shapes=[('data', data_shape)])
model.set_params(arg_params, aux_params)
#warmup
data = mx.nd.zeros(shape=data_shape)
db = mx.io.DataBatch(data=(data,))
model.forward(db, is_train=False)
out = model.get_outputs()[0].asnumpy()
self.model = model
self.nms_threshold = nms
_ratio = (1.,)
fmc = 3
if self.rac=='net3':
_ratio = (1.,)
elif network=='net5': #retinaface
fmc = 5
else:
assert False, 'rac setting error %s'%self.rac
if fmc==3:
self._feat_stride_fpn = [32, 16, 8]
self.anchor_cfg = {
'32': {'SCALES': (32,16), 'BASE_SIZE': 16, 'RATIOS': _ratio, 'ALLOWED_BORDER': 9999},
'16': {'SCALES': (8,4), 'BASE_SIZE': 16, 'RATIOS': _ratio, 'ALLOWED_BORDER': 9999},
'8': {'SCALES': (2,1), 'BASE_SIZE': 16, 'RATIOS': _ratio, 'ALLOWED_BORDER': 9999},
}
elif fmc==5:
self._feat_stride_fpn = [64, 32, 16, 8, 4]
self.anchor_cfg = {}
_ass = 2.0**(1.0/3)
_basescale = 1.0
for _stride in [4, 8, 16, 32, 64]:
key = str(_stride)
value = {'BASE_SIZE': 16, 'RATIOS': _ratio, 'ALLOWED_BORDER': 9999}
scales = []
for _ in range(3):
scales.append(_basescale)
_basescale *= _ass
value['SCALES'] = tuple(scales)
self.anchor_cfg[key] = value
print(self._feat_stride_fpn, self.anchor_cfg)
self.use_landmarks = False
if len(sym)//len(self._feat_stride_fpn)==3:
self.use_landmarks = True
print('use_landmarks', self.use_landmarks)
self.fpn_keys = []
for s in self._feat_stride_fpn:
self.fpn_keys.append('stride%s'%s)
self._anchors_fpn = dict(zip(self.fpn_keys, generate_anchors_fpn(cfg=self.anchor_cfg)))
for k in self._anchors_fpn:
v = self._anchors_fpn[k].astype(np.float32)
self._anchors_fpn[k] = v
self.anchor_plane_cache = {}
if fix_image_size is None:
self.anchor_plane_cache = None
self._num_anchors = dict(zip(self.fpn_keys, [anchors.shape[0] for anchors in self._anchors_fpn.values()]))
def detect(self, img, threshold=0.5, scale=1.0):
proposals_list = []
scores_list = []
landmarks_list = []
if scale==1.0:
im = img
else:
im = cv2.resize(img, None, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
im_info = [im.shape[0], im.shape[1]]
im_tensor = np.zeros((1, 3, im.shape[0], im.shape[1]))
for i in range(3):
im_tensor[0, i, :, :] = im[:, :, 2 - i]
data = nd.array(im_tensor)
db = mx.io.DataBatch(data=(data,), provide_data=[('data', data.shape)])
self.model.forward(db, is_train=False)
net_out = self.model.get_outputs()
for _idx,s in enumerate(self._feat_stride_fpn):
_key = 'stride%s'%s
stride = int(s)
if self.use_landmarks:
idx = _idx*3
else:
idx = _idx*2
scores = net_out[idx].asnumpy()
scores = scores[:, self._num_anchors['stride%s'%s]:, :, :]
idx+=1
bbox_deltas = net_out[idx].asnumpy()
height, width = bbox_deltas.shape[2], bbox_deltas.shape[3]
A = self._num_anchors['stride%s'%s]
K = height * width
if self.anchor_plane_cache is not None:
key = (height, width, stride)
if key in self.anchor_plane_cache:
anchors = self.anchor_plane_cache[key]
else:
anchors_fpn = self._anchors_fpn['stride%s'%s]
anchors = anchors_plane(height, width, stride, anchors_fpn)
anchors = anchors.reshape((K * A, 4))
self.anchor_plane_cache[key] = anchors
else:
anchors_fpn = self._anchors_fpn['stride%s'%s]
anchors = anchors_plane(height, width, stride, anchors_fpn)
anchors = anchors.reshape((K * A, 4))
scores = clip_pad(scores, (height, width))
scores = scores.transpose((0, 2, 3, 1)).reshape((-1, 1))
bbox_deltas = clip_pad(bbox_deltas, (height, width))
bbox_deltas = bbox_deltas.transpose((0, 2, 3, 1))
bbox_pred_len = bbox_deltas.shape[3]//A
bbox_deltas = bbox_deltas.reshape((-1, bbox_pred_len))
proposals = bbox_pred(anchors, bbox_deltas)
#proposals = clip_boxes(proposals, im_info[:2])
scores_ravel = scores.ravel()
order = np.where(scores_ravel>=threshold)[0]
proposals = proposals[order, :]
scores = scores[order]
proposals[:,0:4] /= scale
proposals_list.append(proposals)
scores_list.append(scores)
if self.use_landmarks:
idx+=1
landmark_deltas = net_out[idx].asnumpy()
landmark_deltas = clip_pad(landmark_deltas, (height, width))
landmark_pred_len = landmark_deltas.shape[1]//A
landmark_deltas = landmark_deltas.transpose((0, 2, 3, 1)).reshape((-1, 5, landmark_pred_len//5))
#print(landmark_deltas.shape, landmark_deltas)
landmarks = landmark_pred(anchors, landmark_deltas)
landmarks = landmarks[order, :]
landmarks[:,:,0:2] /= scale
landmarks_list.append(landmarks)
proposals = np.vstack(proposals_list)
landmarks = None
if proposals.shape[0]==0:
if self.use_landmarks:
landmarks = np.zeros( (0,5,2) )
return np.zeros( (0,5) ), landmarks
scores = np.vstack(scores_list)
scores_ravel = scores.ravel()
order = scores_ravel.argsort()[::-1]
proposals = proposals[order, :]
scores = scores[order]
if self.use_landmarks:
landmarks = np.vstack(landmarks_list)
landmarks = landmarks[order].astype(np.float32, copy=False)
pre_det = np.hstack((proposals[:,0:4], scores)).astype(np.float32, copy=False)
keep = self.nms(pre_det)
det = np.hstack( (pre_det, proposals[:,4:]) )
det = det[keep, :]
if self.use_landmarks:
landmarks = landmarks[keep]
return det, landmarks
def nms(self, dets):
thresh = self.nms_threshold
x1 = dets[:, 0]
y1 = dets[:, 1]
x2 = dets[:, 2]
y2 = dets[:, 3]
scores = dets[:, 4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
inds = np.where(ovr <= thresh)[0]
order = order[inds + 1]
return keep
def get_retinaface(name, rac='net3',
root='~/.insightface/models', **kwargs):
from .model_store import get_model_file
_file = get_model_file("retinaface_%s"%name, root=root)
return FaceDetector(_file, rac)
def retinaface_r50_v1(**kwargs):
return get_retinaface("r50_v1", rac='net3', **kwargs)
def retinaface_mnet025_v1(**kwargs):
return get_retinaface("mnet025_v1", rac='net3', **kwargs)

View File

@@ -0,0 +1,83 @@
from __future__ import division
import mxnet as mx
import numpy as np
import cv2
__all__ = ['FaceRecognition',
'arcface_r100_v1', 'arcface_outofreach_v1', 'arcface_mfn_v1',
'get_arcface']
class FaceRecognition:
def __init__(self, name, download, param_file):
self.name = name
self.download = download
self.param_file = param_file
self.image_size = (112, 112)
if download:
assert param_file
def prepare(self, ctx_id):
if self.param_file:
pos = self.param_file.rfind('-')
prefix = self.param_file[0:pos]
pos2 = self.param_file.rfind('.')
epoch = int(self.param_file[pos+1:pos2])
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
all_layers = sym.get_internals()
sym = all_layers['fc1_output']
if ctx_id>=0:
ctx = mx.gpu(ctx_id)
else:
ctx = mx.cpu()
model = mx.mod.Module(symbol=sym, context=ctx, label_names = None)
data_shape = (1,3)+self.image_size
model.bind(data_shapes=[('data', data_shape)])
model.set_params(arg_params, aux_params)
#warmup
data = mx.nd.zeros(shape=data_shape)
db = mx.io.DataBatch(data=(data,))
model.forward(db, is_train=False)
embedding = model.get_outputs()[0].asnumpy()
self.model = model
else:
pass
def get_embedding(self, img):
assert self.param_file and self.model
assert img.shape[2]==3 and img.shape[0:2]==self.image_size
data = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
data = np.transpose(data, (2,0,1))
data = np.expand_dims(data, axis=0)
data = mx.nd.array(data)
db = mx.io.DataBatch(data=(data,))
self.model.forward(db, is_train=False)
embedding = self.model.get_outputs()[0].asnumpy()
return embedding
def compute_sim(self, img1, img2):
emb1 = self.get_embedding(img1).flatten()
emb2 = self.get_embedding(img2).flatten()
from numpy.linalg import norm
sim = np.dot(emb1, emb2)/(norm(emb1)*norm(emb2))
return sim
def get_arcface(name, download=True,
root='~/.insightface/models', **kwargs):
if not download:
return FaceRecognition(name, False, None)
else:
from .model_store import get_model_file
_file = get_model_file("arcface_%s"%name, root=root)
return FaceRecognition(name, True, _file)
def arcface_r100_v1(**kwargs):
return get_arcface("r100_v1", download=True, **kwargs)
def arcface_mfn_v1(**kwargs):
return get_arcface("mfn_v1", download=True, **kwargs)
def arcface_outofreach_v1(**kwargs):
return get_arcface("outofreach_v1", download=False, **kwargs)

View File

@@ -0,0 +1,92 @@
"""Model store which provides pretrained models."""
from __future__ import print_function
__all__ = ['get_model_file']
import os
import zipfile
import glob
from ..utils import download, check_sha1
_model_sha1 = {name: checksum for checksum, name in [
('95be21b58e29e9c1237f229dae534bd854009ce0', 'arcface_r100_v1'),
('', 'arcface_mfn_v1'),
('39fd1e087a2a2ed70a154ac01fecaa86c315d01b', 'retinaface_r50_v1'),
('2c9de8116d1f448fd1d4661f90308faae34c990a', 'retinaface_mnet025_v1'),
]}
base_repo_url = 'http://insightface.ai/files/'
_url_format = '{repo_url}models/{file_name}.zip'
def short_hash(name):
if name not in _model_sha1:
raise ValueError('Pretrained model for {name} is not available.'.format(name=name))
return _model_sha1[name][:8]
def find_params_file(dir_path):
if not os.path.exists(dir_path):
return None
paths = glob.glob("%s/*.params"%dir_path)
if len(paths)==0:
return None
paths = sorted(paths)
return paths[-1]
def get_model_file(name, root=os.path.join('~', '.insightface', 'models')):
r"""Return location for the pretrained on local file system.
This function will download from online model zoo when model cannot be found or has mismatch.
The root directory will be created if it doesn't exist.
Parameters
----------
name : str
Name of the model.
root : str, default '~/.mxnet/models'
Location for keeping the model parameters.
Returns
-------
file_path
Path to the requested pretrained model file.
"""
file_name = name
root = os.path.expanduser(root)
dir_path = os.path.join(root, name)
file_path = find_params_file(dir_path)
#file_path = os.path.join(root, file_name + '.params')
sha1_hash = _model_sha1[name]
if file_path is not None:
if check_sha1(file_path, sha1_hash):
return file_path
else:
print('Mismatch in the content of model file detected. Downloading again.')
else:
print('Model file is not found. Downloading.')
if not os.path.exists(root):
os.makedirs(root)
if not os.path.exists(dir_path):
os.makedirs(dir_path)
zip_file_path = os.path.join(root, file_name + '.zip')
repo_url = base_repo_url
if repo_url[-1] != '/':
repo_url = repo_url + '/'
download(_url_format.format(repo_url=repo_url, file_name=file_name),
path=zip_file_path,
overwrite=True)
with zipfile.ZipFile(zip_file_path) as zf:
zf.extractall(dir_path)
os.remove(zip_file_path)
file_path = find_params_file(dir_path)
if check_sha1(file_path, sha1_hash):
return file_path
else:
raise ValueError('Downloaded file has different hash. Please try again.')

View File

@@ -0,0 +1,54 @@
# pylint: disable=wildcard-import, unused-wildcard-import
"""Model store which handles pretrained models from both
mxnet.gluon.model_zoo.vision and gluoncv.models
"""
from .face_recognition import *
from .face_detection import *
#from .face_alignment import *
__all__ = ['get_model', 'get_model_list']
_models = {
'arcface_r100_v1': arcface_r100_v1,
#'arcface_mfn_v1': arcface_mfn_v1,
#'arcface_outofreach_v1': arcface_outofreach_v1,
'retinaface_r50_v1': retinaface_r50_v1,
'retinaface_mnet025_v1': retinaface_mnet025_v1,
}
def get_model(name, **kwargs):
"""Returns a pre-defined model by name
Parameters
----------
name : str
Name of the model.
root : str, default '~/.insightface/models'
Location for keeping the model parameters.
Returns
-------
Model
The model.
"""
name = name.lower()
if name not in _models:
err_str = '"%s" is not among the following model list:\n\t' % (name)
err_str += '%s' % ('\n\t'.join(sorted(_models.keys())))
raise ValueError(err_str)
net = _models[name](**kwargs)
return net
def get_model_list():
"""Get the entire list of model names in model_zoo.
Returns
-------
list of str
Entire list of model names in model_zoo.
"""
return sorted(_models.keys())

View File

@@ -0,0 +1,17 @@
from __future__ import absolute_import
#from . import bbox
#from . import viz
#from . import random
#from . import metrics
#from . import parallel
from .download import download, check_sha1
from .filesystem import makedirs
from .filesystem import try_import_dali
#from .bbox import bbox_iou
#from .block import recursive_visit, set_lr_mult, freeze_bn
#from .lr_scheduler import LRSequential, LRScheduler
#from .plot_history import TrainingHistory
#from .export_helper import export_block
#from .sync_loader_helper import split_data, split_and_load

View File

@@ -0,0 +1,88 @@
"""Download files with progress bar."""
import os
import hashlib
import requests
from tqdm import tqdm
def check_sha1(filename, sha1_hash):
"""Check whether the sha1 hash of the file content matches the expected hash.
Parameters
----------
filename : str
Path to the file.
sha1_hash : str
Expected sha1 hash in hexadecimal digits.
Returns
-------
bool
Whether the file content matches the expected hash.
"""
sha1 = hashlib.sha1()
with open(filename, 'rb') as f:
while True:
data = f.read(1048576)
if not data:
break
sha1.update(data)
sha1_file = sha1.hexdigest()
l = min(len(sha1_file), len(sha1_hash))
return sha1.hexdigest()[0:l] == sha1_hash[0:l]
def download(url, path=None, overwrite=False, sha1_hash=None):
"""Download an given URL
Parameters
----------
url : str
URL to download
path : str, optional
Destination path to store downloaded file. By default stores to the
current directory with same name as in url.
overwrite : bool, optional
Whether to overwrite destination file if already exists.
sha1_hash : str, optional
Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified
but doesn't match.
Returns
-------
str
The file path of the downloaded file.
"""
if path is None:
fname = url.split('/')[-1]
else:
path = os.path.expanduser(path)
if os.path.isdir(path):
fname = os.path.join(path, url.split('/')[-1])
else:
fname = path
if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)):
dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
if not os.path.exists(dirname):
os.makedirs(dirname)
print('Downloading %s from %s...'%(fname, url))
r = requests.get(url, stream=True)
if r.status_code != 200:
raise RuntimeError("Failed downloading url %s"%url)
total_length = r.headers.get('content-length')
with open(fname, 'wb') as f:
if total_length is None: # no content length header
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
else:
total_length = int(total_length)
for chunk in tqdm(r.iter_content(chunk_size=1024),
total=int(total_length / 1024. + 0.5),
unit='KB', unit_scale=False, dynamic_ncols=True):
f.write(chunk)
if sha1_hash and not check_sha1(fname, sha1_hash):
raise UserWarning('File {} is downloaded but the content hash does not match. ' \
'The repo may be outdated or download may be incomplete. ' \
'If the "repo_url" is overridden, consider switching to ' \
'the default repo.'.format(fname))
return fname

View File

@@ -0,0 +1,135 @@
"""Filesystem utility functions."""
import os
import errno
def makedirs(path):
"""Create directory recursively if not exists.
Similar to `makedir -p`, you can skip checking existence before this function.
Parameters
----------
path : str
Path of the desired dir
"""
try:
os.makedirs(path)
except OSError as exc:
if exc.errno != errno.EEXIST:
raise
def try_import(package, message=None):
"""Try import specified package, with custom message support.
Parameters
----------
package : str
The name of the targeting package.
message : str, default is None
If not None, this function will raise customized error message when import error is found.
Returns
-------
module if found, raise ImportError otherwise
"""
try:
return __import__(package)
except ImportError as e:
if not message:
raise e
raise ImportError(message)
def try_import_cv2():
"""Try import cv2 at runtime.
Returns
-------
cv2 module if found. Raise ImportError otherwise
"""
msg = "cv2 is required, you can install by package manager, e.g. 'apt-get', \
or `pip install opencv-python --user` (note that this is unofficial PYPI package)."
return try_import('cv2', msg)
def try_import_mmcv():
"""Try import mmcv at runtime.
Returns
-------
mmcv module if found. Raise ImportError otherwise
"""
msg = "mmcv is required, you can install by first `pip install Cython --user` \
and then `pip install mmcv --user` (note that this is unofficial PYPI package)."
return try_import('mmcv', msg)
def try_import_rarfile():
"""Try import rarfile at runtime.
Returns
-------
rarfile module if found. Raise ImportError otherwise
"""
msg = "rarfile is required, you can install by first `sudo apt-get install unrar` \
and then `pip install rarfile --user` (note that this is unofficial PYPI package)."
return try_import('rarfile', msg)
def import_try_install(package, extern_url=None):
"""Try import the specified package.
If the package not installed, try use pip to install and import if success.
Parameters
----------
package : str
The name of the package trying to import.
extern_url : str or None, optional
The external url if package is not hosted on PyPI.
For example, you can install a package using:
"pip install git+http://github.com/user/repo/tarball/master/egginfo=xxx".
In this case, you can pass the url to the extern_url.
Returns
-------
<class 'Module'>
The imported python module.
"""
try:
return __import__(package)
except ImportError:
try:
from pip import main as pipmain
except ImportError:
from pip._internal import main as pipmain
# trying to install package
url = package if extern_url is None else extern_url
pipmain(['install', '--user', url]) # will raise SystemExit Error if fails
# trying to load again
try:
return __import__(package)
except ImportError:
import sys
import site
user_site = site.getusersitepackages()
if user_site not in sys.path:
sys.path.append(user_site)
return __import__(package)
return __import__(package)
def try_import_dali():
"""Try import NVIDIA DALI at runtime.
"""
try:
dali = __import__('nvidia.dali', fromlist=['pipeline', 'ops', 'types'])
dali.Pipeline = dali.pipeline.Pipeline
except ImportError:
class dali:
class Pipeline:
def __init__(self):
raise NotImplementedError(
"DALI not found, please check if you installed it correctly.")
return dali

61
python-package/setup.py Normal file
View File

@@ -0,0 +1,61 @@
#!/usr/bin/env python
import os
import io
import re
import shutil
import sys
from setuptools import setup, find_packages
def read(*names, **kwargs):
with io.open(
os.path.join(os.path.dirname(__file__), *names),
encoding=kwargs.get("encoding", "utf8")
) as fp:
return fp.read()
def find_version(*file_paths):
version_file = read(*file_paths)
version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
version_file, re.M)
if version_match:
return version_match.group(1)
raise RuntimeError("Unable to find version string.")
try:
import pypandoc
long_description = pypandoc.convert('README.md', 'rst')
except(IOError, ImportError):
long_description = open('README.md').read()
VERSION = find_version('insightface', '__init__.py')
requirements = [
'numpy',
'tqdm',
'requests',
'matplotlib',
'Pillow',
'scipy',
'opencv-python',
'scikit-learn',
'scikit-image',
'easydict',
]
setup(
# Metadata
name='insightface',
version=VERSION,
author='InsightFace Contributors',
url='https://github.com/deepinsight/insightface',
description='InsightFace Toolkit',
long_description=long_description,
license='Apache-2.0',
# Package info
packages=find_packages(exclude=('docs', 'tests', 'scripts')),
zip_safe=True,
include_package_data=True,
install_requires=requirements,
)