mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-19 15:41:33 +00:00
fix
This commit is contained in:
@@ -18,6 +18,15 @@ import numpy as np
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__),'..', 'common'))
|
||||
import face_image
|
||||
|
||||
def ch_dev(arg_params, aux_params, ctx):
|
||||
new_args = dict()
|
||||
new_auxs = dict()
|
||||
for k, v in arg_params.items():
|
||||
new_args[k] = v.as_in_context(ctx)
|
||||
for k, v in aux_params.items():
|
||||
new_auxs[k] = v.as_in_context(ctx)
|
||||
return new_args, new_auxs
|
||||
|
||||
def get_embedding(args, imgrec, id, image_size, model):
|
||||
s = imgrec.read_idx(id)
|
||||
header, _ = mx.recordio.unpack(s)
|
||||
@@ -49,7 +58,8 @@ def get_embedding(args, imgrec, id, image_size, model):
|
||||
data[ii][:] = data[0][:]
|
||||
label[ii][:] = label[0][:]
|
||||
ii+=1
|
||||
db = mx.io.DataBatch(data=(data,), label=(label,))
|
||||
#db = mx.io.DataBatch(data=(data,), label=(label,))
|
||||
db = mx.io.DataBatch(data=(data,))
|
||||
model.forward(db, is_train=False)
|
||||
net_out = model.get_outputs()
|
||||
net_out = net_out[0].asnumpy()
|
||||
@@ -82,8 +92,15 @@ def main(args):
|
||||
prefix = vec[0]
|
||||
epoch = int(vec[1])
|
||||
print('loading',prefix, epoch)
|
||||
model = mx.mod.Module.load(prefix, epoch, context = ctx)
|
||||
model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))])
|
||||
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
|
||||
#arg_params, aux_params = ch_dev(arg_params, aux_params, ctx)
|
||||
all_layers = sym.get_internals()
|
||||
sym = all_layers['fc1_output']
|
||||
#model = mx.mod.Module.load(prefix, epoch, context = ctx)
|
||||
#model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))])
|
||||
model = mx.mod.Module(symbol=sym, context=ctx, label_names = None)
|
||||
model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))])
|
||||
model.set_params(arg_params, aux_params)
|
||||
rec_list = []
|
||||
for ds in include_datasets:
|
||||
path_imgrec = os.path.join(ds, 'train.rec')
|
||||
@@ -136,29 +153,30 @@ def main(args):
|
||||
|
||||
|
||||
if len(args.exclude)>0:
|
||||
_path_imgrec = os.path.join(args.exclude, 'train.rec')
|
||||
_path_imgidx = os.path.join(args.exclude, 'train.idx')
|
||||
_imgrec = mx.recordio.MXIndexedRecordIO(_path_imgidx, _path_imgrec, 'r') # pylint: disable=redefined-variable-type
|
||||
_ds_id = len(rec_list)
|
||||
_id_list = []
|
||||
s = _imgrec.read_idx(0)
|
||||
header, _ = mx.recordio.unpack(s)
|
||||
assert header.flag>0
|
||||
print('header0 label', header.label)
|
||||
header0 = (int(header.label[0]), int(header.label[1]))
|
||||
#assert(header.flag==1)
|
||||
imgidx = range(1, int(header.label[0]))
|
||||
seq_identity = range(int(header.label[0]), int(header.label[1]))
|
||||
pp=0
|
||||
for identity in seq_identity:
|
||||
pp+=1
|
||||
if pp%10==0:
|
||||
print('processing ex id', pp)
|
||||
embedding = get_embedding(args, _imgrec, identity, image_size, model)
|
||||
#print(embedding.shape)
|
||||
_id_list.append( (_ds_id, identity, embedding) )
|
||||
if test_limit>0 and pp>=test_limit:
|
||||
break
|
||||
if os.path.isdir(args.exclude):
|
||||
_path_imgrec = os.path.join(args.exclude, 'train.rec')
|
||||
_path_imgidx = os.path.join(args.exclude, 'train.idx')
|
||||
_imgrec = mx.recordio.MXIndexedRecordIO(_path_imgidx, _path_imgrec, 'r') # pylint: disable=redefined-variable-type
|
||||
_ds_id = len(rec_list)
|
||||
_id_list = []
|
||||
s = _imgrec.read_idx(0)
|
||||
header, _ = mx.recordio.unpack(s)
|
||||
assert header.flag>0
|
||||
print('header0 label', header.label)
|
||||
header0 = (int(header.label[0]), int(header.label[1]))
|
||||
#assert(header.flag==1)
|
||||
imgidx = range(1, int(header.label[0]))
|
||||
seq_identity = range(int(header.label[0]), int(header.label[1]))
|
||||
pp=0
|
||||
for identity in seq_identity:
|
||||
pp+=1
|
||||
if pp%10==0:
|
||||
print('processing ex id', pp)
|
||||
embedding = get_embedding(args, _imgrec, identity, image_size, model)
|
||||
#print(embedding.shape)
|
||||
_id_list.append( (_ds_id, identity, embedding) )
|
||||
if test_limit>0 and pp>=test_limit:
|
||||
break
|
||||
|
||||
#X = []
|
||||
#for id_item in all_id_list:
|
||||
|
||||
@@ -30,6 +30,7 @@ import time
|
||||
import traceback
|
||||
#from builtins import range
|
||||
from easydict import EasyDict as edict
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'common'))
|
||||
import face_preprocess
|
||||
import face_image
|
||||
|
||||
@@ -191,7 +192,8 @@ def parse_args():
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
if args.list:
|
||||
make_list(args)
|
||||
pass
|
||||
#make_list(args)
|
||||
else:
|
||||
if os.path.isdir(args.prefix):
|
||||
working_dir = args.prefix
|
||||
|
||||
Reference in New Issue
Block a user