This commit is contained in:
Jia Guo
2018-01-25 20:31:36 +08:00
parent f916173825
commit d92a6bcda0
2 changed files with 47 additions and 27 deletions

View File

@@ -18,6 +18,15 @@ import numpy as np
sys.path.append(os.path.join(os.path.dirname(__file__),'..', 'common'))
import face_image
def ch_dev(arg_params, aux_params, ctx):
new_args = dict()
new_auxs = dict()
for k, v in arg_params.items():
new_args[k] = v.as_in_context(ctx)
for k, v in aux_params.items():
new_auxs[k] = v.as_in_context(ctx)
return new_args, new_auxs
def get_embedding(args, imgrec, id, image_size, model):
s = imgrec.read_idx(id)
header, _ = mx.recordio.unpack(s)
@@ -49,7 +58,8 @@ def get_embedding(args, imgrec, id, image_size, model):
data[ii][:] = data[0][:]
label[ii][:] = label[0][:]
ii+=1
db = mx.io.DataBatch(data=(data,), label=(label,))
#db = mx.io.DataBatch(data=(data,), label=(label,))
db = mx.io.DataBatch(data=(data,))
model.forward(db, is_train=False)
net_out = model.get_outputs()
net_out = net_out[0].asnumpy()
@@ -82,8 +92,15 @@ def main(args):
prefix = vec[0]
epoch = int(vec[1])
print('loading',prefix, epoch)
model = mx.mod.Module.load(prefix, epoch, context = ctx)
model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))])
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
#arg_params, aux_params = ch_dev(arg_params, aux_params, ctx)
all_layers = sym.get_internals()
sym = all_layers['fc1_output']
#model = mx.mod.Module.load(prefix, epoch, context = ctx)
#model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))])
model = mx.mod.Module(symbol=sym, context=ctx, label_names = None)
model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))])
model.set_params(arg_params, aux_params)
rec_list = []
for ds in include_datasets:
path_imgrec = os.path.join(ds, 'train.rec')
@@ -136,29 +153,30 @@ def main(args):
if len(args.exclude)>0:
_path_imgrec = os.path.join(args.exclude, 'train.rec')
_path_imgidx = os.path.join(args.exclude, 'train.idx')
_imgrec = mx.recordio.MXIndexedRecordIO(_path_imgidx, _path_imgrec, 'r') # pylint: disable=redefined-variable-type
_ds_id = len(rec_list)
_id_list = []
s = _imgrec.read_idx(0)
header, _ = mx.recordio.unpack(s)
assert header.flag>0
print('header0 label', header.label)
header0 = (int(header.label[0]), int(header.label[1]))
#assert(header.flag==1)
imgidx = range(1, int(header.label[0]))
seq_identity = range(int(header.label[0]), int(header.label[1]))
pp=0
for identity in seq_identity:
pp+=1
if pp%10==0:
print('processing ex id', pp)
embedding = get_embedding(args, _imgrec, identity, image_size, model)
#print(embedding.shape)
_id_list.append( (_ds_id, identity, embedding) )
if test_limit>0 and pp>=test_limit:
break
if os.path.isdir(args.exclude):
_path_imgrec = os.path.join(args.exclude, 'train.rec')
_path_imgidx = os.path.join(args.exclude, 'train.idx')
_imgrec = mx.recordio.MXIndexedRecordIO(_path_imgidx, _path_imgrec, 'r') # pylint: disable=redefined-variable-type
_ds_id = len(rec_list)
_id_list = []
s = _imgrec.read_idx(0)
header, _ = mx.recordio.unpack(s)
assert header.flag>0
print('header0 label', header.label)
header0 = (int(header.label[0]), int(header.label[1]))
#assert(header.flag==1)
imgidx = range(1, int(header.label[0]))
seq_identity = range(int(header.label[0]), int(header.label[1]))
pp=0
for identity in seq_identity:
pp+=1
if pp%10==0:
print('processing ex id', pp)
embedding = get_embedding(args, _imgrec, identity, image_size, model)
#print(embedding.shape)
_id_list.append( (_ds_id, identity, embedding) )
if test_limit>0 and pp>=test_limit:
break
#X = []
#for id_item in all_id_list:

View File

@@ -30,6 +30,7 @@ import time
import traceback
#from builtins import range
from easydict import EasyDict as edict
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'common'))
import face_preprocess
import face_image
@@ -191,7 +192,8 @@ def parse_args():
if __name__ == '__main__':
args = parse_args()
if args.list:
make_list(args)
pass
#make_list(args)
else:
if os.path.isdir(args.prefix):
working_dir = args.prefix