add mobilenet

This commit is contained in:
Jia Guo
2017-12-06 19:46:33 +08:00
parent f297af8911
commit 81efbd1940
2 changed files with 116 additions and 5 deletions

View File

@@ -24,7 +24,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), 'symbols'))
import fresnet
import finception_resnet_v2
import spherenet
import marginalnet
import fmobilenet
#import inceptions
#import xception
#import lfw
@@ -143,8 +143,10 @@ def get_symbol(args, arg_params, aux_params):
if args.network[0]=='s':
embedding = spherenet.get_symbol(512, args.num_layers)
elif args.network[0]=='m':
print('init marginal', args.num_layers)
embedding = marginalnet.get_symbol(512, args.num_layers)
print('init mobilenet', args.num_layers)
embedding = fmobilenet.get_symbol(512,
use_se=args.use_se, version_input=args.version_input,
version_output=args.version_output, version_unit=args.version_unit)
elif args.network[0]=='i':
print('init inception-resnet-v2', args.num_layers)
embedding = finception_resnet_v2.get_symbol(512)
@@ -355,8 +357,8 @@ def train_net(args):
data_shape_dict = {'data': (args.batch_size,)+data_shape, 'softmax_label': (args.batch_size,)}
if args.network[0]=='s':
arg_params, aux_params = spherenet.init_weights(sym, data_shape_dict, args.num_layers)
elif args.network[0]=='m':
arg_params, aux_params = marginalnet.init_weights(sym, data_shape_dict, args.num_layers)
#elif args.network[0]=='m':
# arg_params, aux_params = marginalnet.init_weights(sym, data_shape_dict, args.num_layers)
#resnet_dcn.init_weights(sym, data_shape_dict, arg_params, aux_params)
else:
#sym, arg_params, aux_params = mx.model.load_checkpoint(pretrained, load_epoch)