mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-16 13:46:15 +00:00
242 lines
7.5 KiB
Python
242 lines
7.5 KiB
Python
import cv2
|
|
import torch
|
|
import timeit
|
|
import numpy as np
|
|
from PIL import Image
|
|
import torch.nn.functional as F
|
|
from torch.autograd import Variable
|
|
from torchvision import transforms
|
|
|
|
|
|
from FaceHairMask import graph
|
|
from FaceHairMask import graphonomy_process as tr
|
|
|
|
|
|
|
|
label_colours = [(0, 0, 0) for i in range(20)]
|
|
label_colours[2] = (255, 0, 0)
|
|
label_colours[13] = (0, 0, 255)
|
|
|
|
|
|
def custom_decode_labels(mask, num_images=1, num_classes=20):
|
|
"""Decode batch of segmentation masks.
|
|
|
|
Args:
|
|
mask: result of inference after taking argmax.
|
|
num_images: number of images to decode from the batch.
|
|
num_classes: number of classes to predict (including background).
|
|
|
|
Returns:
|
|
A batch with num_images RGB images of the same size as the input.
|
|
"""
|
|
n, h, w = mask.shape
|
|
|
|
# import ipdb; ipdb.set_trace()
|
|
assert (
|
|
n >= num_images
|
|
), "Batch size %d should be greater or equal than number of images to save %d." % (
|
|
n,
|
|
num_images,
|
|
)
|
|
|
|
hair_mask = torch.where(mask == 2, torch.ones_like(mask), torch.zeros_like(mask))
|
|
|
|
face_mask = torch.where(mask == 13, torch.ones_like(mask), torch.zeros_like(mask))
|
|
|
|
return hair_mask, face_mask
|
|
|
|
|
|
def overlay(frame, mask):
|
|
|
|
mask = np.array(mask)
|
|
frame = np.array(frame)
|
|
|
|
tmp = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
|
|
_, alpha = cv2.threshold(tmp, 0, 255, cv2.THRESH_BINARY)
|
|
b, g, r = cv2.split(mask)
|
|
rgba = [b, g, r, alpha]
|
|
dst = cv2.merge(rgba, 4)
|
|
|
|
# overlay mask on frame
|
|
overlaid_image = cv2.addWeighted(frame, 0.4, dst, 0.1, 0)
|
|
return overlaid_image
|
|
|
|
|
|
def flip(x, dim):
|
|
indices = [slice(None)] * x.dim()
|
|
indices[dim] = torch.arange(
|
|
x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device
|
|
)
|
|
return x[tuple(indices)]
|
|
|
|
|
|
def flip_cihp(tail_list):
|
|
"""
|
|
|
|
:param tail_list: tail_list size is 1 x n_class x h x w
|
|
:return:
|
|
"""
|
|
# tail_list = tail_list[0]
|
|
tail_list_rev = [None] * 20
|
|
for xx in range(14):
|
|
tail_list_rev[xx] = tail_list[xx].unsqueeze(0)
|
|
tail_list_rev[14] = tail_list[15].unsqueeze(0)
|
|
tail_list_rev[15] = tail_list[14].unsqueeze(0)
|
|
tail_list_rev[16] = tail_list[17].unsqueeze(0)
|
|
tail_list_rev[17] = tail_list[16].unsqueeze(0)
|
|
tail_list_rev[18] = tail_list[19].unsqueeze(0)
|
|
tail_list_rev[19] = tail_list[18].unsqueeze(0)
|
|
return torch.cat(tail_list_rev, dim=0)
|
|
|
|
|
|
def decode_labels(mask, num_images=1, num_classes=20):
|
|
"""Decode batch of segmentation masks.
|
|
|
|
Args:
|
|
mask: result of inference after taking argmax.
|
|
num_images: number of images to decode from the batch.
|
|
num_classes: number of classes to predict (including background).
|
|
|
|
Returns:
|
|
A batch with num_images RGB images of the same size as the input.
|
|
"""
|
|
n, h, w = mask.shape
|
|
assert (
|
|
n >= num_images
|
|
), "Batch size %d should be greater or equal than number of images to save %d." % (
|
|
n,
|
|
num_images,
|
|
)
|
|
outputs = np.zeros((num_images, h, w, 3), dtype=np.uint8)
|
|
for i in range(num_images):
|
|
img = Image.new("RGB", (len(mask[i, 0]), len(mask[i])))
|
|
pixels = img.load()
|
|
for j_, j in enumerate(mask[i, :, :]):
|
|
for k_, k in enumerate(j):
|
|
if k < num_classes:
|
|
pixels[k_, j_] = label_colours[k]
|
|
outputs[i] = np.array(img)
|
|
return outputs
|
|
|
|
|
|
def read_img(img_path):
|
|
_img = Image.open(img_path).convert("RGB") # return is RGB pic
|
|
return _img
|
|
|
|
|
|
def img_transform(img, transform=None):
|
|
sample = {"image": img, "label": 0}
|
|
|
|
sample = transform(sample)
|
|
return sample
|
|
|
|
|
|
def inference(net, img=None, device=None):
|
|
"""
|
|
|
|
:param net:
|
|
:return:
|
|
"""
|
|
# adj
|
|
adj2_ = torch.from_numpy(graph.cihp2pascal_nlp_adj).float()
|
|
adj2_test = (
|
|
adj2_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 20).to(device).transpose(2, 3)
|
|
)
|
|
|
|
adj1_ = Variable(torch.from_numpy(graph.preprocess_adj(graph.pascal_graph)).float())
|
|
adj3_test = adj1_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 7).to(device)
|
|
|
|
cihp_adj = graph.preprocess_adj(graph.cihp_graph)
|
|
adj3_ = Variable(torch.from_numpy(cihp_adj).float())
|
|
adj1_test = adj3_.unsqueeze(0).unsqueeze(0).expand(1, 1, 20, 20).to(device)
|
|
|
|
# multi-scale
|
|
scale_list = [1, 0.5, 0.75, 1.25, 1.5, 1.75]
|
|
#scale_list = [1, 0.5, 0.75, 1.25]
|
|
# NOTE: this part of the code assumes img is PIL image in RGB color space
|
|
# We provide torch tensor in range [-1, 1]
|
|
# Bring range to [0, 255]
|
|
img = torch.clamp(img, -1, 1)
|
|
img = (img + 1.0) / 2.0
|
|
img *= 255
|
|
|
|
testloader_list = []
|
|
testloader_flip_list = []
|
|
for pv in scale_list:
|
|
composed_transforms_ts = transforms.Compose(
|
|
[
|
|
tr.Scale_only_img(pv),
|
|
tr.Normalize_xception_tf_only_img(),
|
|
tr.ToTensor_only_img(),
|
|
]
|
|
)
|
|
|
|
composed_transforms_ts_flip = transforms.Compose(
|
|
[
|
|
tr.Scale_only_img(pv),
|
|
tr.HorizontalFlip_only_img(),
|
|
tr.Normalize_xception_tf_only_img(),
|
|
tr.ToTensor_only_img(),
|
|
]
|
|
)
|
|
# NOTE: img [1, 3, 256, 256], (min, max) = (0, 255)
|
|
|
|
# print("original:", img.shape, img.min(), img.max())
|
|
testloader_list.append(img_transform(img, composed_transforms_ts))
|
|
# print(img_transform(img, composed_transforms_ts))
|
|
testloader_flip_list.append(img_transform(img, composed_transforms_ts_flip))
|
|
# print(testloader_list)
|
|
start_time = timeit.default_timer()
|
|
# One testing epoch
|
|
# net.eval()
|
|
# 1 0.5 0.75 1.25 1.5 1.75 ; flip:
|
|
|
|
# NOTE: testloader_list[0]['image'].shape = 3, 420, 620
|
|
|
|
for iii, sample_batched in enumerate(zip(testloader_list, testloader_flip_list)):
|
|
inputs, labels = sample_batched[0]["image"], sample_batched[0]["label"]
|
|
inputs_f, _ = sample_batched[1]["image"], sample_batched[1]["label"]
|
|
inputs = inputs.unsqueeze(0)
|
|
inputs_f = inputs_f.unsqueeze(0)
|
|
inputs = torch.cat((inputs, inputs_f), dim=0)
|
|
if iii == 0:
|
|
_, _, h, w = inputs.size()
|
|
# assert inputs.size() == inputs_f.size()
|
|
|
|
# Forward pass of the mini-batch
|
|
|
|
# TODO: check requires grad functionality
|
|
# inputs = Variable(inputs, requires_grad=False)
|
|
|
|
with torch.no_grad():
|
|
if device is not None:
|
|
inputs = inputs.to(device)
|
|
# outputs = net.forward(inputs)
|
|
outputs = net.forward(
|
|
inputs, adj1_test.to(device), adj3_test.to(device), adj2_test.to(device)
|
|
)
|
|
outputs = (outputs[0] + flip(flip_cihp(outputs[1]), dim=-1)) / 2
|
|
outputs = outputs.unsqueeze(0)
|
|
if iii > 0:
|
|
outputs = F.upsample(
|
|
outputs, size=(h, w), mode="bilinear", align_corners=True
|
|
)
|
|
outputs_final = outputs_final + outputs
|
|
else:
|
|
outputs_final = outputs.clone()
|
|
################ plot pic
|
|
predictions = torch.max(outputs_final, 1)[1]
|
|
# results = predictions.cpu().numpy()
|
|
# vis_res = decode_labels(results)
|
|
# parsing_im = Image.fromarray(vis_res[0])
|
|
# return parsing_im
|
|
|
|
hair_mask, face_mask = custom_decode_labels(predictions)
|
|
|
|
return outputs_final, hair_mask, face_mask
|
|
|
|
# parsing_im.save(output_path+'/{}.png'.format(output_name))
|
|
# cv2.imwrite(output_path+'/{}_gray.png'.format(output_name), results[0, :, :])
|
|
|
|
# end_time = timeit.default_timer()
|
|
# print('time used for the multi-scale image inference' + ' is :' + str(end_time - start_time)) |