from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import sys import math import random import logging import pickle import numpy as np from data import FaceImageIter from data import FaceImageIterList import mxnet as mx from mxnet import ndarray as nd import argparse import mxnet.optimizer as optimizer sys.path.append(os.path.join(os.path.dirname(__file__), 'common')) import face_image sys.path.append(os.path.join(os.path.dirname(__file__), 'eval')) sys.path.append(os.path.join(os.path.dirname(__file__), 'symbols')) import fresnet import finception_resnet_v2 import fmobilenet import fxception import fdensenet import fdpn import fnasnet #import lfw import verification import sklearn sys.path.append(os.path.join(os.path.dirname(__file__), 'losses')) import center_loss logger = logging.getLogger() logger.setLevel(logging.INFO) class AccMetric(mx.metric.EvalMetric): def __init__(self): self.axis = 1 super(AccMetric, self).__init__( 'acc', axis=self.axis, output_names=None, label_names=None) self.losses = [] def update(self, labels, preds): #loss = preds[2].asnumpy()[0] #if len(self.losses)==20: # print('ce loss', sum(self.losses)/len(self.losses)) # self.losses = [] #self.losses.append(loss) preds = [preds[1]] #use softmax output for label, pred_label in zip(labels, preds): #print(pred_label) #print(label.shape, pred_label.shape) if pred_label.shape != label.shape: pred_label = mx.ndarray.argmax(pred_label, axis=self.axis) pred_label = pred_label.asnumpy().astype('int32').flatten() label = label.asnumpy().astype('int32').flatten() #print(label) #print('label',label) #print('pred_label', pred_label) assert label.shape==pred_label.shape self.sum_metric += (pred_label.flat == label.flat).sum() self.num_inst += len(pred_label.flat) class LossValueMetric(mx.metric.EvalMetric): def __init__(self): self.axis = 1 super(LossValueMetric, self).__init__( 'lossvalue', axis=self.axis, output_names=None, label_names=None) self.losses = [] def update(self, labels, preds): loss = preds[-1].asnumpy()[0] self.sum_metric += loss self.num_inst += 1.0 gt_label = preds[-2].asnumpy() #print(gt_label) def parse_args(): parser = argparse.ArgumentParser(description='Train face network') # general parser.add_argument('--data-dir', default='', help='') parser.add_argument('--prefix', default='../model/model', help='directory to save model.') parser.add_argument('--pretrained', default='', help='') parser.add_argument('--retrain', action='store_true', default=False, help='true means continue training.') parser.add_argument('--ckpt', type=int, default=1, help='') parser.add_argument('--network', default='s20', help='') parser.add_argument('--version-se', type=int, default=0, help='') parser.add_argument('--version-input', type=int, default=1, help='') parser.add_argument('--version-output', type=str, default='E', help='') parser.add_argument('--version-unit', type=int, default=3, help='') parser.add_argument('--end-epoch', type=int, default=100000, help='training epoch size.') parser.add_argument('--lr', type=float, default=0.1, help='') parser.add_argument('--wd', type=float, default=0.0005, help='') parser.add_argument('--mom', type=float, default=0.9, help='') parser.add_argument('--embedding-dim', type=int, default=512, help='') parser.add_argument('--per-batch-size', type=int, default=0, help='') parser.add_argument('--margin-m', type=float, default=0.35, help='') parser.add_argument('--margin-s', type=float, default=64.0, help='') parser.add_argument('--margin', type=int, default=4, help='') parser.add_argument('--beta', type=float, default=1000., help='') parser.add_argument('--beta-min', type=float, default=5., help='') parser.add_argument('--beta-freeze', type=int, default=0, help='') parser.add_argument('--gamma', type=float, default=0.12, help='') parser.add_argument('--power', type=float, default=1.0, help='') parser.add_argument('--scale', type=float, default=0.9993, help='') parser.add_argument('--center-alpha', type=float, default=0.5, help='') parser.add_argument('--center-scale', type=float, default=0.003, help='') parser.add_argument('--images-per-identity', type=int, default=0, help='') parser.add_argument('--triplet-bag-size', type=int, default=3600, help='') parser.add_argument('--triplet-alpha', type=float, default=0.3, help='') parser.add_argument('--triplet-max-ap', type=float, default=0.0, help='') parser.add_argument('--verbose', type=int, default=2000, help='') parser.add_argument('--loss-type', type=int, default=1, help='') parser.add_argument('--incay', type=float, default=0.0, help='feature incay') parser.add_argument('--use-deformable', type=int, default=0, help='') parser.add_argument('--patch', type=str, default='0_0_96_112_0', help='') parser.add_argument('--lr-steps', type=str, default='', help='') parser.add_argument('--target', type=str, default='lfw,cfp_ff,cfp_fp,agedb_30', help='') args = parser.parse_args() return args def get_symbol(args, arg_params, aux_params): data_shape = (args.image_channel,args.image_h,args.image_w) image_shape = ",".join([str(x) for x in data_shape]) if args.network[0]=='d': embedding = fdensenet.get_symbol(512, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='m': print('init mobilenet', args.num_layers) embedding = fmobilenet.get_symbol(512, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='i': print('init inception-resnet-v2', args.num_layers) embedding = finception_resnet_v2.get_symbol(512, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='x': print('init xception', args.num_layers) embedding = fxception.get_symbol(512, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='p': print('init dpn', args.num_layers) embedding = fdpn.get_symbol(512, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='n': print('init nasnet', args.num_layers) embedding = fnasnet.get_symbol(512) else: print('init resnet', args.num_layers) embedding = fresnet.get_symbol(512, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) gt_label = mx.symbol.Variable('softmax_label') assert args.loss_type>=0 extra_loss = None if args.loss_type==0: #softmax _weight = mx.symbol.Variable('fc7_weight') _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0) fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=args.num_classes, name='fc7') elif args.loss_type==1: #sphere _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, 512), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes, weight = _weight, beta=args.beta, margin=args.margin, scale=args.scale, beta_min=args.beta_min, verbose=1000, name='fc7') elif args.loss_type==8: #centerloss, TODO _weight = mx.symbol.Variable('fc7_weight') _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0) fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=args.num_classes, name='fc7') print('center-loss', args.center_alpha, args.center_scale) extra_loss = mx.symbol.Custom(data=embedding, label=gt_label, name='center_loss', op_type='centerloss',\ num_class=args.num_classes, alpha=args.center_alpha, scale=args.center_scale, batchsize=args.per_batch_size) elif args.loss_type==2: s = args.margin_s m = args.margin_m _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, 512), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') if s>0.0: nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') if m>0.0: s_m = s*m gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0) fc7 = fc7-gt_one_hot else: fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') if m>0.0: body = embedding*embedding body = mx.sym.sum_axis(body, axis=1, keepdims=True) body = mx.sym.sqrt(body) body = body*m gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) body = mx.sym.broadcast_mul(gt_one_hot, body) fc7 = fc7-body elif args.loss_type==3: _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, 512), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*22.0 fc7 = mx.sym.LSoftmax(data=nembedding, label=gt_label, num_hidden=args.num_classes, weight = _weight, beta=args.beta, margin=args.margin, scale=args.scale, beta_min=args.beta_min, verbose=1000, name='fc7') elif args.loss_type==4: s = args.margin_s m = args.margin_m cos_m = math.cos(m) sin_m = math.sin(m) _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, 512), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') assert s>0.0 assert m>0.0 assert m<(math.pi/2) threshold = 0.0 threshold = math.cos(math.pi-m) nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy/s if threshold==0.0: cond = mx.symbol.Activation(data=cos_t, act_type='relu') else: cond_v = cos_t - threshold cond = mx.symbol.Activation(data=cond_v, act_type='relu') #theta = mx.sym.arccos(costheta) #sintheta = mx.sym.sin(theta) body = cos_t*cos_t body = 1.0-body sin_t = mx.sym.sqrt(body) new_zy = cos_t*cos_m b = sin_t*sin_m new_zy = new_zy - b new_zy = new_zy*s zy_keep = zy _zy = sin_t*(-1.0*s*m) zy_keep += _zy new_zy = mx.sym.where(cond, new_zy, zy_keep) diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7+body #elif args.loss_type==4: # s = args.margin_s # m = args.margin_m # assert s>0.0 # #assert m>0.0 # _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, 512), lr_mult=1.0) # _weight = mx.symbol.L2Normalization(_weight, mode='instance') # nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s # fc7 = mx.sym.AmSoftmax(data=nembedding, label=gt_label, num_hidden=args.num_classes, # weight = _weight, verbose=1000, # margin = m, s = s, name='fc7') elif args.loss_type==10: #marginal loss nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') params = [1.2, 0.3, 1.0] n1 = mx.sym.expand_dims(nembedding, axis=1) #N,1,C n2 = mx.sym.expand_dims(nembedding, axis=0) #1,N,C body = mx.sym.broadcast_sub(n1, n2) #N,N,C body = body * body body = mx.sym.sum(body, axis=2) # N,N #body = mx.sym.sqrt(body) body = body - params[0] mask = mx.sym.Variable('extra') body = body*mask body = body+params[1] #body = mx.sym.maximum(body, 0.0) body = mx.symbol.Activation(data=body, act_type='relu') body = mx.sym.sum(body) body = body/(args.per_batch_size*args.per_batch_size-args.per_batch_size) extra_loss = mx.symbol.MakeLoss(body, grad_scale=params[2]) elif args.loss_type==11: #npair loss params = [0.9, 0.2] nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') nembedding = mx.sym.transpose(nembedding) nembedding = mx.symbol.reshape(nembedding, (512, args.per_identities, args.images_per_identity)) nembedding = mx.sym.transpose(nembedding, axes=(2,1,0)) #2*id*512 #nembedding = mx.symbol.reshape(nembedding, (512, args.images_per_identity, args.per_identities)) #nembedding = mx.sym.transpose(nembedding, axes=(1,2,0)) #2*id*512 n1 = mx.symbol.slice_axis(nembedding, axis=0, begin=0, end=1) n2 = mx.symbol.slice_axis(nembedding, axis=0, begin=1, end=2) #n1 = [] #n2 = [] #for i in xrange(args.per_identities): # _n1 = mx.symbol.slice_axis(nembedding, axis=0, begin=2*i, end=2*i+1) # _n2 = mx.symbol.slice_axis(nembedding, axis=0, begin=2*i+1, end=2*i+2) # n1.append(_n1) # n2.append(_n2) #n1 = mx.sym.concat(*n1, dim=0) #n2 = mx.sym.concat(*n2, dim=0) #rembeddings = mx.symbol.reshape(nembedding, (args.images_per_identity, args.per_identities, 512)) #n1 = mx.symbol.slice_axis(rembeddings, axis=0, begin=0, end=1) #n2 = mx.symbol.slice_axis(rembeddings, axis=0, begin=1, end=2) n1 = mx.symbol.reshape(n1, (args.per_identities, 512)) n2 = mx.symbol.reshape(n2, (args.per_identities, 512)) cosine_matrix = mx.symbol.dot(lhs=n1, rhs=n2, transpose_b = True) #id*id, id=N of N-pair data_extra = mx.sym.Variable('extra') data_extra = mx.sym.slice_axis(data_extra, axis=0, begin=0, end=args.per_identities) mask = cosine_matrix * data_extra #body = mx.sym.mean(mask) fii = mx.sym.sum_axis(mask, axis=1) fij_fii = mx.sym.broadcast_sub(cosine_matrix, fii) fij_fii = mx.sym.exp(fij_fii) row = mx.sym.sum_axis(fij_fii, axis=1) row = mx.sym.log(row) body = mx.sym.mean(row) extra_loss = mx.sym.MakeLoss(body) elif args.loss_type==12: #triplet loss nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') anchor = mx.symbol.slice_axis(nembedding, axis=0, begin=0, end=args.per_batch_size//3) positive = mx.symbol.slice_axis(nembedding, axis=0, begin=args.per_batch_size//3, end=2*args.per_batch_size//3) negative = mx.symbol.slice_axis(nembedding, axis=0, begin=2*args.per_batch_size//3, end=args.per_batch_size) ap = anchor - positive an = anchor - negative ap = ap*ap an = an*an ap = mx.symbol.sum(ap, axis=1, keepdims=1) #(T,1) an = mx.symbol.sum(an, axis=1, keepdims=1) #(T,1) triplet_loss = mx.symbol.Activation(data = (ap-an+args.triplet_alpha), act_type='relu') triplet_loss = mx.symbol.mean(triplet_loss) #triplet_loss = mx.symbol.sum(triplet_loss)/(args.per_batch_size//3) extra_loss = mx.symbol.MakeLoss(triplet_loss) elif args.loss_type==9: #coco loss centroids = [] for i in xrange(args.per_identities): xs = mx.symbol.slice_axis(embedding, axis=0, begin=i*args.images_per_identity, end=(i+1)*args.images_per_identity) mean = mx.symbol.mean(xs, axis=0, keepdims=True) mean = mx.symbol.L2Normalization(mean, mode='instance') centroids.append(mean) centroids = mx.symbol.concat(*centroids, dim=0) nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*args.coco_scale fc7 = mx.symbol.dot(nembedding, centroids, transpose_b = True) #(batchsize, per_identities) #extra_loss = mx.symbol.softmax_cross_entropy(fc7, gt_label, name='softmax_ce')/args.per_batch_size #extra_loss = mx.symbol.BlockGrad(extra_loss) else: #embedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*float(args.loss_type) embedding = embedding * 5 _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, 512), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') * 2 fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes, weight = _weight, beta=args.beta, margin=args.margin, scale=args.scale, beta_min=args.beta_min, verbose=100, name='fc7') #fc7 = mx.sym.Custom(data=embedding, label=gt_label, weight=_weight, num_hidden=args.num_classes, # beta=args.beta, margin=args.margin, scale=args.scale, # op_type='ASoftmax', name='fc7') if args.loss_type<=1 and args.incay>0.0: params = [1.e-10] sel = mx.symbol.argmax(data = fc7, axis=1) sel = (sel==gt_label) norm = embedding*embedding norm = mx.symbol.sum(norm, axis=1) norm = norm+params[0] feature_incay = sel/norm feature_incay = mx.symbol.mean(feature_incay) * args.incay extra_loss = mx.symbol.MakeLoss(feature_incay) #out = softmax #l2_embedding = mx.symbol.L2Normalization(embedding) #ce = mx.symbol.softmax_cross_entropy(fc7, gt_label, name='softmax_ce')/args.per_batch_size #out = mx.symbol.Group([mx.symbol.BlockGrad(embedding), softmax, mx.symbol.BlockGrad(ce)]) out_list = [mx.symbol.BlockGrad(embedding)] softmax = None if args.loss_type<10: softmax = mx.symbol.SoftmaxOutput(data=fc7, label = gt_label, name='softmax', normalization='valid') out_list.append(softmax) if softmax is None: out_list.append(mx.sym.BlockGrad(gt_label)) if extra_loss is not None: out_list.append(extra_loss) out = mx.symbol.Group(out_list) return (out, arg_params, aux_params) def train_net(args): ctx = [] cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip() if len(cvd)>0: for i in xrange(len(cvd.split(','))): ctx.append(mx.gpu(i)) if len(ctx)==0: ctx = [mx.cpu()] print('use cpu') else: print('gpu num:', len(ctx)) prefix = args.prefix prefix_dir = os.path.dirname(prefix) if not os.path.exists(prefix_dir): os.makedirs(prefix_dir) end_epoch = args.end_epoch args.ctx_num = len(ctx) args.num_layers = int(args.network[1:]) print('num_layers', args.num_layers) if args.per_batch_size==0: args.per_batch_size = 128 if args.loss_type==10: args.per_batch_size = 256 args.batch_size = args.per_batch_size*args.ctx_num args.rescale_threshold = 0 args.image_channel = 3 ppatch = [int(x) for x in args.patch.split('_')] assert len(ppatch)==5 os.environ['BETA'] = str(args.beta) data_dir_list = args.data_dir.split(',') if args.loss_type!=12: assert len(data_dir_list)==1 data_dir = data_dir_list[0] args.use_val = False path_imgrec = None path_imglist = None val_rec = None prop = face_image.load_property(data_dir) args.num_classes = prop.num_classes image_size = prop.image_size args.image_h = image_size[0] args.image_w = image_size[1] print('image_size', image_size) assert(args.num_classes>0) print('num_classes', args.num_classes) args.coco_scale = 0.5*math.log(float(args.num_classes-1))+3 #path_imglist = "/raid5data/dplearn/MS-Celeb-Aligned/lst2" path_imgrec = os.path.join(data_dir, "train.rec") val_rec = os.path.join(data_dir, "val.rec") if os.path.exists(val_rec) and args.loss_type<10: args.use_val = True else: val_rec = None args.use_val = False #args.num_classes = 10572 #webface #args.num_classes = 81017 #args.num_classes = 82395 if (args.loss_type>=1 and args.loss_type<=5) and args.num_classes>40000: args.beta_freeze = 5000 args.gamma = 0.06 if args.loss_type<9: assert args.images_per_identity==0 else: if args.images_per_identity==0: if args.loss_type==11: args.images_per_identity = 2 elif args.loss_type==10 or args.loss_type==9: args.images_per_identity = 16 elif args.loss_type==12: args.images_per_identity = 5 assert args.per_batch_size%3==0 assert args.images_per_identity>=2 args.per_identities = int(args.per_batch_size/args.images_per_identity) print('Called with argument:', args) data_shape = (args.image_channel,image_size[0],image_size[1]) mean = None if args.use_val: val_dataiter = FaceImageIter( batch_size = args.batch_size, data_shape = data_shape, path_imgrec = val_rec, #path_imglist = val_path, shuffle = False, rand_mirror = False, mean = mean, ) else: val_dataiter = None begin_epoch = 0 base_lr = args.lr base_wd = args.wd base_mom = args.mom if len(args.pretrained)==0: arg_params = None aux_params = None sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params) else: vec = args.pretrained.split(',') print('loading', vec) _, arg_params, aux_params = mx.model.load_checkpoint(vec[0], int(vec[1])) sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params) data_extra = None hard_mining = False triplet_params = None coco_mode = False if args.loss_type==10: hard_mining = True _shape = (args.batch_size, args.per_batch_size) data_extra = np.full(_shape, -1.0, dtype=np.float32) c = 0 while c=1 and args.loss_type<=5: lr_steps = [100000, 140000, 160000] p = 512.0/args.batch_size for l in xrange(len(lr_steps)): lr_steps[l] = int(lr_steps[l]*p) else: lr_steps = [int(x) for x in args.lr_steps.split(',')] print('lr_steps', lr_steps) def _batch_callback(param): #global global_step global_step[0]+=1 mbatch = global_step[0] for _lr in lr_steps: if mbatch==args.beta_freeze+_lr: opt.lr *= 0.1 print('lr change to', opt.lr) break _cb(param) if mbatch%1000==0: print('lr-batch-epoch:',opt.lr,param.nbatch,param.epoch) if mbatch>=0 and mbatch%args.verbose==0: acc_list = ver_test(mbatch) save_step[0]+=1 msave = save_step[0] do_save = False lfw_score = acc_list[0] if lfw_score>highest_acc[0]: highest_acc[0] = lfw_score if lfw_score>=0.998: do_save = True if acc_list[-1]>=highest_acc[-1]: highest_acc[-1] = acc_list[-1] if lfw_score>=0.99: do_save = True if args.ckpt==0: do_save = False elif args.ckpt>1: do_save = True #for i in xrange(len(acc_list)): # acc = acc_list[i] # if acc>=highest_acc[i]: # highest_acc[i] = acc # if lfw_score>=0.99: # do_save = True #if args.loss_type==1 and mbatch>lr_steps[-1] and mbatch%10000==0: # do_save = True if do_save: print('saving', msave) if val_dataiter is not None: val_test() arg, aux = model.get_params() mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux) #if acc>=highest_acc[0]: # lfw_npy = "%s-lfw-%04d" % (prefix, msave) # X = np.concatenate(embeddings_list, axis=0) # print('saving lfw npy', X.shape) # np.save(lfw_npy, X) print('[%d]Accuracy-Highest: %1.5f'%(mbatch, highest_acc[-1])) if mbatch<=args.beta_freeze: _beta = args.beta else: move = max(0, mbatch-args.beta_freeze) _beta = max(args.beta_min, args.beta*math.pow(1+args.gamma*move, -1.0*args.power)) #print('beta', _beta) os.environ['BETA'] = str(_beta) #epoch_cb = mx.callback.do_checkpoint(prefix, 1) epoch_cb = None #def _epoch_callback(epoch, sym, arg_params, aux_params): # print('epoch-end', epoch) model.fit(train_dataiter, begin_epoch = begin_epoch, num_epoch = end_epoch, eval_data = val_dataiter, eval_metric = eval_metrics, kvstore = 'device', optimizer = opt, #optimizer_params = optimizer_params, initializer = initializer, arg_params = arg_params, aux_params = aux_params, allow_missing = True, batch_end_callback = _batch_callback, epoch_end_callback = epoch_cb ) def main(): #time.sleep(3600*6.5) args = parse_args() train_net(args) if __name__ == '__main__': main()