mirror of
https://github.com/deepinsight/insightface.git
synced 2025-12-30 08:02:27 +00:00
python package ver 0.6 update
This commit is contained in:
@@ -12,6 +12,19 @@ The code of InsightFace Python Library is released under the MIT License. There
|
||||
pip install -U insightface
|
||||
```
|
||||
|
||||
## Change Log
|
||||
|
||||
### [0.6] - 2022-01-29
|
||||
|
||||
#### Added
|
||||
|
||||
- Add pose estimation in face-analysis app.
|
||||
|
||||
#### Changed
|
||||
|
||||
- Change model automated downloading url, to ucloud.
|
||||
|
||||
|
||||
## Quick Example
|
||||
|
||||
```
|
||||
|
||||
@@ -11,7 +11,7 @@ except ImportError:
|
||||
"Unable to import dependency onnxruntime. "
|
||||
)
|
||||
|
||||
__version__ = '0.5'
|
||||
__version__ = '0.6'
|
||||
|
||||
from . import model_zoo
|
||||
from . import utils
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
from .image import get_image
|
||||
from .pickle_object import get_object
|
||||
|
||||
BIN
python-package/insightface/data/objects/meanshape_68.pkl
Normal file
BIN
python-package/insightface/data/objects/meanshape_68.pkl
Normal file
Binary file not shown.
17
python-package/insightface/data/pickle_object.py
Normal file
17
python-package/insightface/data/pickle_object.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import cv2
|
||||
import os
|
||||
import os.path as osp
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
|
||||
def get_object(name):
|
||||
objects_dir = osp.join(Path(__file__).parent.absolute(), 'objects')
|
||||
if not name.endswith('.pkl'):
|
||||
name = name+".pkl"
|
||||
filepath = osp.join(objects_dir, name)
|
||||
if not osp.exists(filepath):
|
||||
return None
|
||||
with open(filepath, 'rb') as f:
|
||||
obj = pickle.load(f)
|
||||
return obj
|
||||
|
||||
@@ -10,6 +10,8 @@ import cv2
|
||||
import onnx
|
||||
import onnxruntime
|
||||
from ..utils import face_align
|
||||
from ..utils import transform
|
||||
from ..data import get_object
|
||||
|
||||
__all__ = [
|
||||
'Landmark',
|
||||
@@ -59,10 +61,13 @@ class Landmark:
|
||||
self.output_names = output_names
|
||||
assert len(self.output_names)==1
|
||||
output_shape = outputs[0].shape
|
||||
self.require_pose = False
|
||||
#print('init output_shape:', output_shape)
|
||||
if output_shape[1]==3309:
|
||||
self.lmk_dim = 3
|
||||
self.lmk_num = 68
|
||||
self.mean_lmk = get_object('meanshape_68.pkl')
|
||||
self.require_pose = True
|
||||
else:
|
||||
self.lmk_dim = 2
|
||||
self.lmk_num = output_shape[1]//self.lmk_dim
|
||||
@@ -98,6 +103,12 @@ class Landmark:
|
||||
IM = cv2.invertAffineTransform(M)
|
||||
pred = face_align.trans_points(pred, IM)
|
||||
face[self.taskname] = pred
|
||||
if self.require_pose:
|
||||
P = transform.estimate_affine_matrix_3d23d(self.mean_lmk, pred)
|
||||
s, R, t = transform.P2sRt(P)
|
||||
rx, ry, rz = transform.matrix2angle(R)
|
||||
pose = np.array( [rx, ry, rz], dtype=np.float32 )
|
||||
face['pose'] = pose #pitch, yaw, roll
|
||||
return pred
|
||||
|
||||
|
||||
|
||||
@@ -18,31 +18,18 @@ from ..utils import download_onnx
|
||||
__all__ = ['get_model']
|
||||
|
||||
|
||||
def init_session(model_path, **kwargs):
|
||||
sess = onnxruntime.InferenceSession(model_path,**kwargs)
|
||||
return sess
|
||||
|
||||
class PickableInferenceSession:
|
||||
class PickableInferenceSession(onnxruntime.InferenceSession):
|
||||
# This is a wrapper to make the current InferenceSession class pickable.
|
||||
def __init__(self, model_path, **kwargs):
|
||||
super().__init__(model_path, **kwargs)
|
||||
self.model_path = model_path
|
||||
self.sess = init_session(self.model_path, **kwargs)
|
||||
|
||||
def run(self, *args):
|
||||
return self.sess.run(*args)
|
||||
|
||||
def __getstate__(self):
|
||||
return {'model_path': self.model_path}
|
||||
|
||||
def __setstate__(self, values):
|
||||
self.model_path = values['model_path']
|
||||
self.sess = init_session(self.model_path)
|
||||
|
||||
def get_inputs(self):
|
||||
return self.sess.get_inputs()
|
||||
|
||||
def get_outputs(self):
|
||||
return self.sess.get_outputs()
|
||||
model_path = values['model_path']
|
||||
self.__init__(model_path)
|
||||
|
||||
class ModelRouter:
|
||||
def __init__(self, onnx_file):
|
||||
|
||||
@@ -4,7 +4,8 @@ import os.path as osp
|
||||
import zipfile
|
||||
from .download import download_file
|
||||
|
||||
BASE_REPO_URL='http://storage.insightface.ai/files'
|
||||
#BASE_REPO_URL='http://storage.insightface.ai/files'
|
||||
BASE_REPO_URL='http://insightface.cn-sh2.ufileos.com'
|
||||
|
||||
def download(sub_dir, name, force=False, root='~/.insightface'):
|
||||
_root = os.path.expanduser(root)
|
||||
@@ -13,7 +14,8 @@ def download(sub_dir, name, force=False, root='~/.insightface'):
|
||||
return dir_path
|
||||
print('download_path:', dir_path)
|
||||
zip_file_path = os.path.join(_root, sub_dir, name + '.zip')
|
||||
model_url = "%s/%s/%s.zip"%(BASE_REPO_URL, sub_dir, name)
|
||||
#model_url = "%s/%s/%s.zip"%(BASE_REPO_URL, sub_dir, name)
|
||||
model_url = "%s/%s.zip"%(BASE_REPO_URL, name)
|
||||
download_file(model_url,
|
||||
path=zip_file_path,
|
||||
overwrite=True)
|
||||
|
||||
116
python-package/insightface/utils/transform.py
Normal file
116
python-package/insightface/utils/transform.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import cv2
|
||||
import math
|
||||
import numpy as np
|
||||
from skimage import transform as trans
|
||||
|
||||
|
||||
def transform(data, center, output_size, scale, rotation):
|
||||
scale_ratio = scale
|
||||
rot = float(rotation) * np.pi / 180.0
|
||||
#translation = (output_size/2-center[0]*scale_ratio, output_size/2-center[1]*scale_ratio)
|
||||
t1 = trans.SimilarityTransform(scale=scale_ratio)
|
||||
cx = center[0] * scale_ratio
|
||||
cy = center[1] * scale_ratio
|
||||
t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy))
|
||||
t3 = trans.SimilarityTransform(rotation=rot)
|
||||
t4 = trans.SimilarityTransform(translation=(output_size / 2,
|
||||
output_size / 2))
|
||||
t = t1 + t2 + t3 + t4
|
||||
M = t.params[0:2]
|
||||
cropped = cv2.warpAffine(data,
|
||||
M, (output_size, output_size),
|
||||
borderValue=0.0)
|
||||
return cropped, M
|
||||
|
||||
|
||||
def trans_points2d(pts, M):
|
||||
new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
|
||||
for i in range(pts.shape[0]):
|
||||
pt = pts[i]
|
||||
new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)
|
||||
new_pt = np.dot(M, new_pt)
|
||||
#print('new_pt', new_pt.shape, new_pt)
|
||||
new_pts[i] = new_pt[0:2]
|
||||
|
||||
return new_pts
|
||||
|
||||
|
||||
def trans_points3d(pts, M):
|
||||
scale = np.sqrt(M[0][0] * M[0][0] + M[0][1] * M[0][1])
|
||||
#print(scale)
|
||||
new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
|
||||
for i in range(pts.shape[0]):
|
||||
pt = pts[i]
|
||||
new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)
|
||||
new_pt = np.dot(M, new_pt)
|
||||
#print('new_pt', new_pt.shape, new_pt)
|
||||
new_pts[i][0:2] = new_pt[0:2]
|
||||
new_pts[i][2] = pts[i][2] * scale
|
||||
|
||||
return new_pts
|
||||
|
||||
|
||||
def trans_points(pts, M):
|
||||
if pts.shape[1] == 2:
|
||||
return trans_points2d(pts, M)
|
||||
else:
|
||||
return trans_points3d(pts, M)
|
||||
|
||||
def estimate_affine_matrix_3d23d(X, Y):
|
||||
''' Using least-squares solution
|
||||
Args:
|
||||
X: [n, 3]. 3d points(fixed)
|
||||
Y: [n, 3]. corresponding 3d points(moving). Y = PX
|
||||
Returns:
|
||||
P_Affine: (3, 4). Affine camera matrix (the third row is [0, 0, 0, 1]).
|
||||
'''
|
||||
X_homo = np.hstack((X, np.ones([X.shape[0],1]))) #n x 4
|
||||
P = np.linalg.lstsq(X_homo, Y)[0].T # Affine matrix. 3 x 4
|
||||
return P
|
||||
|
||||
def P2sRt(P):
|
||||
''' decompositing camera matrix P
|
||||
Args:
|
||||
P: (3, 4). Affine Camera Matrix.
|
||||
Returns:
|
||||
s: scale factor.
|
||||
R: (3, 3). rotation matrix.
|
||||
t: (3,). translation.
|
||||
'''
|
||||
t = P[:, 3]
|
||||
R1 = P[0:1, :3]
|
||||
R2 = P[1:2, :3]
|
||||
s = (np.linalg.norm(R1) + np.linalg.norm(R2))/2.0
|
||||
r1 = R1/np.linalg.norm(R1)
|
||||
r2 = R2/np.linalg.norm(R2)
|
||||
r3 = np.cross(r1, r2)
|
||||
|
||||
R = np.concatenate((r1, r2, r3), 0)
|
||||
return s, R, t
|
||||
|
||||
def matrix2angle(R):
|
||||
''' get three Euler angles from Rotation Matrix
|
||||
Args:
|
||||
R: (3,3). rotation matrix
|
||||
Returns:
|
||||
x: pitch
|
||||
y: yaw
|
||||
z: roll
|
||||
'''
|
||||
sy = math.sqrt(R[0,0] * R[0,0] + R[1,0] * R[1,0])
|
||||
|
||||
singular = sy < 1e-6
|
||||
|
||||
if not singular :
|
||||
x = math.atan2(R[2,1] , R[2,2])
|
||||
y = math.atan2(-R[2,0], sy)
|
||||
z = math.atan2(R[1,0], R[0,0])
|
||||
else :
|
||||
x = math.atan2(-R[1,2], R[1,1])
|
||||
y = math.atan2(-R[2,0], sy)
|
||||
z = 0
|
||||
|
||||
# rx, ry, rz = np.rad2deg(x), np.rad2deg(y), np.rad2deg(z)
|
||||
rx, ry, rz = x*180/np.pi, y*180/np.pi, z*180/np.pi
|
||||
return rx, ry, rz
|
||||
|
||||
@@ -27,8 +27,10 @@ def find_version(*file_paths):
|
||||
pypandoc_enabled = True
|
||||
try:
|
||||
import pypandoc
|
||||
print('pandoc enabled')
|
||||
long_description = pypandoc.convert('README.md', 'rst')
|
||||
except (IOError, ImportError, ModuleNotFoundError):
|
||||
print('WARNING: pandoc not enabled')
|
||||
long_description = open('README.md').read()
|
||||
pypandoc_enabled = False
|
||||
|
||||
@@ -64,7 +66,10 @@ data_mesh = list(glob.glob('insightface/thirdparty/face3d/mesh/cython/*.h'))
|
||||
data_mesh += list(glob.glob('insightface/thirdparty/face3d/mesh/cython/*.c'))
|
||||
data_mesh += list(glob.glob('insightface/thirdparty/face3d/mesh/cython/*.py*'))
|
||||
|
||||
data_objects = list(glob.glob('insightface/data/objects/*.pkl'))
|
||||
|
||||
data_files = [ ('insightface/data/images', data_images) ]
|
||||
data_files += [ ('insightface/data/objects', data_objects) ]
|
||||
data_files += [ ('insightface/thirdparty/face3d/mesh/cython', data_mesh) ]
|
||||
|
||||
ext_modules=cythonize(extensions)
|
||||
|
||||
Reference in New Issue
Block a user