python package ver 0.6 update

This commit is contained in:
Jia Guo
2022-01-29 19:54:55 +08:00
parent dc632c254e
commit e7816a226d
10 changed files with 172 additions and 20 deletions

View File

@@ -12,6 +12,19 @@ The code of InsightFace Python Library is released under the MIT License. There
pip install -U insightface
```
## Change Log
### [0.6] - 2022-01-29
#### Added
- Add pose estimation in face-analysis app.
#### Changed
- Change model automated downloading url, to ucloud.
## Quick Example
```

View File

@@ -11,7 +11,7 @@ except ImportError:
"Unable to import dependency onnxruntime. "
)
__version__ = '0.5'
__version__ = '0.6'
from . import model_zoo
from . import utils

View File

@@ -1 +1,2 @@
from .image import get_image
from .pickle_object import get_object

View File

@@ -0,0 +1,17 @@
import cv2
import os
import os.path as osp
from pathlib import Path
import pickle
def get_object(name):
objects_dir = osp.join(Path(__file__).parent.absolute(), 'objects')
if not name.endswith('.pkl'):
name = name+".pkl"
filepath = osp.join(objects_dir, name)
if not osp.exists(filepath):
return None
with open(filepath, 'rb') as f:
obj = pickle.load(f)
return obj

View File

@@ -10,6 +10,8 @@ import cv2
import onnx
import onnxruntime
from ..utils import face_align
from ..utils import transform
from ..data import get_object
__all__ = [
'Landmark',
@@ -59,10 +61,13 @@ class Landmark:
self.output_names = output_names
assert len(self.output_names)==1
output_shape = outputs[0].shape
self.require_pose = False
#print('init output_shape:', output_shape)
if output_shape[1]==3309:
self.lmk_dim = 3
self.lmk_num = 68
self.mean_lmk = get_object('meanshape_68.pkl')
self.require_pose = True
else:
self.lmk_dim = 2
self.lmk_num = output_shape[1]//self.lmk_dim
@@ -98,6 +103,12 @@ class Landmark:
IM = cv2.invertAffineTransform(M)
pred = face_align.trans_points(pred, IM)
face[self.taskname] = pred
if self.require_pose:
P = transform.estimate_affine_matrix_3d23d(self.mean_lmk, pred)
s, R, t = transform.P2sRt(P)
rx, ry, rz = transform.matrix2angle(R)
pose = np.array( [rx, ry, rz], dtype=np.float32 )
face['pose'] = pose #pitch, yaw, roll
return pred

View File

@@ -18,31 +18,18 @@ from ..utils import download_onnx
__all__ = ['get_model']
def init_session(model_path, **kwargs):
sess = onnxruntime.InferenceSession(model_path,**kwargs)
return sess
class PickableInferenceSession:
class PickableInferenceSession(onnxruntime.InferenceSession):
# This is a wrapper to make the current InferenceSession class pickable.
def __init__(self, model_path, **kwargs):
super().__init__(model_path, **kwargs)
self.model_path = model_path
self.sess = init_session(self.model_path, **kwargs)
def run(self, *args):
return self.sess.run(*args)
def __getstate__(self):
return {'model_path': self.model_path}
def __setstate__(self, values):
self.model_path = values['model_path']
self.sess = init_session(self.model_path)
def get_inputs(self):
return self.sess.get_inputs()
def get_outputs(self):
return self.sess.get_outputs()
model_path = values['model_path']
self.__init__(model_path)
class ModelRouter:
def __init__(self, onnx_file):

View File

@@ -4,7 +4,8 @@ import os.path as osp
import zipfile
from .download import download_file
BASE_REPO_URL='http://storage.insightface.ai/files'
#BASE_REPO_URL='http://storage.insightface.ai/files'
BASE_REPO_URL='http://insightface.cn-sh2.ufileos.com'
def download(sub_dir, name, force=False, root='~/.insightface'):
_root = os.path.expanduser(root)
@@ -13,7 +14,8 @@ def download(sub_dir, name, force=False, root='~/.insightface'):
return dir_path
print('download_path:', dir_path)
zip_file_path = os.path.join(_root, sub_dir, name + '.zip')
model_url = "%s/%s/%s.zip"%(BASE_REPO_URL, sub_dir, name)
#model_url = "%s/%s/%s.zip"%(BASE_REPO_URL, sub_dir, name)
model_url = "%s/%s.zip"%(BASE_REPO_URL, name)
download_file(model_url,
path=zip_file_path,
overwrite=True)

View File

@@ -0,0 +1,116 @@
import cv2
import math
import numpy as np
from skimage import transform as trans
def transform(data, center, output_size, scale, rotation):
scale_ratio = scale
rot = float(rotation) * np.pi / 180.0
#translation = (output_size/2-center[0]*scale_ratio, output_size/2-center[1]*scale_ratio)
t1 = trans.SimilarityTransform(scale=scale_ratio)
cx = center[0] * scale_ratio
cy = center[1] * scale_ratio
t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy))
t3 = trans.SimilarityTransform(rotation=rot)
t4 = trans.SimilarityTransform(translation=(output_size / 2,
output_size / 2))
t = t1 + t2 + t3 + t4
M = t.params[0:2]
cropped = cv2.warpAffine(data,
M, (output_size, output_size),
borderValue=0.0)
return cropped, M
def trans_points2d(pts, M):
new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
for i in range(pts.shape[0]):
pt = pts[i]
new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)
new_pt = np.dot(M, new_pt)
#print('new_pt', new_pt.shape, new_pt)
new_pts[i] = new_pt[0:2]
return new_pts
def trans_points3d(pts, M):
scale = np.sqrt(M[0][0] * M[0][0] + M[0][1] * M[0][1])
#print(scale)
new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
for i in range(pts.shape[0]):
pt = pts[i]
new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)
new_pt = np.dot(M, new_pt)
#print('new_pt', new_pt.shape, new_pt)
new_pts[i][0:2] = new_pt[0:2]
new_pts[i][2] = pts[i][2] * scale
return new_pts
def trans_points(pts, M):
if pts.shape[1] == 2:
return trans_points2d(pts, M)
else:
return trans_points3d(pts, M)
def estimate_affine_matrix_3d23d(X, Y):
''' Using least-squares solution
Args:
X: [n, 3]. 3d points(fixed)
Y: [n, 3]. corresponding 3d points(moving). Y = PX
Returns:
P_Affine: (3, 4). Affine camera matrix (the third row is [0, 0, 0, 1]).
'''
X_homo = np.hstack((X, np.ones([X.shape[0],1]))) #n x 4
P = np.linalg.lstsq(X_homo, Y)[0].T # Affine matrix. 3 x 4
return P
def P2sRt(P):
''' decompositing camera matrix P
Args:
P: (3, 4). Affine Camera Matrix.
Returns:
s: scale factor.
R: (3, 3). rotation matrix.
t: (3,). translation.
'''
t = P[:, 3]
R1 = P[0:1, :3]
R2 = P[1:2, :3]
s = (np.linalg.norm(R1) + np.linalg.norm(R2))/2.0
r1 = R1/np.linalg.norm(R1)
r2 = R2/np.linalg.norm(R2)
r3 = np.cross(r1, r2)
R = np.concatenate((r1, r2, r3), 0)
return s, R, t
def matrix2angle(R):
''' get three Euler angles from Rotation Matrix
Args:
R: (3,3). rotation matrix
Returns:
x: pitch
y: yaw
z: roll
'''
sy = math.sqrt(R[0,0] * R[0,0] + R[1,0] * R[1,0])
singular = sy < 1e-6
if not singular :
x = math.atan2(R[2,1] , R[2,2])
y = math.atan2(-R[2,0], sy)
z = math.atan2(R[1,0], R[0,0])
else :
x = math.atan2(-R[1,2], R[1,1])
y = math.atan2(-R[2,0], sy)
z = 0
# rx, ry, rz = np.rad2deg(x), np.rad2deg(y), np.rad2deg(z)
rx, ry, rz = x*180/np.pi, y*180/np.pi, z*180/np.pi
return rx, ry, rz

View File

@@ -27,8 +27,10 @@ def find_version(*file_paths):
pypandoc_enabled = True
try:
import pypandoc
print('pandoc enabled')
long_description = pypandoc.convert('README.md', 'rst')
except (IOError, ImportError, ModuleNotFoundError):
print('WARNING: pandoc not enabled')
long_description = open('README.md').read()
pypandoc_enabled = False
@@ -64,7 +66,10 @@ data_mesh = list(glob.glob('insightface/thirdparty/face3d/mesh/cython/*.h'))
data_mesh += list(glob.glob('insightface/thirdparty/face3d/mesh/cython/*.c'))
data_mesh += list(glob.glob('insightface/thirdparty/face3d/mesh/cython/*.py*'))
data_objects = list(glob.glob('insightface/data/objects/*.pkl'))
data_files = [ ('insightface/data/images', data_images) ]
data_files += [ ('insightface/data/objects', data_objects) ]
data_files += [ ('insightface/thirdparty/face3d/mesh/cython', data_mesh) ]
ext_modules=cythonize(extensions)