mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-16 05:27:56 +00:00
200 lines
9.8 KiB
Python
200 lines
9.8 KiB
Python
# Copyright (c) 2020, Baris Gecer. All rights reserved.
|
|
#
|
|
# This work is made available under the CC BY-NC-SA 4.0.
|
|
# To view a copy of this license, see LICENSE
|
|
|
|
import os
|
|
import argparse
|
|
import pickle
|
|
from tqdm.auto import tqdm
|
|
import PIL.Image
|
|
from PIL import ImageFilter
|
|
import numpy as np
|
|
import external.stylegan2.dnnlib.tflib as tflib
|
|
from external.stylegan2 import pretrained_networks
|
|
from core.generator_model import Generator
|
|
from core.perceptual_model import PerceptualModel, load_images
|
|
import external.stylegan2.dnnlib
|
|
from keras.models import load_model
|
|
from keras.applications.resnet50 import preprocess_input
|
|
|
|
|
|
def split_to_batches(l, n):
|
|
for i in range(0, len(l), n):
|
|
yield l[i:i + n]
|
|
|
|
def str2bool(v):
|
|
if isinstance(v, bool):
|
|
return v
|
|
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
|
return True
|
|
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
|
return False
|
|
else:
|
|
raise argparse.ArgumentTypeError('Boolean value expected.')
|
|
|
|
|
|
class Projection_Handler():
|
|
def __init__(self, args):
|
|
self.args = args
|
|
|
|
# Initialize generator and perceptual model
|
|
tflib.init_tf()
|
|
generator_network, discriminator_network, Gs_network = pretrained_networks.load_networks(args.model_url)
|
|
self.generator = Generator(Gs_network, args.batch_size, randomize_noise=args.randomize_noise)
|
|
|
|
if (args.dlatent_avg != ''):
|
|
self.generator.set_dlatent_avg(np.load(args.dlatent_avg))
|
|
|
|
perc_model = None
|
|
if (args.use_lpips_loss > 0.00000001):
|
|
if external.stylegan2.dnnlib.util.is_url(args.vgg_url):
|
|
stream = external.stylegan2.dnnlib.util.open_url(args.vgg_url, cache_dir='../.stylegan2-cache')
|
|
else:
|
|
stream = open(args.vgg_url, 'rb')
|
|
with stream as f:
|
|
perc_model = pickle.load(f)
|
|
|
|
|
|
self.perceptual_model = PerceptualModel(args, perc_model=perc_model, batch_size=args.batch_size)
|
|
|
|
self.ff_model = None
|
|
if (self.ff_model is None):
|
|
if os.path.exists(self.args.load_resnet):
|
|
from keras.applications.resnet50 import preprocess_input
|
|
print("Loading ResNet Model:")
|
|
self.ff_model = load_model(self.args.load_resnet)
|
|
# self.ff_model._make_predict_function()
|
|
dummy_im = np.zeros([args.batch_size, args.resnet_image_size, args.resnet_image_size, 3], np.uint8)
|
|
self.ff_model.predict(preprocess_input(dummy_im))
|
|
if (self.ff_model is None):
|
|
if os.path.exists(self.args.load_effnet):
|
|
from efficientnet import preprocess_input
|
|
print("Loading EfficientNet Model:")
|
|
self.ff_model = load_model(self.args.load_effnet)
|
|
|
|
self.perceptual_model.build_perceptual_model(self.generator, discriminator_network)
|
|
self.perceptual_model.assign_placeholder('id_loss', args.use_id_loss)
|
|
|
|
def run_projection(self, input_images, masks, heatmaps, id_features, iterations=None):
|
|
n_iteration = self.args.iterations
|
|
if iterations is not None:
|
|
n_iteration = iterations
|
|
return_imgs = {}
|
|
return_dlatents = {}
|
|
# Optimize (only) dlatents by minimizing perceptual loss between reference and generated images in feature space
|
|
for names in split_to_batches(list(input_images.keys()), self.args.batch_size):
|
|
#split_to_batches(list(input_images.keys()), self.args.batch_size):
|
|
#tqdm(split_to_batches(list(input_images.keys()), self.args.batch_size),
|
|
#total=len(input_images) // self.args.batch_size):
|
|
# tqdm._instances.clear()
|
|
images_batch = [input_images[x] for x in names]
|
|
masks_batch = [masks[x] for x in names]
|
|
heatmaps_batch = [heatmaps[x] for x in names]
|
|
# if args.output_video:
|
|
# video_out = {}
|
|
# for name in names:
|
|
# video_out[name] = cv2.VideoWriter(os.path.join(args.video_dir, f'{name}.avi'),
|
|
# cv2.VideoWriter_fourcc(*args.video_codec), args.video_frame_rate,
|
|
# (args.video_size, args.video_size))
|
|
|
|
## REGRESSION
|
|
dlatents = None
|
|
if (self.args.load_last != ''): # load previous dlatents for initialization
|
|
for name in names:
|
|
dl = np.expand_dims(np.load(os.path.join(self.args.load_last, f'{name}.npy')), axis=0)
|
|
if (dlatents is None):
|
|
dlatents = dl
|
|
else:
|
|
dlatents = np.vstack((dlatents, dl))
|
|
else:
|
|
if (self.ff_model is not None): # predict initial dlatents with ResNet model
|
|
if (self.args.use_preprocess_input):
|
|
dlatents = self.ff_model.predict(
|
|
preprocess_input(load_images(images_batch, image_size=self.args.resnet_image_size)))
|
|
else:
|
|
dlatents = self.ff_model.predict(load_images(images_batch, image_size=self.args.resnet_image_size))
|
|
if dlatents is not None:
|
|
self.generator.set_dlatents(dlatents)
|
|
|
|
## OPTIMIZATION
|
|
self.perceptual_model.set_reference_images(images_batch, masks_batch, heatmaps_batch, id_features)
|
|
|
|
op = self.perceptual_model.optimize(self.generator.dlatent_variable, iterations=n_iteration)
|
|
pbar = tqdm(op, leave=False, total=n_iteration)
|
|
vid_count = 0
|
|
best_loss = None
|
|
best_dlatent = None
|
|
avg_loss_count = 0
|
|
if self.args.early_stopping:
|
|
avg_loss = prev_loss = None
|
|
|
|
for loss_dict in pbar:
|
|
if self.args.early_stopping: # early stopping feature
|
|
if prev_loss is not None:
|
|
if avg_loss is not None:
|
|
avg_loss = 0.5 * avg_loss + (prev_loss - loss_dict["loss"])
|
|
if avg_loss < self.args.early_stopping_threshold: # count while under threshold; else reset
|
|
avg_loss_count += 1
|
|
else:
|
|
avg_loss_count = 0
|
|
if avg_loss_count > self.args.early_stopping_patience: # stop once threshold is reached
|
|
print("")
|
|
break
|
|
else:
|
|
avg_loss = prev_loss - loss_dict["loss"]
|
|
pbar.set_description(
|
|
" ".join(names) + ": " + "; ".join(["{} {:.4f}".format(k, v) for k, v in loss_dict.items()]))
|
|
if best_loss is None or loss_dict["loss"] < best_loss:
|
|
if best_dlatent is None or self.args.average_best_loss <= 0.00000001:
|
|
best_dlatent = self.generator.get_dlatents()
|
|
else:
|
|
best_dlatent = 0.25 * best_dlatent + 0.75 * self.generator.get_dlatents()
|
|
if self.args.use_best_loss:
|
|
self.generator.set_dlatents(best_dlatent)
|
|
best_loss = loss_dict["loss"]
|
|
# if self.args.output_video and (vid_count % self.args.video_skip == 0):
|
|
# batch_frames = self.generator.generate_images()
|
|
# for i, name in enumerate(names):
|
|
# video_frame = PIL.Image.fromarray(batch_frames[i], 'RGB').resize(
|
|
# (self.args.video_size, self.args.video_size), PIL.Image.LANCZOS)
|
|
# video_out[name].write(cv2.cvtColor(np.array(video_frame).astype('uint8'), cv2.COLOR_RGB2BGR))
|
|
self.generator.stochastic_clip_dlatents()
|
|
prev_loss = loss_dict["loss"]
|
|
if not self.args.use_best_loss:
|
|
best_loss = prev_loss
|
|
# pbar.set_postfix(loss="{:.4f}".format(best_loss))
|
|
print(" ".join(names), " Loss {:.4f}".format(best_loss))
|
|
|
|
# if self.args.output_video:
|
|
# for name in names:
|
|
# video_out[name].release()
|
|
|
|
# Generate images from found dlatents and save them
|
|
if self.args.use_best_loss:
|
|
self.generator.set_dlatents(best_dlatent)
|
|
|
|
generated_images = self.generator.generate_images()
|
|
generated_dlatents = self.generator.get_dlatents()
|
|
for img_array, dlatent, img_path, img_name in zip(generated_images, generated_dlatents, images_batch,
|
|
names):
|
|
mask_img = None
|
|
if self.args.composite_mask and (self.args.load_mask or self.args.face_mask):
|
|
_, im_name = os.path.split(img_path)
|
|
mask_img = os.path.join(self.args.mask_dir, f'{im_name}')
|
|
if self.args.composite_mask and mask_img is not None and os.path.isfile(mask_img):
|
|
orig_img = PIL.Image.open(img_path).convert('RGB')
|
|
width, height = orig_img.size
|
|
imask = PIL.Image.open(mask_img).convert('L').resize((width, height))
|
|
imask = imask.filter(ImageFilter.GaussianBlur(self.args.composite_blur))
|
|
mask = np.array(imask) / 255
|
|
mask = np.expand_dims(mask, axis=-1)
|
|
img_array = mask * np.array(img_array) + (1.0 - mask) * np.array(orig_img)
|
|
img_array = img_array.astype(np.uint8)
|
|
img = PIL.Image.fromarray(img_array, 'RGB')
|
|
return_imgs[img_name] = img
|
|
return_dlatents[img_name] = dlatent
|
|
|
|
self.generator.reset_dlatents()
|
|
return return_imgs, return_dlatents
|