This commit is contained in:
nttstar
2018-04-27 14:53:58 +08:00
parent bd45b710c1
commit f7c83cac4f
3 changed files with 27 additions and 12 deletions

View File

@@ -163,6 +163,8 @@ class FaceImageIter(io.DataIter):
try:
while i < batch_size:
label, s, bbox, landmark = self.next_sample()
#if label[1]>=0.0 or label[2]>=0.0:
# print(label[0:10])
_data = self.imdecode(s)
if self.rand_mirror:
_rd = random.randint(0,1)

View File

@@ -24,6 +24,10 @@ def main(args):
stat = [0,0]
for ds in ['ms1m', 'megaage', 'imdb']:
for n in ['train', 'val']:
if ds=='ms1m' or ds=='imdb':
continue
if n=='val' and ds!='megaage':
continue
repeat = 1
if n=='train' and ds=='megaage':
repeat = 1
@@ -34,7 +38,7 @@ def main(args):
if n=='val':
writer = val_writer
widx = val_widx
path = os.path.join(ds, '%s.rec'%n)
path = os.path.join(args.input, ds, '%s.rec'%n)
if not os.path.exists(path):
continue
imgrec = mx.recordio.MXIndexedRecordIO(path[:-3]+'idx', path, 'r') # pylint: disable=redefined-variable-type
@@ -93,6 +97,7 @@ def main(args):
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='do dataset merge')
# general
parser.add_argument('--input', default='', type=str, help='')
parser.add_argument('--output', default='', type=str, help='')
args = parser.parse_args()
main(args)

View File

@@ -50,17 +50,21 @@ USE_AGE = True
class AccMetric(mx.metric.EvalMetric):
def __init__(self, pred_idx = 1, name='acc'):
def __init__(self, pred_idx = 1, label_idx = 0, name='acc'):
self.axis = 1
self.pred_idx = pred_idx
self.label_idx = label_idx
super(AccMetric, self).__init__(
'name', axis=self.axis,
name, axis=self.axis,
output_names=None, label_names=None)
self.name = name
self.losses = []
self.count = 0
def update(self, labels, preds):
self.count+=1
#print('label num', len(labels))
preds = [preds[self.pred_idx]] #use softmax output
for label, pred_label in zip(labels, preds):
if pred_label.shape != label.shape:
@@ -68,8 +72,9 @@ class AccMetric(mx.metric.EvalMetric):
pred_label = pred_label.asnumpy().astype('int32').flatten()
label = label.asnumpy()
if label.ndim==2:
label = label[:,0]
label = label[:,self.label_idx]
label = label.astype('int32').flatten()
print(self.name, label)
assert label.shape==pred_label.shape
self.sum_metric += (pred_label.flat == label.flat).sum()
self.num_inst += len(pred_label.flat)
@@ -331,14 +336,17 @@ def train_net(args):
data_dir = data_dir_list[0]
path_imgrec = None
path_imglist = None
prop = face_image.load_property(data_dir)
args.num_classes = prop.num_classes
image_size = prop.image_size
args.num_classes = 0
image_size = (112,112)
if os.path.exists(os.path.join(data_dir, 'property')):
prop = face_image.load_property(data_dir)
args.num_classes = prop.num_classes
image_size = prop.image_size
assert(args.num_classes>0)
print('num_classes', args.num_classes)
args.image_h = image_size[0]
args.image_w = image_size[1]
print('image_size', image_size)
assert(args.num_classes>0)
print('num_classes', args.num_classes)
path_imgrec = os.path.join(data_dir, "train.rec")
print('Called with argument:', args)
@@ -392,13 +400,13 @@ def train_net(args):
eval_metrics = []
if USE_FR:
_metric = AccMetric(pred_idx=1)
_metric = AccMetric(pred_idx=1, label_idx=0)
eval_metrics.append(_metric)
if USE_GENDER:
_metric = AccMetric(pred_idx=2, name='gender')
_metric = AccMetric(pred_idx=2, label_idx=1, name='gender')
eval_metrics.append(_metric)
elif USE_GENDER:
_metric = AccMetric(pred_idx=1, name='gender')
_metric = AccMetric(pred_idx=1, label_idx=1, name='gender')
eval_metrics.append(_metric)
if USE_AGE:
_metric = MAEMetric()