Files
insightface/recognition/oneflow_face/ofrecord_util.py
2021-01-20 17:25:35 +08:00

151 lines
4.4 KiB
Python

import os
import oneflow as flow
from config import config
def train_dataset_reader(
data_dir, batch_size, data_part_num, part_name_suffix_length=1
):
if os.path.exists(data_dir):
print("Loading train data from {}".format(data_dir))
else:
raise Exception("Invalid train dataset dir", data_dir)
image_blob_conf = flow.data.BlobConf(
"encoded",
shape=(112, 112, 3),
dtype=flow.float,
codec=flow.data.ImageCodec(
image_preprocessors=[
flow.data.ImagePreprocessor("bgr2rgb"),
flow.data.ImagePreprocessor("mirror"),
]
),
preprocessors=[
flow.data.NormByChannelPreprocessor(
mean_values=(127.5, 127.5, 127.5), std_values=(128, 128, 128)
),
],
)
label_blob_conf = flow.data.BlobConf(
"label", shape=(), dtype=flow.int32, codec=flow.data.RawCodec()
)
return flow.data.decode_ofrecord(
data_dir,
(label_blob_conf, image_blob_conf),
batch_size=batch_size,
data_part_num=data_part_num,
part_name_prefix=config.part_name_prefix,
part_name_suffix_length=config.part_name_suffix_length,
shuffle=config.shuffle,
buffer_size=16384,
)
def validation_dataset_reader(val_dataset_dir, val_batch_size=1, val_data_part_num=1):
# lfw: (12000L, 3L, 112L, 112L)
# cfp_fp: (14000L, 3L, 112L, 112L)
# agedb_30: (12000L, 3L, 112L, 112L)
if os.path.exists(val_dataset_dir):
print("Loading validation data from {}".format(val_dataset_dir))
else:
raise Exception("Invalid validation dataset dir", val_dataset_dir)
color_space = "RGB"
ofrecord = flow.data.ofrecord_reader(
val_dataset_dir,
batch_size=val_batch_size,
data_part_num=val_data_part_num,
part_name_suffix_length=1,
shuffle_after_epoch=False,
)
image = flow.data.OFRecordImageDecoder(ofrecord, "encoded", color_space=color_space)
issame = flow.data.OFRecordRawDecoder(
ofrecord, "issame", shape=(), dtype=flow.int32
)
rsz, scale, new_size = flow.image.Resize(image, target_size=(112,112), channels=3)
normal = flow.image.CropMirrorNormalize(
rsz,
color_space=color_space,
crop_h=0,
crop_w=0,
crop_pos_y=0.5,
crop_pos_x=0.5,
mean=[127.5, 127.5, 127.5],
std=[128.0, 128.0, 128.0],
output_dtype=flow.float,
)
normal = flow.transpose(normal, name="transpose_val", perm=[0, 2, 3, 1])
return issame, normal
def load_synthetic(config):
batch_size = config.train_batch_size
image_size = 112
label = flow.data.decode_random(
shape=(),
dtype=flow.int32,
batch_size=batch_size,
initializer=flow.zeros_initializer(flow.int32),
)
image = flow.data.decode_random(
shape=(image_size, image_size, 3), dtype=flow.float, batch_size=batch_size,
)
return label, image
def load_train_dataset(args):
data_dir = config.dataset_dir
batch_size = args.train_batch_size
data_part_num = config.train_data_part_num
part_name_suffix_length = config.part_name_suffix_length
print("train batch size in load train dataset: ", batch_size)
labels, images = train_dataset_reader(
data_dir, batch_size, data_part_num, part_name_suffix_length
)
return labels, images
def load_lfw_dataset(args):
data_dir = args.lfw_dataset_dir
batch_size = args.val_batch_size_per_device
data_part_num = args.val_data_part_num
(issame, images) = validation_dataset_reader(
val_dataset_dir=data_dir,
val_batch_size=batch_size,
val_data_part_num=data_part_num,
)
return issame, images
def load_cfp_fp_dataset(args):
data_dir = args.cfp_fp_dataset_dir
batch_size = args.val_batch_size_per_device
data_part_num = args.val_data_part_num
(issame, images) = validation_dataset_reader(
val_dataset_dir=data_dir,
val_batch_size=batch_size,
val_data_part_num=data_part_num,
)
return issame, images
def load_agedb_30_dataset(args):
data_dir = args.agedb_30_dataset_dir
batch_size = args.val_batch_size_per_device
data_part_num = args.val_data_part_num
(issame, images) = validation_dataset_reader(
val_dataset_dir=data_dir,
val_batch_size=batch_size,
val_data_part_num=data_part_num,
)
return issame, images