Files
insightface/reconstruction/ostec/core/operator.py

436 lines
20 KiB
Python

# Copyright (c) 2020, Baris Gecer. All rights reserved.
#
# This work is made available under the CC BY-NC-SA 4.0.
# To view a copy of this license, see LICENSE
from utils.align2stylegan import align_im2stylegan, align_mesh2stylegan
from core.projection_handler import Projection_Handler
from skimage.morphology import remove_small_holes
import time
from utils.ganfit_camera import apply_camera_only3d, get_pose
from utils.utils import *
from core.arcface_handler import Arcface_Handler
from utils import generate_heatmap
class Face:
def __init__(self, tmesh,
tmesh_masked,
tmesh_rotated,
img_uv_src,
angle_uv_src,
angle_uv_list,
img_uv_list,
view_angle_src,
id_features,
exclude_mask,
is_profile,
mode):
self.tmesh = tmesh
self.tmesh_masked = tmesh_masked
self.tmesh_rotated = tmesh_rotated
self.img_uv_src = img_uv_src
self.angle_uv_src = angle_uv_src
self.angle_uv_list = angle_uv_list
self.img_uv_list = img_uv_list
self.view_angle_src = view_angle_src
self.id_features = id_features
self.is_profile = is_profile
self.exclude_mask = exclude_mask
if mode == 'auto':
if is_profile:
self.mode = 'hard'
else:
self.mode = 'soft'
else:
self.mode = mode
def rotation_dict(self):
if self.mode == 'soft':
return {
'bottom': [-30, 0, 0],
'bottom_left': [-15, -30, 0],
'bottom_right': [-15, 30, 0],
'left': [5, -40, 0],
'right': [5, 40, 0],
}
elif self.mode == 'hard':
return {
'front': [0, 0, 0],
'bottom': [-30, 0, 0],
'bottom_left': [-15, -30, 0],
'bottom_right': [-15, 30, 0],
'left': [5, -40, 0],
'right': [5, 40, 0],
}
else:
raise Exception('Unknown mode!')
def coef_dict(self):
if self.mode == 'soft':
return {
'bottom': 1.3,
'bottom_left': 1,
'bottom_right': 1,
'left': 1,
'right': 1,
'src': 2,
}
elif self.mode == 'hard':
return {
'front': 2,
'bottom': 1.3,
'bottom_left': 1,
'bottom_right': 1,
'left': 1,
'right': 1,
'src': 0.1,
}
else:
raise Exception('Unknown mode!')
class Operator:
def __init__(self, args):
self.tcoords_full = mio.import_pickle('models/topology/tcoords_full.pkl')
self.tcoords = mio.import_pickle('models/topology/tcoords_alex.pkl')
self.mask = mio.import_pickle('models/topology/mask_full2crop.pkl') | True
self.tight_mask = mio.import_pickle('models/topology/mask_full2tightcrop.pkl')
self.template = mio.import_pickle('models/topology/all_all_all_crop_mean.pkl')
self.lms_ind = mio.import_pickle('models/topology/all_all_all_lands_ids.pkl')
self.img_shape = [1024, 1024] # 2048
self.uv_shape = [1024, 1536]
uv_mesh = self.tcoords.copy().points[:, ::-1]
uv_mesh[:, 0] = 1 - uv_mesh[:, 0]
uv_mesh *= self.uv_shape
self.uv_mesh = np.concatenate([uv_mesh, uv_mesh[:, 0:1] * 0], 1)
self.uv_trilist = mio.import_pickle('models/topology/trilist_full.pkl') #self.template.trilist
self.args = args
self.mode = args.mode # 'soft', 'hard', 'auto'
self.arcface_handler = Arcface_Handler()
self.projector = Projection_Handler(args)
def render_uv_image(self, generated, tcoords):
uv_tmesh = TexturedTriMesh(self.uv_mesh, tcoords, generated, trilist=self.uv_trilist)
bcs = rasterize_barycentric_coordinate_images(uv_tmesh, self.uv_shape)
img = rasterize_mesh_from_barycentric_coordinate_images(uv_tmesh, *bcs)
img.pixels = np.clip(img.pixels, 0.0, 1.0)
return img
def render_colored_image(self, view_angle_trg, return_visibility=False):
uv_cmesh = ColouredTriMesh(self.uv_mesh, trilist=self.uv_trilist, colours=np.tile(view_angle_trg, [3, 1]).T)
bcs = rasterize_barycentric_coordinate_images(uv_cmesh, self.uv_shape)
img = rasterize_mesh_from_barycentric_coordinate_images(uv_cmesh, *bcs)
img.pixels = np.clip(img.pixels, -1.0, 1.0)
if return_visibility:
visible = np.sum(bcs[0].pixels, axis=0) != 0
return img, visible
else:
return img
def camera_tri_angle_src(self, tmesh):#, pose_angle_deg=[0, 0, 0], cam_dist=-4.5):
camera_direction = -tmesh.points / np.tile(np.linalg.norm(tmesh.points, axis=1), [3, 1]).T
view_angle = np.sum(camera_direction * tmesh.vertex_normals(), 1)
return view_angle
def camera_tri_angle(self, tmesh, pose_angle_deg=[0, 0, 0], cam_dist=-4.5):
rot_z = rotation_z(pose_angle_deg[2])
rot_y = rotation_y(-pose_angle_deg[1])
rot_x = rotation_x(pose_angle_deg[0])
rotation = rot_z.compose_before(rot_y).compose_before(rot_x)
translation = Translation([0, 0, +cam_dist])
camera = rotation.compose_before(translation)
cam_mesh = camera.apply(tmesh)
camera_direction = -cam_mesh.points / np.tile(np.linalg.norm(cam_mesh.points, axis=1), [3, 1]).T
view_angle = np.sum(camera_direction * cam_mesh.vertex_normals(), 1)
return view_angle
def create_syn(self, face, trg_angle=[0, 0, 0], include_mask=None):
view_angle_trg = self.camera_tri_angle(face.tmesh, pose_angle_deg=trg_angle)
im, projected_mesh = rasterize_image(face.tmesh, self.img_shape, pose_angle_deg=trg_angle, cam_dist=4.5)
# fill_mask = include_mask.astype(np.bool)
fill_mask = ((view_angle_trg < face.view_angle_src) | (face.view_angle_src > 0.4)) & self.tight_mask
if include_mask is not None:
fill_mask = fill_mask | include_mask.astype(np.bool)
if face.exclude_mask is not None:
tcoord_sampling = np.round(self.tcoords.points[:,::-1] * face.exclude_mask.shape).astype(np.int)
fill_mask[self.mask] = fill_mask[self.mask] & ~face.exclude_mask[face.exclude_mask.shape[0] - tcoord_sampling[:, 0], tcoord_sampling[:, 1]]
mask_mesh = ColouredTriMesh(face.tmesh.points, trilist=face.tmesh.trilist, colours=np.tile(fill_mask, [3, 1]).T)
mask = rasterize_image(mask_mesh, self.img_shape,pose_angle_deg=trg_angle, cam_dist=4.5)[0]
return im, projected_mesh[:, :2], mask
def create_align_syn(self, face, trg_angle=[0, 0, 0], include_mask=None):
im, projected_mesh, mask = self.create_syn(face, trg_angle, include_mask)
imgs, masks, transformation_params = align_im2stylegan(im_menpo2PIL(im), #im_menpo2PIL(mask),
im_menpo2PIL(self.extend_mask(im, mask)),
projected_mesh[self.lms_ind][:,::-1])
aligned_meshes = align_mesh2stylegan(projected_mesh, transformation_params)
landmarks = aligned_meshes[self.lms_ind]
landmarks[:,1] = 1 - landmarks[:,1]
heatmaps = generate_heatmap.generate_heatmaps(width=self.args.model_res,
height=self.args.model_res,
points=landmarks*self.args.model_res,
sigma=25)
landmarks = landmarks[:,::-1]
aligned_meshes = aligned_meshes[self.mask]
return imgs, masks, heatmaps, aligned_meshes
def get_tmesh(self, im, reconstruction_dict, face_mask):
id_features = self.arcface_handler.get_identity_features(im, reconstruction_dict['dense_lms'][self.lms_ind])
_, yaw_angle, _ = reconstruction_dict['euler_angles']
is_profile = np.abs(yaw_angle* 180 / np.pi)>30
visibility_threshold = 0.4
dense_lms = reconstruction_dict['dense_lms'] / im.shape[::-1]
dense_lms[:, 1] = 1 - dense_lms[:, 1]
im_masked = np.array(im_menpo2PIL(im))
mask_landmarks = np.ones_like(im_masked[:,:,0])
if face_mask is not None:
im_masked = im_masked * np.repeat(np.expand_dims(np.array(face_mask,np.bool),2),3,2)
mask_landmarks *= np.array(face_mask, np.uint8)
im_masked = fill_UV(im_PIL2menpo(im_masked))
im_masked.pixels = np.concatenate([im_masked.pixels, np.expand_dims(mask_landmarks,0)],0)
img_uv_src = self.render_uv_image(im_masked, dense_lms[self.mask])
mask_landmarks = img_uv_src.pixels[3]<0.5
img_uv_src.pixels = img_uv_src.pixels[0:3]
img_uv_src = fill_UV(img_uv_src)
if is_profile:
mask_landmarks = binary_dilation(mask_landmarks,disk(5))
visibility_threshold = 0.6
img_uv_src.pixels[:,mask_landmarks] = 0
tcoords = self.tcoords_full.copy()
tcoords.points[self.mask] = self.tcoords.points
tmesh = TexturedTriMesh(reconstruction_dict['vertices'], tcoords.points, img_uv_src,
trilist=reconstruction_dict['trilist'])
tmesh_masked = tmesh.from_mask(self.mask)
tmesh_rotated = TexturedTriMesh(reconstruction_dict['vertices_rotated'], tmesh.tcoords.points, tmesh.texture,
trilist=tmesh.trilist)
view_angle_src = self.camera_tri_angle_src(tmesh_rotated)
view_angle_src_masked = view_angle_src[self.mask]
view_angle_src_masked[~self.tight_mask[self.mask]] = -1 # Only take tight crop from the original image
angle_uv_src, visible = self.render_colored_image(view_angle_src_masked, return_visibility=True)
angle_uv_src.pixels[:,~visible | mask_landmarks] = -1.0
mask = angle_uv_src.pixels[0] < visibility_threshold
mask = ~remove_small_holes(~mask, area_threshold=100000)
if is_profile and self.mode=='soft':
mask = binary_dilation(mask, disk(10))
img_uv_src.pixels[:, mask] = 0
angle_uv_src.pixels[:, mask] = -1
img_uv_src_flipped = img_uv_src.mirror(1)
angle_uv_src_flipped = angle_uv_src.mirror(1)
temp = img_uv_src_flipped.pixels
pad = int((16 / 1024) * self.uv_shape[1])
img_uv_src_flipped.pixels = np.concatenate(
[np.zeros([temp.shape[0], temp.shape[1], pad]), temp[:, :, :-pad]], 2)
temp = angle_uv_src_flipped.pixels
angle_uv_src_flipped.pixels = np.concatenate(
[np.zeros([temp.shape[0], temp.shape[1], pad]), temp[:, :, :-pad]], 2)
img_uv_src_flipped = fill_UV(img_uv_src_flipped)
img_uv_src = fill_UV(img_uv_src)
mask_flipped = (angle_uv_src_flipped.pixels[0] > visibility_threshold) & mask
mask_flipped = remove_small_holes(mask_flipped, area_threshold=100000)
# mask_flipped = binary_dilation(mask_flipped, disk(15))
angle_uv_src.pixels = mask_flipped * angle_uv_src_flipped.pixels + (1 - mask_flipped) * angle_uv_src.pixels
mask_all = mask_flipped.astype(int).copy()
mask_all[~mask_flipped & ~mask] = 2
mask_all = fill_UV(Image(np.tile(mask_all, [3, 1, 1]))).pixels[0]
mask_flipped_g = gaussian(mask_all == 1, sigma=30, multichannel=True, mode='reflect')
mask_flipped_inv_g = gaussian(mask_all == 2, sigma=30, multichannel=True, mode='reflect')
img_uv_src.pixels = mask_flipped_g * img_uv_src_flipped.pixels + mask_flipped_inv_g * img_uv_src.pixels
# img_uv_src.pixels[:,mask_flipped] = img_uv_src_flipped.pixels[:,mask_flipped]
mask = (angle_uv_src.pixels[0] < visibility_threshold)
mask = ~remove_small_holes(~mask, area_threshold=100000)
# mask = binary_dilation(mask, disk(15))
img_uv_src.pixels[:, mask] = 0
angle_uv_src.pixels[:, mask] = -1
img_uv_src = fill_UV(img_uv_src)
tmesh.texture = img_uv_src
tmesh_rotated.texture = img_uv_src
face = Face(
tmesh=tmesh,
tmesh_masked=tmesh_masked,
tmesh_rotated=tmesh_rotated,
img_uv_src=img_uv_src,
angle_uv_src=angle_uv_src,
angle_uv_list=[],
img_uv_list=[],
view_angle_src=view_angle_src,
id_features=id_features,
exclude_mask=mask,
is_profile=is_profile,
mode=self.mode
)
face.angle_uv_list = [np.clip(angle_uv_src.pixels * face.coef_dict()['src'],-1,1)]
face.img_uv_list = [fill_UV(img_uv_src).pixels]
return face
def extend_mask(self, im, mask):
# closed_mask = binary_dilation(mask.pixels[0].astype(np.bool), disk(10))
# extended_mask = ((np.sum(im.pixels, 0) == 0) & (closed_mask & ~mask.pixels[0].astype(np.bool))) | mask.pixels[0].astype(np.bool)
# im_filled = remove_small_holes(np.sum(im.pixels, 0) > 0, area_threshold=1000)
# border = binary_dilation(im_filled, disk(10)) & ~binary_erosion(im_filled, disk(10))
return mask #Image(extended_mask)# | border)
def run_iteration(self, face, key, trg_angle):
imgs, masks, heatmaps, aligned_meshes = self.create_align_syn(face, trg_angle, face.uv_blending[key])
# Run Optimizer
generated_imgs, generated_dlatents = self.projector.run_projection({key: imgs},
{key: masks},
{key: heatmaps},
face.id_features)
img_uv = self.render_uv_image(im_PIL2menpo(generated_imgs[key]), aligned_meshes)
img_uv =fill_UV(img_uv)
img_uv = uv_color_normalize(face.img_uv_src, face.angle_uv_src, img_uv, Image(face.angle_uv_list[len(face.img_uv_list)]))
face.img_uv_list.append(img_uv.pixels)
final_uv, _ = uv_stiching(face.img_uv_list, face.angle_uv_list[:len(face.img_uv_list)], 40)
results_dict = {
'generated_imgs': generated_imgs[key],
'generated_dlatents': generated_dlatents[key],
'imgs': imgs,
'masks': masks,
'aligned_meshes': aligned_meshes,
'img_uv': img_uv,
'final_uv': final_uv
}
face.img_uv_src = final_uv
face.tmesh.texture = final_uv
face.tmesh.tcoords = self.tcoords_full.copy()
face.tmesh.tcoords.points[self.mask] = self.tcoords.points
return face, results_dict
def run(self, im, reconstruction_dict, face_mask=None):
start = time.time()
print('Preprocessing...', end=" ")
# GANFit compatibility
if not 'vertices' in reconstruction_dict: # GANFit
reconstruction_dict['vertices'] = reconstruction_dict['tmesh'].points
reconstruction_dict['trilist'] = reconstruction_dict['tmesh'].trilist
if not 'vertices_rotated' in reconstruction_dict: # GANFit
reconstruction_dict['vertices_rotated'] = apply_camera_only3d(reconstruction_dict['vertices'], reconstruction_dict['camera_params'])
if not 'euler_angles' in reconstruction_dict: # GANFit
reconstruction_dict['euler_angles'] = get_pose(reconstruction_dict['camera_params'])
# Prepare Textured Trimesh with visible part of the face
face = self.get_tmesh(im, reconstruction_dict, face_mask)
img_uv_src = face.img_uv_src.copy()
angle_uv_src = face.angle_uv_src.copy()
print('Done in %.2f secs' % (time.time() - start))
# Prepare view angle maps
start = time.time()
print('Building a Visibility Index...', end=" ")
angle_uv = {}
key_list = ['src']
angle_uv_list = [np.clip(angle_uv_src.pixels * face.coef_dict()['src'],-1,1)]
view_angle_src_full = self.camera_tri_angle_src(face.tmesh_rotated)
tcoord_sampling = np.round(self.tcoords.points*angle_uv_src.shape).astype(np.int)
view_angle_src_full[self.mask] = angle_uv_src.pixels[0, angle_uv_src.shape[0] - tcoord_sampling[:, 1], tcoord_sampling[:, 0]]
view_angle_src_full[~self.tight_mask] = -1 # Only take tight crop from the original image
angle_list = [np.clip(view_angle_src_full * face.coef_dict()['src'],-1,1)]
dummy_im = im_menpo2PIL(img_uv_src)
# For each view calculate angles towards the camera (Visibility scores)
for key, trg_angle in face.rotation_dict().items():
view_angle_trg = self.camera_tri_angle(face.tmesh, pose_angle_deg=trg_angle)
view_angle_trg = np.clip(view_angle_trg * face.coef_dict()[key],-1,1)
_, projected_mesh = rasterize_image(face.tmesh, self.img_shape, pose_angle_deg=trg_angle,
cam_dist=4.5)
_, _, transformation_params = align_im2stylegan(dummy_im, dummy_im,
projected_mesh[self.lms_ind, :2][:, ::-1])
aligned_meshes = align_mesh2stylegan(projected_mesh[:, :2], transformation_params)
out_of_plane = ((aligned_meshes[:, 0] > 1) |
(aligned_meshes[:, 1] > 1) |
(aligned_meshes[:, 0] < 0) |
(aligned_meshes[:, 1] < 0))
view_angle_trg[out_of_plane] = -1
angle_list.append(view_angle_trg)
angle_uv[key] = self.render_colored_image(view_angle_trg[self.mask])
angle_uv_list.append(angle_uv[key].pixels)
key_list.append(key)
# Building a Visibility Index
max_ind = np.argmax(angle_list, axis=0)
max_ind_one_hot = np.zeros((max_ind.size, max_ind.max() + 1))
max_ind_one_hot[np.arange(max_ind.size), max_ind.flatten()] = 1
max_ind_one_hot = max_ind_one_hot.reshape(max_ind.shape + (-1,))
mask_out_all = np.max(angle_list,axis=0) ==-1
max_ind_one_hot[mask_out_all,:] = 0
uv_blending = {}
for i, key in enumerate(key_list):
uv_blending[key] = np.zeros(max_ind_one_hot[:,i].shape,np.float)
for j in range(i):
uv_blending[key] += max_ind_one_hot[:,j]
uv_blending[key] = np.clip(uv_blending[key],0, 1)
face.uv_blending = uv_blending
face.angle_uv_list = angle_uv_list
print('Done in %.2f secs' % (time.time() - start))
# Projecting for each of the predefined views
start = time.time()
print('Projecting...')
results_dict = {}
for key, trg_angle in face.rotation_dict().items():
face, results_dict[key] = self.run_iteration(face, key, trg_angle)
final_uv = results_dict[key]['final_uv']
print('Done in %.2f secs' % (time.time() - start))
if self.args.frontalize:
start = time.time()
print('Frontalizing...')
imgs = {}
masks = {}
heatmaps = {}
self.projector.perceptual_model.assign_placeholder('id_loss', self.args.use_id_loss_frontalize)
imgs['frontal'], masks['frontal'], heatmaps['frontal'], _ = self.create_align_syn(face, trg_angle=[0, 0, 0], include_mask=face.uv_blending[key])
generated_imgs, generated_dlatents = self.projector.run_projection(imgs, masks, heatmaps, face.id_features, iterations= self.args.iterations_frontalize)
results_dict['frontal'] = im_PIL2menpo(generated_imgs['frontal'])
results_dict['frontalize'] = [imgs, masks, heatmaps, face.id_features]
print('Done in %.2f secs' % (time.time() - start))
return final_uv, results_dict