Files
insightface/parsing/dml_csr/utils/utils.py
QINGPING ZHENG ba7da5fa61 Create utils.py
2022-03-23 00:22:37 +08:00

107 lines
3.4 KiB
Python

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
"""
@Author : Qingping Zheng
@Contact : qingpingzheng2014@gmail.com
@File : utils.py
@Time : 10/01/21 00:00 PM
@Desc :
@License : Licensed under the Apache License, Version 2.0 (the "License");
@Copyright : Copyright 2022 The Authors. All Rights Reserved.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import torchvision
import torch
from PIL import Image
from torch import nn
# colour map
COLORS = [(0,0,0)
# 0=background
,(128,0,0),(0,128,0),(128,128,0),(0,0,128),(128,0,128)
# 1=aeroplane, 2=bicycle, 3=bird, 4=boat, 5=bottle
,(0,128,128),(128,128,128),(64,0,0),(192,0,0),(64,128,0)
# 6=bus, 7=car, 8=cat, 9=chair, 10=cow
,(192,128,0),(64,0,128),(192,0,128),(64,128,128),(192,128,128)
# 11=diningtable, 12=dog, 13=horse, 14=motorbike, 15=person
,(0,64,0),(128,64,0),(0,192,0),(128,192,0),(0,64,128)]
# 16=potted plant, 17=sheep, 18=sofa, 19=train, 20=tv/monitor
def decode_parsing(labels, num_images=1, num_classes=21, is_pred=False):
"""Decode batch of segmentation masks.
Args:
mask: result of inference after taking argmax.
num_images: number of images to decode from the batch.
num_classes: number of classes to predict (including background).
Returns:
A batch with num_images RGB images of the same size as the input.
"""
pred_labels = labels[:num_images].clone().cpu().data
if is_pred:
pred_labels = torch.argmax(pred_labels, dim=1)
n, h, w = pred_labels.size()
labels_color = torch.zeros([n, 3, h, w], dtype=torch.uint8)
for i, c in enumerate(COLORS):
c0 = labels_color[:, 0, :, :]
c1 = labels_color[:, 1, :, :]
c2 = labels_color[:, 2, :, :]
c0[pred_labels == i] = c[0]
c1[pred_labels == i] = c[1]
c2[pred_labels == i] = c[2]
return labels_color
def inv_preprocess(imgs, num_images):
"""Inverse preprocessing of the batch of images.
Add the mean vector and convert from BGR to RGB.
Args:
imgs: batch of input images.
num_images: number of images to apply the inverse transformations on.
img_mean: vector of mean colour values.
Returns:
The batch of the size num_images with the same spatial dimensions as the input.
"""
rev_imgs = imgs[:num_images].clone().cpu().data
rev_normalize = NormalizeInverse(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
for i in range(num_images):
rev_imgs[i] = rev_normalize(rev_imgs[i])
return rev_imgs
class NormalizeInverse(torchvision.transforms.Normalize):
"""
Undoes the normalization and returns the reconstructed images in the input domain.
"""
def __init__(self, mean, std):
mean = torch.as_tensor(mean)
std = torch.as_tensor(std)
std_inv = 1 / (std + 1e-7)
mean_inv = -mean * std_inv
super().__init__(mean=mean_inv, std=std_inv)
class SingleGPU(nn.Module):
def __init__(self, module):
super(SingleGPU, self).__init__()
self.module = module
def forward(self, x):
return self.module(x.cuda(non_blocking=True))