From d92a6bcda0fce7a629563bac912d5fbc05ec2aa5 Mon Sep 17 00:00:00 2001 From: Jia Guo Date: Thu, 25 Jan 2018 20:31:36 +0800 Subject: [PATCH] fix --- src/data/dataset_merge.py | 70 ++++++++++++++++++++++++--------------- src/data/face2rec2.py | 4 ++- 2 files changed, 47 insertions(+), 27 deletions(-) diff --git a/src/data/dataset_merge.py b/src/data/dataset_merge.py index 38aea3f..bba45e1 100644 --- a/src/data/dataset_merge.py +++ b/src/data/dataset_merge.py @@ -18,6 +18,15 @@ import numpy as np sys.path.append(os.path.join(os.path.dirname(__file__),'..', 'common')) import face_image +def ch_dev(arg_params, aux_params, ctx): + new_args = dict() + new_auxs = dict() + for k, v in arg_params.items(): + new_args[k] = v.as_in_context(ctx) + for k, v in aux_params.items(): + new_auxs[k] = v.as_in_context(ctx) + return new_args, new_auxs + def get_embedding(args, imgrec, id, image_size, model): s = imgrec.read_idx(id) header, _ = mx.recordio.unpack(s) @@ -49,7 +58,8 @@ def get_embedding(args, imgrec, id, image_size, model): data[ii][:] = data[0][:] label[ii][:] = label[0][:] ii+=1 - db = mx.io.DataBatch(data=(data,), label=(label,)) + #db = mx.io.DataBatch(data=(data,), label=(label,)) + db = mx.io.DataBatch(data=(data,)) model.forward(db, is_train=False) net_out = model.get_outputs() net_out = net_out[0].asnumpy() @@ -82,8 +92,15 @@ def main(args): prefix = vec[0] epoch = int(vec[1]) print('loading',prefix, epoch) - model = mx.mod.Module.load(prefix, epoch, context = ctx) - model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))]) + sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) + #arg_params, aux_params = ch_dev(arg_params, aux_params, ctx) + all_layers = sym.get_internals() + sym = all_layers['fc1_output'] + #model = mx.mod.Module.load(prefix, epoch, context = ctx) + #model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))]) + model = mx.mod.Module(symbol=sym, context=ctx, label_names = None) + model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))]) + model.set_params(arg_params, aux_params) rec_list = [] for ds in include_datasets: path_imgrec = os.path.join(ds, 'train.rec') @@ -136,29 +153,30 @@ def main(args): if len(args.exclude)>0: - _path_imgrec = os.path.join(args.exclude, 'train.rec') - _path_imgidx = os.path.join(args.exclude, 'train.idx') - _imgrec = mx.recordio.MXIndexedRecordIO(_path_imgidx, _path_imgrec, 'r') # pylint: disable=redefined-variable-type - _ds_id = len(rec_list) - _id_list = [] - s = _imgrec.read_idx(0) - header, _ = mx.recordio.unpack(s) - assert header.flag>0 - print('header0 label', header.label) - header0 = (int(header.label[0]), int(header.label[1])) - #assert(header.flag==1) - imgidx = range(1, int(header.label[0])) - seq_identity = range(int(header.label[0]), int(header.label[1])) - pp=0 - for identity in seq_identity: - pp+=1 - if pp%10==0: - print('processing ex id', pp) - embedding = get_embedding(args, _imgrec, identity, image_size, model) - #print(embedding.shape) - _id_list.append( (_ds_id, identity, embedding) ) - if test_limit>0 and pp>=test_limit: - break + if os.path.isdir(args.exclude): + _path_imgrec = os.path.join(args.exclude, 'train.rec') + _path_imgidx = os.path.join(args.exclude, 'train.idx') + _imgrec = mx.recordio.MXIndexedRecordIO(_path_imgidx, _path_imgrec, 'r') # pylint: disable=redefined-variable-type + _ds_id = len(rec_list) + _id_list = [] + s = _imgrec.read_idx(0) + header, _ = mx.recordio.unpack(s) + assert header.flag>0 + print('header0 label', header.label) + header0 = (int(header.label[0]), int(header.label[1])) + #assert(header.flag==1) + imgidx = range(1, int(header.label[0])) + seq_identity = range(int(header.label[0]), int(header.label[1])) + pp=0 + for identity in seq_identity: + pp+=1 + if pp%10==0: + print('processing ex id', pp) + embedding = get_embedding(args, _imgrec, identity, image_size, model) + #print(embedding.shape) + _id_list.append( (_ds_id, identity, embedding) ) + if test_limit>0 and pp>=test_limit: + break #X = [] #for id_item in all_id_list: diff --git a/src/data/face2rec2.py b/src/data/face2rec2.py index 6965562..361e398 100644 --- a/src/data/face2rec2.py +++ b/src/data/face2rec2.py @@ -30,6 +30,7 @@ import time import traceback #from builtins import range from easydict import EasyDict as edict +sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'common')) import face_preprocess import face_image @@ -191,7 +192,8 @@ def parse_args(): if __name__ == '__main__': args = parse_args() if args.list: - make_list(args) + pass + #make_list(args) else: if os.path.isdir(args.prefix): working_dir = args.prefix