mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-18 14:55:42 +00:00
add mobilenet
This commit is contained in:
@@ -24,7 +24,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), 'symbols'))
|
||||
import fresnet
|
||||
import finception_resnet_v2
|
||||
import spherenet
|
||||
import marginalnet
|
||||
import fmobilenet
|
||||
#import inceptions
|
||||
#import xception
|
||||
#import lfw
|
||||
@@ -143,8 +143,10 @@ def get_symbol(args, arg_params, aux_params):
|
||||
if args.network[0]=='s':
|
||||
embedding = spherenet.get_symbol(512, args.num_layers)
|
||||
elif args.network[0]=='m':
|
||||
print('init marginal', args.num_layers)
|
||||
embedding = marginalnet.get_symbol(512, args.num_layers)
|
||||
print('init mobilenet', args.num_layers)
|
||||
embedding = fmobilenet.get_symbol(512,
|
||||
use_se=args.use_se, version_input=args.version_input,
|
||||
version_output=args.version_output, version_unit=args.version_unit)
|
||||
elif args.network[0]=='i':
|
||||
print('init inception-resnet-v2', args.num_layers)
|
||||
embedding = finception_resnet_v2.get_symbol(512)
|
||||
@@ -355,8 +357,8 @@ def train_net(args):
|
||||
data_shape_dict = {'data': (args.batch_size,)+data_shape, 'softmax_label': (args.batch_size,)}
|
||||
if args.network[0]=='s':
|
||||
arg_params, aux_params = spherenet.init_weights(sym, data_shape_dict, args.num_layers)
|
||||
elif args.network[0]=='m':
|
||||
arg_params, aux_params = marginalnet.init_weights(sym, data_shape_dict, args.num_layers)
|
||||
#elif args.network[0]=='m':
|
||||
# arg_params, aux_params = marginalnet.init_weights(sym, data_shape_dict, args.num_layers)
|
||||
#resnet_dcn.init_weights(sym, data_shape_dict, arg_params, aux_params)
|
||||
else:
|
||||
#sym, arg_params, aux_params = mx.model.load_checkpoint(pretrained, load_epoch)
|
||||
|
||||
Reference in New Issue
Block a user