Improved landmark differentiability by heatmaps.

This commit is contained in:
Baris Gecer
2022-05-29 14:27:58 +01:00
parent 2e5d23ee0e
commit 2a8b181d4d
6 changed files with 166 additions and 24 deletions

View File

@@ -63,7 +63,7 @@ parser.add_argument('--use_lpips_loss', default=100, help='Use LPIPS perceptual
parser.add_argument('--use_l1_penalty', default=0.5, help='Use L1 penalty on latents; 0 to disable, > 0 to scale.', type=float)
parser.add_argument('--use_discriminator_loss', default=0.5, help='Use trained discriminator to evaluate realism.', type=float)
parser.add_argument('--use_adaptive_loss', default=False, help='Use the adaptive robust loss function from Google Research for pixel and VGG feature loss.', type=str2bool, nargs='?', const=True)
parser.add_argument('--use_landmark_loss', default=5, help='Use landmark loss; 0 to disable, > 0 to scale.', type=float)
parser.add_argument('--use_landmark_loss', default=200, help='Use landmark loss; 0 to disable, > 0 to scale.', type=float)
parser.add_argument('--use_id_loss', default=10, help='Use landmark loss; 0 to disable, > 0 to scale.', type=float)
parser.add_argument('--use_id_loss_frontalize', default=100, help='Use landmark loss; 0 to disable, > 0 to scale.', type=float)

View File

@@ -24,9 +24,9 @@ class Landmark_Handler():
net_model = networks.DNFaceMultiView('')
with tf.variable_scope('net'):
lms_heatmap_prediction, states = net_model._build_network(generated_image, datas=None, is_training=False,
self.lms_heatmap_prediction, states = net_model._build_network(generated_image, datas=None, is_training=False,
n_channels=n_landmarks)
self.pts_predictions = tf_heatmap_to_lms(lms_heatmap_prediction)
self.pts_predictions = tf_heatmap_to_lms(self.lms_heatmap_prediction)
variables = tf.all_variables()
variables_to_restore = [v for v in variables if v.name.split('/')[0] == 'net']
self.saver = tf.train.Saver(variables_to_restore)

View File

@@ -10,7 +10,8 @@ import time
from utils.ganfit_camera import apply_camera_only3d, get_pose
from utils.utils import *
from core.arcface_handler import Arcface_Handler
import cv2
from utils import generate_heatmap
class Face:
def __init__(self, tmesh,
@@ -256,11 +257,17 @@ class Operator:
projected_mesh[self.lms_ind][:,::-1])
aligned_meshes = align_mesh2stylegan(projected_mesh, transformation_params)
landmarks = aligned_meshes[self.lms_ind]
landmarks[:,1] = 1 - landmarks[:,1]
heatmaps = generate_heatmap.generate_heatmaps(width=self.args.model_res,
height=self.args.model_res,
points=landmarks*self.args.model_res,
sigma=25)
landmarks = landmarks[:,::-1]
landmarks[:,0] = 1 - landmarks[:,0]
aligned_meshes = aligned_meshes[self.mask]
#TODO: landmarks to heatmaps
return imgs, masks, landmarks, aligned_meshes
return imgs, masks, heatmaps, aligned_meshes
def get_tmesh(self, im, reconstruction_dict, face_mask):
id_features = self.arcface_handler.get_identity_features(im, reconstruction_dict['dense_lms'][self.lms_ind])
@@ -377,12 +384,12 @@ class Operator:
def run_iteration(self, face, key, trg_angle):
imgs, masks, landmarks, aligned_meshes = self.create_align_syn(face, trg_angle, face.uv_blending[key])
imgs, masks, heatmaps, aligned_meshes = self.create_align_syn(face, trg_angle, face.uv_blending[key])
# Run Optimizer
generated_imgs, generated_dlatents = self.projector.run_projection({key: imgs},
{key: masks},
{key: landmarks},
{key: heatmaps},
face.id_features)
img_uv = self.render_uv_image(im_PIL2menpo(generated_imgs[key]), aligned_meshes)
@@ -497,13 +504,13 @@ class Operator:
print('Frontalizing...')
imgs = {}
masks = {}
landmarks = {}
heatmaps = {}
self.projector.perceptual_model.assign_placeholder('id_loss', self.args.use_id_loss_frontalize)
imgs['frontal'], masks['frontal'], landmarks['frontal'], _ = self.create_align_syn(face, trg_angle=[0, 0, 0], include_mask=face.uv_blending[key])
generated_imgs, generated_dlatents = self.projector.run_projection(imgs, masks, landmarks, face.id_features, iterations= self.args.iterations_frontalize)
imgs['frontal'], masks['frontal'], heatmaps['frontal'], _ = self.create_align_syn(face, trg_angle=[0, 0, 0], include_mask=face.uv_blending[key])
generated_imgs, generated_dlatents = self.projector.run_projection(imgs, masks, heatmaps, face.id_features, iterations= self.args.iterations_frontalize)
results_dict['frontal'] = im_PIL2menpo(generated_imgs['frontal'])
results_dict['frontalize'] = [imgs, masks, landmarks, face.id_features]
results_dict['frontalize'] = [imgs, masks, heatmaps, face.id_features]
print('Done in %.2f secs' % (time.time() - start))
return final_uv, results_dict

View File

@@ -98,7 +98,7 @@ class PerceptualModel:
self.perc_model = None
self.ref_img = None
self.ref_weight = None
self.ref_landmarks = None
self.ref_heatmaps = None
self.perceptual_model = None
self.ref_img_features = None
self.features_weight = None
@@ -161,11 +161,12 @@ class PerceptualModel:
landmark_model = Landmark_Handler(self.args, self.sess, generated_image/255)
landmark_model.load_model()
ibug84to68_ind = list(range(0, 33, 2)) + list(range(33, 84))
self.generated_heatmaps = tf.gather(landmark_model.lms_heatmap_prediction, ibug84to68_ind, axis=3)
self.generated_landmarks = tf.gather(landmark_model.pts_predictions, ibug84to68_ind, axis=1)
self.ref_landmarks = tf.get_variable('ref_landmarks', shape=self.generated_landmarks.shape,
dtype='float32', initializer=tf.initializers.zeros())
self.add_placeholder("ref_landmarks")
self.ref_heatmaps = tf.get_variable('ref_heatmaps', shape=self.generated_heatmaps.shape,
dtype='float32', initializer=tf.initializers.zeros())
self.add_placeholder("ref_heatmaps")
self.generated_id, vars, _ = arcface_handler.get_input_features(generated_image / 255, self.generated_landmarks[:, :, ::-1])
self.init_id_vars = tf.variables_initializer(vars)
@@ -209,7 +210,7 @@ class PerceptualModel:
tflib.convert_images_from_uint8(generated_image_tensor, nhwc_to_nchw=True), self.stub))
# - discriminator_network.get_output_for(tflib.convert_images_from_uint8(ref_img, nhwc_to_nchw=True), stub)
if self.landmark_loss is not None:
self.loss += self.landmark_loss * tf.math.reduce_mean(tf.reduce_sum(tf.pow(self.ref_landmarks - self.generated_landmarks,2),2))
self.loss += self.landmark_loss * tf.math.reduce_mean(tf.reduce_sum(tf.pow(self.ref_heatmaps - self.generated_heatmaps, 2), 2))
if self.id_loss is not None:
self.id_loss_comp = tf.losses.cosine_distance(self.generated_id, self.org_features, 1)
self.loss += self.id_loss * self.id_loss_comp
@@ -232,11 +233,16 @@ class PerceptualModel:
self.sess.graph.finalize() # Graph is read-only after this statement.
def set_reference_images(self, images_PIL, masks_PIL, landmarks, id_features):
def set_reference_images(self, images_PIL, masks_PIL, heatmaps, id_features):
assert(len(images_PIL) != 0 and len(images_PIL) <= self.batch_size)
loaded_image = load_images(images_PIL, self.img_size, sharpen=self.sharpen_input)
loaded_mask = load_images(masks_PIL, self.img_size, sharpen=self.sharpen_input, im_type='L')
loaded_landmarks = [lms*self.img_size for lms in landmarks]
heatmaps = np.transpose(np.array(heatmaps), [0, 2, 3, 1])
input_size = np.array(heatmaps).shape[2]
output_size = int(self.ref_heatmaps.shape[1])
bin_size = input_size // output_size
loaded_heatmaps = heatmaps.reshape((heatmaps.shape[0], output_size, bin_size,
output_size, bin_size, 68)).max(4).max(2)
image_features = None
if self.perceptual_model is not None:
image_features = self.perceptual_model.predict_on_batch(preprocess_input(np.array(loaded_image)))
@@ -283,7 +289,7 @@ class PerceptualModel:
self.assign_placeholder("ref_weight", image_mask)
self.assign_placeholder("ref_img", loaded_image)
self.assign_placeholder("org_features", id_features)
self.assign_placeholder("ref_landmarks", loaded_landmarks)
self.assign_placeholder("ref_heatmaps", loaded_heatmaps)
def optimize(self, vars_to_optimize, iterations=200):
self.sess.run(self._reset_global_step)

View File

@@ -76,7 +76,7 @@ class Projection_Handler():
self.perceptual_model.build_perceptual_model(self.generator, discriminator_network)
self.perceptual_model.assign_placeholder('id_loss', args.use_id_loss)
def run_projection(self, input_images, masks, landmarks, id_features, iterations=None):
def run_projection(self, input_images, masks, heatmaps, id_features, iterations=None):
n_iteration = self.args.iterations
if iterations is not None:
n_iteration = iterations
@@ -90,7 +90,7 @@ class Projection_Handler():
# tqdm._instances.clear()
images_batch = [input_images[x] for x in names]
masks_batch = [masks[x] for x in names]
landmarks_batch = [landmarks[x] for x in names]
heatmaps_batch = [heatmaps[x] for x in names]
# if args.output_video:
# video_out = {}
# for name in names:
@@ -118,7 +118,7 @@ class Projection_Handler():
self.generator.set_dlatents(dlatents)
## OPTIMIZATION
self.perceptual_model.set_reference_images(images_batch, masks_batch, landmarks_batch, id_features)
self.perceptual_model.set_reference_images(images_batch, masks_batch, heatmaps_batch, id_features)
op = self.perceptual_model.optimize(self.generator.dlatent_variable, iterations=n_iteration)
pbar = tqdm(op, leave=False, total=n_iteration)

View File

@@ -0,0 +1,129 @@
import numpy as np
import math
import cv2
# Adapted from: https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
def _gaussian(size=3, sigma=0.25, amplitude=1, normalize=False, width=None, height=None, sigma_horz=None,
sigma_vert=None, mean_horz=0.5, mean_vert=0.5):
""" Generate a guassian kernel.
Args:
size (int): The size of the kernel if the width or height are not specified
sigma (float): Standard deviation of the kernel if sigma_horz or sigma_vert are not specified
amplitude: The scale of the kernel
normalize: If True, the kernel will be normalized such as values will sum to one
width (int, optional): The width of the kernel
height (int, optional): The height of the kernel
sigma_horz (float, optional): Horizontal standard deviation of the kernel
sigma_vert (float, optional): Vertical standard deviation of the kernel
mean_horz (float): Horizontal mean of the kernel
mean_vert (float): Vertical mean of the kernel
Returns:
np.array: The computed gaussian kernel
"""
# handle some defaults
if width is None:
width = size
if height is None:
height = size
if sigma_horz is None:
sigma_horz = sigma
if sigma_vert is None:
sigma_vert = sigma
center_x = mean_horz * width + 0.5
center_y = mean_vert * height + 0.5
gauss = np.empty((height, width), dtype=np.float32)
# generate kernel
for i in range(height):
for j in range(width):
gauss[i][j] = amplitude * math.exp(-(math.pow((j + 1 - center_x) / (
sigma_horz * width), 2) / 2.0 + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0))
if normalize:
gauss = gauss / np.sum(gauss)
return gauss
# Adapted from: https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
def draw_gaussian(image, point, sigma):
""" Draw gaussian circle at a point in an image.
Args:
image (np.array): An image of shape (H, W)
point (np.array): The center point of the guassian circle
sigma (float): Standard deviation of the gaussian kernel
Returns:
np.array: The image with the drawn gaussian.
"""
# Check if the gaussian is inside
point[0] = round(point[0], 2)
point[1] = round(point[1], 2)
ul = [math.floor(point[0] - 7.5 * sigma), math.floor(point[1] - 7.5 * sigma)]
br = [math.floor(point[0] + 7.5 * sigma), math.floor(point[1] + 7.5 * sigma)]
if (ul[0] > image.shape[1] or ul[1] >
image.shape[0] or br[0] < 1 or br[1] < 1):
return image
size = 15 * sigma + 1
g = _gaussian(size, sigma=0.1)
g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) -
int(max(1, ul[0])) + int(max(1, -ul[0]))]
g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) -
int(max(1, ul[1])) + int(max(1, -ul[1]))]
img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))]
img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))]
assert (g_x[0] > 0 and g_y[1] > 0)
image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] = \
image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] + g[g_y[0] - 1:g_y[1], g_x[0] - 1:g_x[1]]
image[image > 1] = 1
return image
# Adapted from: https://github.com/1adrianb/face-alignment/blob/master/face_alignment/api.py
def generate_heatmaps(height, width, points, sigma=None):
""" Generate heatmaps corresponding to a set of points.
Args:
height (int): Heatmap height
width (int): Heatmap width
points (np.array): An array of points of shape (N, 2)
sigma (float, optional): Standard deviation of the gaussian kernel. If not specified it will be determined
from the width of the heatmap
Returns:
np.array: The generated heatmaps.
"""
sigma = max(1, int(np.round(width / 128.))) if sigma is None else sigma
heatmaps = np.zeros((points.shape[0], height, width), dtype=np.float32)
for i in range(points.shape[0]):
if points[i, 0] > 0:
heatmaps[i] = draw_gaussian(
heatmaps[i], points[i], sigma)
return heatmaps
if __name__ == "__main__":
#you can use [X,2] matrix
points = np.array([
[(30.2946)+8, 51.6963],
[(65.5318)+8, 51.5014],
[(48.0252)+8, 71.7366],
[(33.5493)+8, 92.3655],
[(62.7299)+8, 92.2041]], dtype=np.float32)
heatmaps = generate_heatmaps(width = 112,
height = 112,
points = points,
sigma = 3)
print(heatmaps.shape)
final_heatmap = np.sum(heatmaps, axis=0)
cv2.imwrite("final_heatmap.png", final_heatmap*255)
print("end")