Files
insightface/reconstruction/ostec/external/graphonomy/FaceHairMask/graphonomy_inference.py
2022-05-29 14:26:46 +01:00

242 lines
7.5 KiB
Python

import cv2
import torch
import timeit
import numpy as np
from PIL import Image
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import transforms
from FaceHairMask import graph
from FaceHairMask import graphonomy_process as tr
label_colours = [(0, 0, 0) for i in range(20)]
label_colours[2] = (255, 0, 0)
label_colours[13] = (0, 0, 255)
def custom_decode_labels(mask, num_images=1, num_classes=20):
"""Decode batch of segmentation masks.
Args:
mask: result of inference after taking argmax.
num_images: number of images to decode from the batch.
num_classes: number of classes to predict (including background).
Returns:
A batch with num_images RGB images of the same size as the input.
"""
n, h, w = mask.shape
# import ipdb; ipdb.set_trace()
assert (
n >= num_images
), "Batch size %d should be greater or equal than number of images to save %d." % (
n,
num_images,
)
hair_mask = torch.where(mask == 2, torch.ones_like(mask), torch.zeros_like(mask))
face_mask = torch.where(mask == 13, torch.ones_like(mask), torch.zeros_like(mask))
return hair_mask, face_mask
def overlay(frame, mask):
mask = np.array(mask)
frame = np.array(frame)
tmp = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
_, alpha = cv2.threshold(tmp, 0, 255, cv2.THRESH_BINARY)
b, g, r = cv2.split(mask)
rgba = [b, g, r, alpha]
dst = cv2.merge(rgba, 4)
# overlay mask on frame
overlaid_image = cv2.addWeighted(frame, 0.4, dst, 0.1, 0)
return overlaid_image
def flip(x, dim):
indices = [slice(None)] * x.dim()
indices[dim] = torch.arange(
x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device
)
return x[tuple(indices)]
def flip_cihp(tail_list):
"""
:param tail_list: tail_list size is 1 x n_class x h x w
:return:
"""
# tail_list = tail_list[0]
tail_list_rev = [None] * 20
for xx in range(14):
tail_list_rev[xx] = tail_list[xx].unsqueeze(0)
tail_list_rev[14] = tail_list[15].unsqueeze(0)
tail_list_rev[15] = tail_list[14].unsqueeze(0)
tail_list_rev[16] = tail_list[17].unsqueeze(0)
tail_list_rev[17] = tail_list[16].unsqueeze(0)
tail_list_rev[18] = tail_list[19].unsqueeze(0)
tail_list_rev[19] = tail_list[18].unsqueeze(0)
return torch.cat(tail_list_rev, dim=0)
def decode_labels(mask, num_images=1, num_classes=20):
"""Decode batch of segmentation masks.
Args:
mask: result of inference after taking argmax.
num_images: number of images to decode from the batch.
num_classes: number of classes to predict (including background).
Returns:
A batch with num_images RGB images of the same size as the input.
"""
n, h, w = mask.shape
assert (
n >= num_images
), "Batch size %d should be greater or equal than number of images to save %d." % (
n,
num_images,
)
outputs = np.zeros((num_images, h, w, 3), dtype=np.uint8)
for i in range(num_images):
img = Image.new("RGB", (len(mask[i, 0]), len(mask[i])))
pixels = img.load()
for j_, j in enumerate(mask[i, :, :]):
for k_, k in enumerate(j):
if k < num_classes:
pixels[k_, j_] = label_colours[k]
outputs[i] = np.array(img)
return outputs
def read_img(img_path):
_img = Image.open(img_path).convert("RGB") # return is RGB pic
return _img
def img_transform(img, transform=None):
sample = {"image": img, "label": 0}
sample = transform(sample)
return sample
def inference(net, img=None, device=None):
"""
:param net:
:return:
"""
# adj
adj2_ = torch.from_numpy(graph.cihp2pascal_nlp_adj).float()
adj2_test = (
adj2_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 20).to(device).transpose(2, 3)
)
adj1_ = Variable(torch.from_numpy(graph.preprocess_adj(graph.pascal_graph)).float())
adj3_test = adj1_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 7).to(device)
cihp_adj = graph.preprocess_adj(graph.cihp_graph)
adj3_ = Variable(torch.from_numpy(cihp_adj).float())
adj1_test = adj3_.unsqueeze(0).unsqueeze(0).expand(1, 1, 20, 20).to(device)
# multi-scale
scale_list = [1, 0.5, 0.75, 1.25, 1.5, 1.75]
#scale_list = [1, 0.5, 0.75, 1.25]
# NOTE: this part of the code assumes img is PIL image in RGB color space
# We provide torch tensor in range [-1, 1]
# Bring range to [0, 255]
img = torch.clamp(img, -1, 1)
img = (img + 1.0) / 2.0
img *= 255
testloader_list = []
testloader_flip_list = []
for pv in scale_list:
composed_transforms_ts = transforms.Compose(
[
tr.Scale_only_img(pv),
tr.Normalize_xception_tf_only_img(),
tr.ToTensor_only_img(),
]
)
composed_transforms_ts_flip = transforms.Compose(
[
tr.Scale_only_img(pv),
tr.HorizontalFlip_only_img(),
tr.Normalize_xception_tf_only_img(),
tr.ToTensor_only_img(),
]
)
# NOTE: img [1, 3, 256, 256], (min, max) = (0, 255)
# print("original:", img.shape, img.min(), img.max())
testloader_list.append(img_transform(img, composed_transforms_ts))
# print(img_transform(img, composed_transforms_ts))
testloader_flip_list.append(img_transform(img, composed_transforms_ts_flip))
# print(testloader_list)
start_time = timeit.default_timer()
# One testing epoch
# net.eval()
# 1 0.5 0.75 1.25 1.5 1.75 ; flip:
# NOTE: testloader_list[0]['image'].shape = 3, 420, 620
for iii, sample_batched in enumerate(zip(testloader_list, testloader_flip_list)):
inputs, labels = sample_batched[0]["image"], sample_batched[0]["label"]
inputs_f, _ = sample_batched[1]["image"], sample_batched[1]["label"]
inputs = inputs.unsqueeze(0)
inputs_f = inputs_f.unsqueeze(0)
inputs = torch.cat((inputs, inputs_f), dim=0)
if iii == 0:
_, _, h, w = inputs.size()
# assert inputs.size() == inputs_f.size()
# Forward pass of the mini-batch
# TODO: check requires grad functionality
# inputs = Variable(inputs, requires_grad=False)
with torch.no_grad():
if device is not None:
inputs = inputs.to(device)
# outputs = net.forward(inputs)
outputs = net.forward(
inputs, adj1_test.to(device), adj3_test.to(device), adj2_test.to(device)
)
outputs = (outputs[0] + flip(flip_cihp(outputs[1]), dim=-1)) / 2
outputs = outputs.unsqueeze(0)
if iii > 0:
outputs = F.upsample(
outputs, size=(h, w), mode="bilinear", align_corners=True
)
outputs_final = outputs_final + outputs
else:
outputs_final = outputs.clone()
################ plot pic
predictions = torch.max(outputs_final, 1)[1]
# results = predictions.cpu().numpy()
# vis_res = decode_labels(results)
# parsing_im = Image.fromarray(vis_res[0])
# return parsing_im
hair_mask, face_mask = custom_decode_labels(predictions)
return outputs_final, hair_mask, face_mask
# parsing_im.save(output_path+'/{}.png'.format(output_name))
# cv2.imwrite(output_path+'/{}_gray.png'.format(output_name), results[0, :, :])
# end_time = timeit.default_timer()
# print('time used for the multi-scale image inference' + ' is :' + str(end_time - start_time))