mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-17 14:26:08 +00:00
tiny
This commit is contained in:
@@ -163,6 +163,8 @@ class FaceImageIter(io.DataIter):
|
||||
try:
|
||||
while i < batch_size:
|
||||
label, s, bbox, landmark = self.next_sample()
|
||||
#if label[1]>=0.0 or label[2]>=0.0:
|
||||
# print(label[0:10])
|
||||
_data = self.imdecode(s)
|
||||
if self.rand_mirror:
|
||||
_rd = random.randint(0,1)
|
||||
|
||||
@@ -24,6 +24,10 @@ def main(args):
|
||||
stat = [0,0]
|
||||
for ds in ['ms1m', 'megaage', 'imdb']:
|
||||
for n in ['train', 'val']:
|
||||
if ds=='ms1m' or ds=='imdb':
|
||||
continue
|
||||
if n=='val' and ds!='megaage':
|
||||
continue
|
||||
repeat = 1
|
||||
if n=='train' and ds=='megaage':
|
||||
repeat = 1
|
||||
@@ -34,7 +38,7 @@ def main(args):
|
||||
if n=='val':
|
||||
writer = val_writer
|
||||
widx = val_widx
|
||||
path = os.path.join(ds, '%s.rec'%n)
|
||||
path = os.path.join(args.input, ds, '%s.rec'%n)
|
||||
if not os.path.exists(path):
|
||||
continue
|
||||
imgrec = mx.recordio.MXIndexedRecordIO(path[:-3]+'idx', path, 'r') # pylint: disable=redefined-variable-type
|
||||
@@ -93,6 +97,7 @@ def main(args):
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='do dataset merge')
|
||||
# general
|
||||
parser.add_argument('--input', default='', type=str, help='')
|
||||
parser.add_argument('--output', default='', type=str, help='')
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
@@ -50,17 +50,21 @@ USE_AGE = True
|
||||
|
||||
|
||||
class AccMetric(mx.metric.EvalMetric):
|
||||
def __init__(self, pred_idx = 1, name='acc'):
|
||||
def __init__(self, pred_idx = 1, label_idx = 0, name='acc'):
|
||||
self.axis = 1
|
||||
self.pred_idx = pred_idx
|
||||
self.label_idx = label_idx
|
||||
super(AccMetric, self).__init__(
|
||||
'name', axis=self.axis,
|
||||
name, axis=self.axis,
|
||||
output_names=None, label_names=None)
|
||||
self.name = name
|
||||
self.losses = []
|
||||
self.count = 0
|
||||
|
||||
def update(self, labels, preds):
|
||||
self.count+=1
|
||||
|
||||
#print('label num', len(labels))
|
||||
preds = [preds[self.pred_idx]] #use softmax output
|
||||
for label, pred_label in zip(labels, preds):
|
||||
if pred_label.shape != label.shape:
|
||||
@@ -68,8 +72,9 @@ class AccMetric(mx.metric.EvalMetric):
|
||||
pred_label = pred_label.asnumpy().astype('int32').flatten()
|
||||
label = label.asnumpy()
|
||||
if label.ndim==2:
|
||||
label = label[:,0]
|
||||
label = label[:,self.label_idx]
|
||||
label = label.astype('int32').flatten()
|
||||
print(self.name, label)
|
||||
assert label.shape==pred_label.shape
|
||||
self.sum_metric += (pred_label.flat == label.flat).sum()
|
||||
self.num_inst += len(pred_label.flat)
|
||||
@@ -331,14 +336,17 @@ def train_net(args):
|
||||
data_dir = data_dir_list[0]
|
||||
path_imgrec = None
|
||||
path_imglist = None
|
||||
prop = face_image.load_property(data_dir)
|
||||
args.num_classes = prop.num_classes
|
||||
image_size = prop.image_size
|
||||
args.num_classes = 0
|
||||
image_size = (112,112)
|
||||
if os.path.exists(os.path.join(data_dir, 'property')):
|
||||
prop = face_image.load_property(data_dir)
|
||||
args.num_classes = prop.num_classes
|
||||
image_size = prop.image_size
|
||||
assert(args.num_classes>0)
|
||||
print('num_classes', args.num_classes)
|
||||
args.image_h = image_size[0]
|
||||
args.image_w = image_size[1]
|
||||
print('image_size', image_size)
|
||||
assert(args.num_classes>0)
|
||||
print('num_classes', args.num_classes)
|
||||
path_imgrec = os.path.join(data_dir, "train.rec")
|
||||
|
||||
print('Called with argument:', args)
|
||||
@@ -392,13 +400,13 @@ def train_net(args):
|
||||
|
||||
eval_metrics = []
|
||||
if USE_FR:
|
||||
_metric = AccMetric(pred_idx=1)
|
||||
_metric = AccMetric(pred_idx=1, label_idx=0)
|
||||
eval_metrics.append(_metric)
|
||||
if USE_GENDER:
|
||||
_metric = AccMetric(pred_idx=2, name='gender')
|
||||
_metric = AccMetric(pred_idx=2, label_idx=1, name='gender')
|
||||
eval_metrics.append(_metric)
|
||||
elif USE_GENDER:
|
||||
_metric = AccMetric(pred_idx=1, name='gender')
|
||||
_metric = AccMetric(pred_idx=1, label_idx=1, name='gender')
|
||||
eval_metrics.append(_metric)
|
||||
if USE_AGE:
|
||||
_metric = MAEMetric()
|
||||
|
||||
Reference in New Issue
Block a user