mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-16 21:47:47 +00:00
move more to config
This commit is contained in:
@@ -13,6 +13,15 @@ config.net_input = 1
|
||||
config.net_output = 'E'
|
||||
config.net_multiplier = 1.0
|
||||
config.val_targets = ['lfw', 'cfp_fp', 'agedb_30']
|
||||
config.ce_loss = False
|
||||
config.fc7_lr_mult = 1.0
|
||||
config.fc7_wd_mult = 1.0
|
||||
config.fc7_no_bias = False
|
||||
config.max_steps = 0
|
||||
config.data_rand_mirror = True
|
||||
config.data_cutoff = False
|
||||
config.data_color = 0
|
||||
config.data_images_filter = 0
|
||||
|
||||
|
||||
# network settings
|
||||
@@ -51,7 +60,7 @@ dataset = edict()
|
||||
|
||||
dataset.emore = edict()
|
||||
dataset.emore.dataset = 'emore'
|
||||
dataset.emore.dataset_path = './faces_emore'
|
||||
dataset.emore.dataset_path = '../datasets/faces_emore'
|
||||
dataset.emore.num_classes = 85742
|
||||
dataset.emore.image_shape = (112,112,3)
|
||||
dataset.emore.val_targets = ['lfw', 'cfp_fp', 'agedb_30']
|
||||
@@ -59,10 +68,6 @@ dataset.emore.val_targets = ['lfw', 'cfp_fp', 'agedb_30']
|
||||
loss = edict()
|
||||
loss.softmax = edict()
|
||||
loss.softmax.loss_name = 'softmax'
|
||||
loss.softmax.loss_s = -1.0
|
||||
loss.softmax.loss_m1 = 0.0
|
||||
loss.softmax.loss_m2 = 0.0
|
||||
loss.softmax.loss_m3 = 0.0
|
||||
|
||||
loss.nsoftmax = edict()
|
||||
loss.nsoftmax.loss_name = 'margin_softmax'
|
||||
@@ -116,7 +121,7 @@ default = edict()
|
||||
# default network
|
||||
default.network = 'r100'
|
||||
default.pretrained = ''
|
||||
default.pretrained_epoch = 0
|
||||
default.pretrained_epoch = 1
|
||||
# default dataset
|
||||
default.dataset = 'emore'
|
||||
default.loss = 'arcface'
|
||||
|
||||
@@ -78,25 +78,16 @@ def parse_args():
|
||||
args, rest = parser.parse_known_args()
|
||||
generate_config(args.network, args.dataset, args.loss)
|
||||
parser.add_argument('--models-root', default=default.models_root, help='root directory to save model.')
|
||||
parser.add_argument('--pretrained', default='', help='pretrained model to load')
|
||||
parser.add_argument('--pretrained', default=default.pretrained, help='pretrained model to load')
|
||||
parser.add_argument('--pretrained-epoch', default=default.pretrained_epoch, help='pretrained epoch to load')
|
||||
parser.add_argument('--ckpt', type=int, default=default.ckpt, help='checkpoint saving option. 0: discard saving. 1: save when necessary. 2: always save')
|
||||
parser.add_argument('--verbose', type=int, default=default.verbose, help='do verification testing and model saving every verbose batches')
|
||||
parser.add_argument('--max-steps', type=int, default=0, help='max training batches')
|
||||
parser.add_argument('--end-epoch', type=int, default=100000, help='training epoch size.')
|
||||
parser.add_argument('--lr', type=float, default=default.lr, help='start learning rate')
|
||||
parser.add_argument('--lr-steps', type=str, default=default.lr_steps, help='steps of lr changing')
|
||||
parser.add_argument('--wd', type=float, default=default.wd, help='weight decay')
|
||||
parser.add_argument('--mom', type=float, default=default.mom, help='momentum')
|
||||
parser.add_argument('--frequent', type=int, default=default.frequent, help='')
|
||||
parser.add_argument('--fc7-wd-mult', type=float, default=1.0, help='weight decay mult for fc7')
|
||||
parser.add_argument('--fc7-lr-mult', type=float, default=1.0, help='lr mult for fc7')
|
||||
parser.add_argument("--fc7-no-bias", default=False, action="store_true" , help="fc7 no bias flag")
|
||||
parser.add_argument('--per-batch-size', type=int, default=default.per_batch_size, help='batch size in each context')
|
||||
parser.add_argument('--rand-mirror', type=int, default=1, help='if do random mirror in training')
|
||||
parser.add_argument('--cutoff', type=int, default=0, help='cut off aug')
|
||||
parser.add_argument('--color', type=int, default=0, help='color jittering aug')
|
||||
parser.add_argument('--images-filter', type=int, default=0, help='minimum images per identity filter')
|
||||
parser.add_argument('--ce-loss', default=False, action='store_true', help='if output ce loss')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
@@ -107,14 +98,16 @@ def get_symbol(args):
|
||||
gt_label = all_label
|
||||
is_softmax = True
|
||||
if config.loss_name=='softmax': #softmax
|
||||
_weight = mx.symbol.Variable("fc7_weight", shape=(config.num_classes, config.emb_size), lr_mult=args.fc7_lr_mult, wd_mult=args.fc7_wd_mult)
|
||||
if args.fc7_no_bias:
|
||||
_weight = mx.symbol.Variable("fc7_weight", shape=(config.num_classes, config.emb_size),
|
||||
lr_mult=config.fc7_lr_mult, wd_mult=config.fc7_wd_mult, init=mx.init.Normal(0.01))
|
||||
if config.fc7_no_bias:
|
||||
fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, no_bias = True, num_hidden=config.num_classes, name='fc7')
|
||||
else:
|
||||
_bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
|
||||
fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=config.num_classes, name='fc7')
|
||||
elif config.loss_name=='margin_softmax':
|
||||
_weight = mx.symbol.Variable("fc7_weight", shape=(config.num_classes, config.emb_size), lr_mult=args.fc7_lr_mult, wd_mult=args.fc7_wd_mult)
|
||||
_weight = mx.symbol.Variable("fc7_weight", shape=(config.num_classes, config.emb_size),
|
||||
lr_mult=config.fc7_lr_mult, wd_mult=config.fc7_wd_mult, init=mx.init.Normal(0.01))
|
||||
s = config.loss_s
|
||||
_weight = mx.symbol.L2Normalization(_weight, mode='instance')
|
||||
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
|
||||
@@ -170,7 +163,7 @@ def get_symbol(args):
|
||||
if is_softmax:
|
||||
softmax = mx.symbol.SoftmaxOutput(data=fc7, label = gt_label, name='softmax', normalization='valid')
|
||||
out_list.append(softmax)
|
||||
if args.ce_loss:
|
||||
if config.ce_loss:
|
||||
#ce_loss = mx.symbol.softmax_cross_entropy(data=fc7, label = gt_label, name='ce_loss')/args.per_batch_size
|
||||
body = mx.symbol.SoftmaxActivation(data=fc7)
|
||||
body = mx.symbol.log(body)
|
||||
@@ -200,7 +193,6 @@ def train_net(args):
|
||||
print('prefix', prefix)
|
||||
if not os.path.exists(prefix_dir):
|
||||
os.makedirs(prefix_dir)
|
||||
end_epoch = args.end_epoch
|
||||
args.ctx_num = len(ctx)
|
||||
args.batch_size = args.per_batch_size*args.ctx_num
|
||||
args.rescale_threshold = 0
|
||||
@@ -229,9 +221,8 @@ def train_net(args):
|
||||
data_shape_dict = {'data' : (args.per_batch_size,)+data_shape}
|
||||
spherenet.init_weights(sym, data_shape_dict, args.num_layers)
|
||||
else:
|
||||
vec = args.pretrained.split(',')
|
||||
print('loading', vec)
|
||||
_, arg_params, aux_params = mx.model.load_checkpoint(vec[0], int(vec[1]))
|
||||
print('loading', args.pretrained, args.pretrained_epoch)
|
||||
_, arg_params, aux_params = mx.model.load_checkpoint(args.pretrained, args.pretrained_epoch)
|
||||
sym = get_symbol(args)
|
||||
|
||||
#label_name = 'softmax_label'
|
||||
@@ -250,9 +241,9 @@ def train_net(args):
|
||||
data_shape = data_shape,
|
||||
path_imgrec = path_imgrec,
|
||||
shuffle = True,
|
||||
rand_mirror = args.rand_mirror,
|
||||
rand_mirror = config.data_rand_mirror,
|
||||
mean = mean,
|
||||
cutoff = args.cutoff,
|
||||
cutoff = config.data_cutoff,
|
||||
ctx_num = args.ctx_num,
|
||||
images_per_identity = config.images_per_identity,
|
||||
triplet_params = triplet_params,
|
||||
@@ -267,15 +258,15 @@ def train_net(args):
|
||||
data_shape = data_shape,
|
||||
path_imgrec = path_imgrec,
|
||||
shuffle = True,
|
||||
rand_mirror = args.rand_mirror,
|
||||
rand_mirror = config.data_rand_mirror,
|
||||
mean = mean,
|
||||
cutoff = args.cutoff,
|
||||
color_jittering = args.color,
|
||||
images_filter = args.images_filter,
|
||||
cutoff = config.data_cutoff,
|
||||
color_jittering = config.data_color,
|
||||
images_filter = config.data_images_filter,
|
||||
)
|
||||
metric1 = AccMetric()
|
||||
eval_metrics = [mx.metric.create(metric1)]
|
||||
if args.ce_loss:
|
||||
if config.ce_loss:
|
||||
metric2 = LossValueMetric()
|
||||
eval_metrics.append( mx.metric.create(metric2) )
|
||||
|
||||
@@ -370,7 +361,7 @@ def train_net(args):
|
||||
arg, aux = model.get_params()
|
||||
mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)
|
||||
print('[%d]Accuracy-Highest: %1.5f'%(mbatch, highest_acc[-1]))
|
||||
if args.max_steps>0 and mbatch>args.max_steps:
|
||||
if config.max_steps>0 and mbatch>config.max_steps:
|
||||
sys.exit(0)
|
||||
|
||||
epoch_cb = None
|
||||
@@ -378,7 +369,7 @@ def train_net(args):
|
||||
|
||||
model.fit(train_dataiter,
|
||||
begin_epoch = begin_epoch,
|
||||
num_epoch = end_epoch,
|
||||
num_epoch = 999999,
|
||||
eval_data = val_dataiter,
|
||||
eval_metric = eval_metrics,
|
||||
kvstore = 'device',
|
||||
|
||||
Reference in New Issue
Block a user