mirror of
https://github.com/deepinsight/insightface.git
synced 2025-12-30 08:02:27 +00:00
181 lines
5.7 KiB
Python
181 lines
5.7 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
"""
|
|
@Author : Qingping Zheng
|
|
@Contact : qingpingzheng2014@gmail.com
|
|
@File : datasets.py
|
|
@Time : 10/01/21 00:00 PM
|
|
@Desc :
|
|
@License : Licensed under the Apache License, Version 2.0 (the "License");
|
|
@Copyright : Copyright 2015 The Authors. All Rights Reserved.
|
|
"""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import os.path as osp
|
|
import pickle
|
|
import random
|
|
|
|
import torch
|
|
import torchvision.transforms.functional as TF
|
|
|
|
from glob import glob
|
|
from typing import Tuple
|
|
from utils import transforms
|
|
|
|
|
|
class FaceDataSet(torch.utils.data.Dataset):
|
|
"""Face data set for model training and validating
|
|
|
|
Examples:
|
|
|
|
./CelebAMask
|
|
|---test
|
|
|---train
|
|
|---images
|
|
|---0.jpg
|
|
|---1.jpg
|
|
|---labels
|
|
|---0.png
|
|
|---1.png
|
|
|---edges
|
|
|---0.png
|
|
|---1.png
|
|
|---valid
|
|
|---label_names.txt
|
|
|---test_list.txt
|
|
|---train_list.txt
|
|
|---images/0.jpg labels/0.png
|
|
|---images/1.jpg labels/1.png
|
|
|---valid_list.txt
|
|
|
|
Args:
|
|
root: A string, training/validating dataset path, e.g. "./CelebAMask"
|
|
dataset: A string, one of `"train"`, `"test"`, `"valid"`.
|
|
crop_size: A list of two intergers.
|
|
scale_factor: A float number.
|
|
rotation_factor: An integer number.
|
|
ignore_label: An integer number, default is 255.
|
|
transformer: A function of torchvision.transforms.Compose([])
|
|
"""
|
|
def __init__(self,
|
|
root: str,
|
|
dataset: str,
|
|
crop_size: list=[473, 473],
|
|
scale_factor: float=0.25,
|
|
rotation_factor: int=30,
|
|
ignore_label: int =255,
|
|
transform=None) -> None:
|
|
|
|
self.root = root
|
|
self.dataset = dataset
|
|
self.crop_size = np.asarray(crop_size)
|
|
self.scale_factor = scale_factor
|
|
self.rotation_factor = rotation_factor
|
|
self.ignore_label = ignore_label
|
|
self.transform = transform
|
|
|
|
self.flip_prob = 0.5
|
|
self.flip_pairs = [[4, 5], [6, 7]]
|
|
self.aspect_ratio = crop_size[1] * 1.0 / crop_size[0]
|
|
|
|
self.file_list_name = osp.join(root, dataset + '_list.txt')
|
|
self.im_list = [line.split()[0][7:-4] for line in open(self.file_list_name).readlines()]
|
|
self.number_samples = len(self.im_list)
|
|
|
|
|
|
def __len__(self) -> int:
|
|
return self.number_samples
|
|
|
|
def _box2cs(self, box: list) -> tuple:
|
|
x, y, w, h = box[:4]
|
|
return self._xywh2cs(x, y, w, h)
|
|
|
|
def _xywh2cs(self, x: float, y: float, w: float, h: float) -> tuple:
|
|
center = np.zeros((2), dtype=np.float32)
|
|
center[0] = x + w * 0.5
|
|
center[1] = y + h * 0.5
|
|
if w > self.aspect_ratio * h:
|
|
h = w * 1.0 / self.aspect_ratio
|
|
elif w < self.aspect_ratio * h:
|
|
w = h * self.aspect_ratio
|
|
scale = np.array([w * 1.0, h * 1.0], dtype=np.float32)
|
|
|
|
return center, scale
|
|
|
|
def __getitem__(self, index: int) -> tuple:
|
|
# Load training image
|
|
im_name = self.im_list[index]
|
|
im_path = osp.join(self.root, self.dataset, 'images', im_name + '.jpg')
|
|
im = cv2.imread(im_path, cv2.IMREAD_COLOR)
|
|
h, w, _ = im.shape
|
|
parsing_anno = np.zeros((h, w), dtype=np.long)
|
|
# Get center and scale
|
|
center, s = self._box2cs([0, 0, w - 1, h - 1])
|
|
r = 0
|
|
|
|
if self.dataset not in ['test', 'valid']:
|
|
edge_path = osp.join(self.root, self.dataset, 'edges', im_name + '.png')
|
|
edge = cv2.imread(edge_path, cv2.IMREAD_GRAYSCALE)
|
|
parsing_anno_path = osp.join(self.root, self.dataset, 'labels', im_name + '.png')
|
|
parsing_anno = cv2.imread(parsing_anno_path, cv2.IMREAD_GRAYSCALE)
|
|
|
|
if self.dataset in 'train':
|
|
|
|
sf = self.scale_factor
|
|
rf = self.rotation_factor
|
|
s = s * np.clip(np.random.randn() * sf + 1, 1 - sf, 1 + sf)
|
|
r = np.clip(np.random.randn() * rf, -rf * 2, rf * 2) \
|
|
if random.random() <= 0.6 else 0
|
|
|
|
trans = transforms.get_affine_transform(center, s, r, self.crop_size)
|
|
image = cv2.warpAffine(
|
|
im,
|
|
trans,
|
|
(int(self.crop_size[1]), int(self.crop_size[0])),
|
|
flags=cv2.INTER_LINEAR,
|
|
borderMode=cv2.BORDER_CONSTANT,
|
|
borderValue=(0, 0, 0))
|
|
|
|
if self.dataset not in ['test', 'valid']:
|
|
edge = cv2.warpAffine(
|
|
edge,
|
|
trans,
|
|
(int(self.crop_size[1]), int(self.crop_size[0])),
|
|
flags=cv2.INTER_LINEAR,
|
|
borderMode=cv2.BORDER_CONSTANT,
|
|
borderValue=(0, 0, 0))
|
|
|
|
if self.transform:
|
|
image = self.transform(image)
|
|
|
|
meta = {
|
|
'name': im_name,
|
|
'center': center,
|
|
'height': h,
|
|
'width': w,
|
|
'scale': s,
|
|
'rotation': r,
|
|
'origin': image
|
|
}
|
|
|
|
if self.dataset not in 'train':
|
|
return image, meta
|
|
else:
|
|
|
|
label_parsing = cv2.warpAffine(
|
|
parsing_anno,
|
|
trans,
|
|
(int(self.crop_size[1]), int(self.crop_size[0])),
|
|
flags=cv2.INTER_NEAREST,
|
|
borderMode=cv2.BORDER_CONSTANT,
|
|
borderValue=(255))
|
|
|
|
label_parsing = torch.from_numpy(label_parsing)
|
|
|
|
return image, label_parsing, edge, meta
|
|
|