Files
insightface/detection/retinaface/rcnn/tools/train_rpn.py
2021-06-19 23:57:15 +08:00

301 lines
11 KiB
Python

import argparse
import logging
import pprint
import mxnet as mx
from ..config import config, default, generate_config
from ..symbol import *
from ..core import callback, metric
from ..core.loader import AnchorLoaderFPN
from ..core.module import MutableModule
from ..utils.load_data import load_gt_roidb, merge_roidb, filter_roidb
from ..utils.load_model import load_param
def train_rpn(network, dataset, image_set, root_path, dataset_path, frequent,
kvstore, work_load_list, no_flip, no_shuffle, resume, ctx,
pretrained, epoch, prefix, begin_epoch, end_epoch, train_shared,
lr, lr_step):
# set up logger
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
# setup config
assert config.TRAIN.BATCH_IMAGES == 1
# load symbol
sym = eval('get_' + network + '_rpn')()
feat_sym = []
for stride in config.RPN_FEAT_STRIDE:
feat_sym.append(sym.get_internals()['rpn_cls_score_stride%s_output' %
stride])
# setup multi-gpu
batch_size = len(ctx)
input_batch_size = config.TRAIN.BATCH_IMAGES * batch_size
# print config
pprint.pprint(config)
# load dataset and prepare imdb for training
image_sets = [iset for iset in image_set.split('+')]
roidbs = [
load_gt_roidb(dataset,
image_set,
root_path,
dataset_path,
flip=not no_flip) for image_set in image_sets
]
roidb = merge_roidb(roidbs)
roidb = filter_roidb(roidb)
# load training data
#train_data = AnchorLoaderFPN(feat_sym, roidb, batch_size=input_batch_size, shuffle=not no_shuffle,
# ctx=ctx, work_load_list=work_load_list,
# feat_stride=config.RPN_FEAT_STRIDE, anchor_scales=config.ANCHOR_SCALES,
# anchor_ratios=config.ANCHOR_RATIOS, aspect_grouping=config.TRAIN.ASPECT_GROUPING,
# allowed_border=9999)
train_data = AnchorLoaderFPN(feat_sym,
roidb,
batch_size=input_batch_size,
shuffle=not no_shuffle,
ctx=ctx,
work_load_list=work_load_list)
# infer max shape
max_data_shape = [('data', (input_batch_size, 3,
max([v[0] for v in config.SCALES]),
max([v[1] for v in config.SCALES])))]
max_data_shape, max_label_shape = train_data.infer_shape(max_data_shape)
print 'providing maximum shape', max_data_shape, max_label_shape
# infer shape
data_shape_dict = dict(train_data.provide_data + train_data.provide_label)
arg_shape, out_shape, aux_shape = sym.infer_shape(**data_shape_dict)
arg_shape_dict = dict(zip(sym.list_arguments(), arg_shape))
out_shape_dict = zip(sym.list_outputs(), out_shape)
aux_shape_dict = dict(zip(sym.list_auxiliary_states(), aux_shape))
print 'output shape'
pprint.pprint(out_shape_dict)
# load and initialize params
if resume:
arg_params, aux_params = load_param(prefix, begin_epoch, convert=True)
else:
arg_params, aux_params = load_param(pretrained, epoch, convert=True)
init = mx.init.Xavier(factor_type="in",
rnd_type='gaussian',
magnitude=2)
init_internal = mx.init.Normal(sigma=0.01)
for k in sym.list_arguments():
if k in data_shape_dict:
continue
if k not in arg_params:
print 'init', k
arg_params[k] = mx.nd.zeros(shape=arg_shape_dict[k])
if not k.endswith('bias'):
init_internal(k, arg_params[k])
for k in sym.list_auxiliary_states():
if k not in aux_params:
print 'init', k
aux_params[k] = mx.nd.zeros(shape=aux_shape_dict[k])
init(k, aux_params[k])
# check parameter shapes
for k in sym.list_arguments():
if k in data_shape_dict:
continue
assert k in arg_params, k + ' not initialized'
assert arg_params[k].shape == arg_shape_dict[k], \
'shape inconsistent for ' + k + ' inferred ' + str(arg_shape_dict[k]) + ' provided ' + str(arg_params[k].shape)
for k in sym.list_auxiliary_states():
assert k in aux_params, k + ' not initialized'
assert aux_params[k].shape == aux_shape_dict[k], \
'shape inconsistent for ' + k + ' inferred ' + str(aux_shape_dict[k]) + ' provided ' + str(aux_params[k].shape)
# create solver
data_names = [k[0] for k in train_data.provide_data]
label_names = [k[0] for k in train_data.provide_label]
if train_shared:
fixed_param_prefix = config.FIXED_PARAMS_SHARED
else:
fixed_param_prefix = config.FIXED_PARAMS
mod = MutableModule(sym,
data_names=data_names,
label_names=label_names,
logger=logger,
context=ctx,
work_load_list=work_load_list,
max_data_shapes=max_data_shape,
max_label_shapes=max_label_shape,
fixed_param_prefix=fixed_param_prefix)
# decide training params
# metric
eval_metric = metric.RPNAccMetric()
cls_metric = metric.RPNLogLossMetric()
bbox_metric = metric.RPNL1LossMetric()
eval_metrics = mx.metric.CompositeEvalMetric()
for child_metric in [eval_metric, cls_metric, bbox_metric]:
eval_metrics.add(child_metric)
# callback
batch_end_callback = []
batch_end_callback.append(
mx.callback.Speedometer(train_data.batch_size, frequent=frequent))
epoch_end_callback = mx.callback.do_checkpoint(prefix)
# decide learning rate
base_lr = lr
lr_factor = 0.1
lr_epoch = [int(epoch) for epoch in lr_step.split(',')]
lr_epoch_diff = [
epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch
]
lr = base_lr * (lr_factor**(len(lr_epoch) - len(lr_epoch_diff)))
lr_iters = [
int(epoch * len(roidb) / batch_size) for epoch in lr_epoch_diff
]
print 'lr', lr, 'lr_epoch_diff', lr_epoch_diff, 'lr_iters', lr_iters
lr_scheduler = mx.lr_scheduler.MultiFactorScheduler(lr_iters, lr_factor)
# optimizer
optimizer_params = {
'momentum': 0.9,
'wd': 0.0001,
'learning_rate': lr,
'lr_scheduler': lr_scheduler,
'rescale_grad': (1.0 / batch_size),
'clip_gradient': 5
}
# train
mod.fit(train_data,
eval_metric=eval_metrics,
epoch_end_callback=epoch_end_callback,
batch_end_callback=batch_end_callback,
kvstore=kvstore,
optimizer='sgd',
optimizer_params=optimizer_params,
arg_params=arg_params,
aux_params=aux_params,
begin_epoch=begin_epoch,
num_epoch=end_epoch)
def parse_args():
parser = argparse.ArgumentParser(
description='Train a Region Proposal Network')
# general
parser.add_argument('--network',
help='network name',
default=default.network,
type=str)
parser.add_argument('--dataset',
help='dataset name',
default=default.dataset,
type=str)
args, rest = parser.parse_known_args()
generate_config(args.network, args.dataset)
parser.add_argument('--image_set',
help='image_set name',
default=default.image_set,
type=str)
parser.add_argument('--root_path',
help='output data folder',
default=default.root_path,
type=str)
parser.add_argument('--dataset_path',
help='dataset path',
default=default.dataset_path,
type=str)
# training
parser.add_argument('--frequent',
help='frequency of logging',
default=default.frequent,
type=int)
parser.add_argument('--kvstore',
help='the kv-store type',
default=default.kvstore,
type=str)
parser.add_argument('--work_load_list',
help='work load for different devices',
default=None,
type=list)
parser.add_argument('--no_flip',
help='disable flip images',
action='store_true')
parser.add_argument('--no_shuffle',
help='disable random shuffle',
action='store_true')
parser.add_argument('--resume',
help='continue training',
action='store_true')
# rpn
parser.add_argument('--gpus',
help='GPU device to train with',
default='0',
type=str)
parser.add_argument('--pretrained',
help='pretrained model prefix',
default=default.pretrained,
type=str)
parser.add_argument('--pretrained_epoch',
help='pretrained model epoch',
default=default.pretrained_epoch,
type=int)
parser.add_argument('--prefix',
help='new model prefix',
default=default.rpn_prefix,
type=str)
parser.add_argument('--begin_epoch',
help='begin epoch of training',
default=0,
type=int)
parser.add_argument('--end_epoch',
help='end epoch of training',
default=default.rpn_epoch,
type=int)
parser.add_argument('--lr',
help='base learning rate',
default=default.rpn_lr,
type=float)
parser.add_argument('--lr_step',
help='learning rate steps (in epoch)',
default=default.rpn_lr_step,
type=str)
parser.add_argument('--train_shared',
help='second round train shared params',
action='store_true')
args = parser.parse_args()
return args
def main():
args = parse_args()
print 'Called with argument:', args
ctx = [mx.gpu(int(i)) for i in args.gpus.split(',')]
train_rpn(args.network,
args.dataset,
args.image_set,
args.root_path,
args.dataset_path,
args.frequent,
args.kvstore,
args.work_load_list,
args.no_flip,
args.no_shuffle,
args.resume,
ctx,
args.pretrained,
args.pretrained_epoch,
args.prefix,
args.begin_epoch,
args.end_epoch,
train_shared=args.train_shared,
lr=args.lr,
lr_step=args.lr_step)
if __name__ == '__main__':
main()