mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-15 04:37:50 +00:00
151 lines
4.4 KiB
Python
151 lines
4.4 KiB
Python
import os
|
|
import oneflow as flow
|
|
from config import config
|
|
|
|
|
|
def train_dataset_reader(
|
|
data_dir, batch_size, data_part_num, part_name_suffix_length=1
|
|
):
|
|
if os.path.exists(data_dir):
|
|
print("Loading train data from {}".format(data_dir))
|
|
else:
|
|
raise Exception("Invalid train dataset dir", data_dir)
|
|
image_blob_conf = flow.data.BlobConf(
|
|
"encoded",
|
|
shape=(112, 112, 3),
|
|
dtype=flow.float,
|
|
codec=flow.data.ImageCodec(
|
|
image_preprocessors=[
|
|
flow.data.ImagePreprocessor("bgr2rgb"),
|
|
flow.data.ImagePreprocessor("mirror"),
|
|
]
|
|
),
|
|
preprocessors=[
|
|
flow.data.NormByChannelPreprocessor(
|
|
mean_values=(127.5, 127.5, 127.5), std_values=(128, 128, 128)
|
|
),
|
|
],
|
|
)
|
|
|
|
label_blob_conf = flow.data.BlobConf(
|
|
"label", shape=(), dtype=flow.int32, codec=flow.data.RawCodec()
|
|
)
|
|
|
|
return flow.data.decode_ofrecord(
|
|
data_dir,
|
|
(label_blob_conf, image_blob_conf),
|
|
batch_size=batch_size,
|
|
data_part_num=data_part_num,
|
|
part_name_prefix=config.part_name_prefix,
|
|
part_name_suffix_length=config.part_name_suffix_length,
|
|
shuffle=config.shuffle,
|
|
buffer_size=16384,
|
|
)
|
|
|
|
|
|
def validation_dataset_reader(val_dataset_dir, val_batch_size=1, val_data_part_num=1):
|
|
# lfw: (12000L, 3L, 112L, 112L)
|
|
# cfp_fp: (14000L, 3L, 112L, 112L)
|
|
# agedb_30: (12000L, 3L, 112L, 112L)
|
|
if os.path.exists(val_dataset_dir):
|
|
print("Loading validation data from {}".format(val_dataset_dir))
|
|
else:
|
|
raise Exception("Invalid validation dataset dir", val_dataset_dir)
|
|
color_space = "RGB"
|
|
ofrecord = flow.data.ofrecord_reader(
|
|
val_dataset_dir,
|
|
batch_size=val_batch_size,
|
|
data_part_num=val_data_part_num,
|
|
part_name_suffix_length=1,
|
|
shuffle_after_epoch=False,
|
|
)
|
|
image = flow.data.OFRecordImageDecoder(ofrecord, "encoded", color_space=color_space)
|
|
issame = flow.data.OFRecordRawDecoder(
|
|
ofrecord, "issame", shape=(), dtype=flow.int32
|
|
)
|
|
|
|
rsz, scale, new_size = flow.image.Resize(image, target_size=(112,112), channels=3)
|
|
normal = flow.image.CropMirrorNormalize(
|
|
rsz,
|
|
color_space=color_space,
|
|
crop_h=0,
|
|
crop_w=0,
|
|
crop_pos_y=0.5,
|
|
crop_pos_x=0.5,
|
|
mean=[127.5, 127.5, 127.5],
|
|
std=[128.0, 128.0, 128.0],
|
|
output_dtype=flow.float,
|
|
)
|
|
|
|
normal = flow.transpose(normal, name="transpose_val", perm=[0, 2, 3, 1])
|
|
|
|
return issame, normal
|
|
|
|
|
|
def load_synthetic(config):
|
|
batch_size = config.train_batch_size
|
|
image_size = 112
|
|
label = flow.data.decode_random(
|
|
shape=(),
|
|
dtype=flow.int32,
|
|
batch_size=batch_size,
|
|
initializer=flow.zeros_initializer(flow.int32),
|
|
)
|
|
|
|
image = flow.data.decode_random(
|
|
shape=(image_size, image_size, 3), dtype=flow.float, batch_size=batch_size,
|
|
)
|
|
return label, image
|
|
|
|
|
|
def load_train_dataset(args):
|
|
data_dir = config.dataset_dir
|
|
batch_size = args.train_batch_size
|
|
data_part_num = config.train_data_part_num
|
|
part_name_suffix_length = config.part_name_suffix_length
|
|
print("train batch size in load train dataset: ", batch_size)
|
|
labels, images = train_dataset_reader(
|
|
data_dir, batch_size, data_part_num, part_name_suffix_length
|
|
)
|
|
return labels, images
|
|
|
|
|
|
def load_lfw_dataset(args):
|
|
data_dir = args.lfw_dataset_dir
|
|
batch_size = args.val_batch_size_per_device
|
|
data_part_num = args.val_data_part_num
|
|
|
|
(issame, images) = validation_dataset_reader(
|
|
val_dataset_dir=data_dir,
|
|
val_batch_size=batch_size,
|
|
val_data_part_num=data_part_num,
|
|
)
|
|
return issame, images
|
|
|
|
|
|
def load_cfp_fp_dataset(args):
|
|
data_dir = args.cfp_fp_dataset_dir
|
|
batch_size = args.val_batch_size_per_device
|
|
data_part_num = args.val_data_part_num
|
|
|
|
(issame, images) = validation_dataset_reader(
|
|
val_dataset_dir=data_dir,
|
|
val_batch_size=batch_size,
|
|
val_data_part_num=data_part_num,
|
|
)
|
|
return issame, images
|
|
|
|
|
|
def load_agedb_30_dataset(args):
|
|
data_dir = args.agedb_30_dataset_dir
|
|
batch_size = args.val_batch_size_per_device
|
|
data_part_num = args.val_data_part_num
|
|
|
|
(issame, images) = validation_dataset_reader(
|
|
val_dataset_dir=data_dir,
|
|
val_batch_size=batch_size,
|
|
val_data_part_num=data_part_num,
|
|
)
|
|
return issame, images
|
|
|