Files
insightface/reconstruction/ostec/core/projection_handler.py
2022-05-29 14:27:58 +01:00

200 lines
9.8 KiB
Python

# Copyright (c) 2020, Baris Gecer. All rights reserved.
#
# This work is made available under the CC BY-NC-SA 4.0.
# To view a copy of this license, see LICENSE
import os
import argparse
import pickle
from tqdm.auto import tqdm
import PIL.Image
from PIL import ImageFilter
import numpy as np
import external.stylegan2.dnnlib.tflib as tflib
from external.stylegan2 import pretrained_networks
from core.generator_model import Generator
from core.perceptual_model import PerceptualModel, load_images
import external.stylegan2.dnnlib
from keras.models import load_model
from keras.applications.resnet50 import preprocess_input
def split_to_batches(l, n):
for i in range(0, len(l), n):
yield l[i:i + n]
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
class Projection_Handler():
def __init__(self, args):
self.args = args
# Initialize generator and perceptual model
tflib.init_tf()
generator_network, discriminator_network, Gs_network = pretrained_networks.load_networks(args.model_url)
self.generator = Generator(Gs_network, args.batch_size, randomize_noise=args.randomize_noise)
if (args.dlatent_avg != ''):
self.generator.set_dlatent_avg(np.load(args.dlatent_avg))
perc_model = None
if (args.use_lpips_loss > 0.00000001):
if external.stylegan2.dnnlib.util.is_url(args.vgg_url):
stream = external.stylegan2.dnnlib.util.open_url(args.vgg_url, cache_dir='../.stylegan2-cache')
else:
stream = open(args.vgg_url, 'rb')
with stream as f:
perc_model = pickle.load(f)
self.perceptual_model = PerceptualModel(args, perc_model=perc_model, batch_size=args.batch_size)
self.ff_model = None
if (self.ff_model is None):
if os.path.exists(self.args.load_resnet):
from keras.applications.resnet50 import preprocess_input
print("Loading ResNet Model:")
self.ff_model = load_model(self.args.load_resnet)
# self.ff_model._make_predict_function()
dummy_im = np.zeros([args.batch_size, args.resnet_image_size, args.resnet_image_size, 3], np.uint8)
self.ff_model.predict(preprocess_input(dummy_im))
if (self.ff_model is None):
if os.path.exists(self.args.load_effnet):
from efficientnet import preprocess_input
print("Loading EfficientNet Model:")
self.ff_model = load_model(self.args.load_effnet)
self.perceptual_model.build_perceptual_model(self.generator, discriminator_network)
self.perceptual_model.assign_placeholder('id_loss', args.use_id_loss)
def run_projection(self, input_images, masks, heatmaps, id_features, iterations=None):
n_iteration = self.args.iterations
if iterations is not None:
n_iteration = iterations
return_imgs = {}
return_dlatents = {}
# Optimize (only) dlatents by minimizing perceptual loss between reference and generated images in feature space
for names in split_to_batches(list(input_images.keys()), self.args.batch_size):
#split_to_batches(list(input_images.keys()), self.args.batch_size):
#tqdm(split_to_batches(list(input_images.keys()), self.args.batch_size),
#total=len(input_images) // self.args.batch_size):
# tqdm._instances.clear()
images_batch = [input_images[x] for x in names]
masks_batch = [masks[x] for x in names]
heatmaps_batch = [heatmaps[x] for x in names]
# if args.output_video:
# video_out = {}
# for name in names:
# video_out[name] = cv2.VideoWriter(os.path.join(args.video_dir, f'{name}.avi'),
# cv2.VideoWriter_fourcc(*args.video_codec), args.video_frame_rate,
# (args.video_size, args.video_size))
## REGRESSION
dlatents = None
if (self.args.load_last != ''): # load previous dlatents for initialization
for name in names:
dl = np.expand_dims(np.load(os.path.join(self.args.load_last, f'{name}.npy')), axis=0)
if (dlatents is None):
dlatents = dl
else:
dlatents = np.vstack((dlatents, dl))
else:
if (self.ff_model is not None): # predict initial dlatents with ResNet model
if (self.args.use_preprocess_input):
dlatents = self.ff_model.predict(
preprocess_input(load_images(images_batch, image_size=self.args.resnet_image_size)))
else:
dlatents = self.ff_model.predict(load_images(images_batch, image_size=self.args.resnet_image_size))
if dlatents is not None:
self.generator.set_dlatents(dlatents)
## OPTIMIZATION
self.perceptual_model.set_reference_images(images_batch, masks_batch, heatmaps_batch, id_features)
op = self.perceptual_model.optimize(self.generator.dlatent_variable, iterations=n_iteration)
pbar = tqdm(op, leave=False, total=n_iteration)
vid_count = 0
best_loss = None
best_dlatent = None
avg_loss_count = 0
if self.args.early_stopping:
avg_loss = prev_loss = None
for loss_dict in pbar:
if self.args.early_stopping: # early stopping feature
if prev_loss is not None:
if avg_loss is not None:
avg_loss = 0.5 * avg_loss + (prev_loss - loss_dict["loss"])
if avg_loss < self.args.early_stopping_threshold: # count while under threshold; else reset
avg_loss_count += 1
else:
avg_loss_count = 0
if avg_loss_count > self.args.early_stopping_patience: # stop once threshold is reached
print("")
break
else:
avg_loss = prev_loss - loss_dict["loss"]
pbar.set_description(
" ".join(names) + ": " + "; ".join(["{} {:.4f}".format(k, v) for k, v in loss_dict.items()]))
if best_loss is None or loss_dict["loss"] < best_loss:
if best_dlatent is None or self.args.average_best_loss <= 0.00000001:
best_dlatent = self.generator.get_dlatents()
else:
best_dlatent = 0.25 * best_dlatent + 0.75 * self.generator.get_dlatents()
if self.args.use_best_loss:
self.generator.set_dlatents(best_dlatent)
best_loss = loss_dict["loss"]
# if self.args.output_video and (vid_count % self.args.video_skip == 0):
# batch_frames = self.generator.generate_images()
# for i, name in enumerate(names):
# video_frame = PIL.Image.fromarray(batch_frames[i], 'RGB').resize(
# (self.args.video_size, self.args.video_size), PIL.Image.LANCZOS)
# video_out[name].write(cv2.cvtColor(np.array(video_frame).astype('uint8'), cv2.COLOR_RGB2BGR))
self.generator.stochastic_clip_dlatents()
prev_loss = loss_dict["loss"]
if not self.args.use_best_loss:
best_loss = prev_loss
# pbar.set_postfix(loss="{:.4f}".format(best_loss))
print(" ".join(names), " Loss {:.4f}".format(best_loss))
# if self.args.output_video:
# for name in names:
# video_out[name].release()
# Generate images from found dlatents and save them
if self.args.use_best_loss:
self.generator.set_dlatents(best_dlatent)
generated_images = self.generator.generate_images()
generated_dlatents = self.generator.get_dlatents()
for img_array, dlatent, img_path, img_name in zip(generated_images, generated_dlatents, images_batch,
names):
mask_img = None
if self.args.composite_mask and (self.args.load_mask or self.args.face_mask):
_, im_name = os.path.split(img_path)
mask_img = os.path.join(self.args.mask_dir, f'{im_name}')
if self.args.composite_mask and mask_img is not None and os.path.isfile(mask_img):
orig_img = PIL.Image.open(img_path).convert('RGB')
width, height = orig_img.size
imask = PIL.Image.open(mask_img).convert('L').resize((width, height))
imask = imask.filter(ImageFilter.GaussianBlur(self.args.composite_blur))
mask = np.array(imask) / 255
mask = np.expand_dims(mask, axis=-1)
img_array = mask * np.array(img_array) + (1.0 - mask) * np.array(orig_img)
img_array = img_array.astype(np.uint8)
img = PIL.Image.fromarray(img_array, 'RGB')
return_imgs[img_name] = img
return_dlatents[img_name] = dlatent
self.generator.reset_dlatents()
return return_imgs, return_dlatents