From 2e75cd4df13a3fc494a1de02a718d7adb4b617ec Mon Sep 17 00:00:00 2001 From: nttstar Date: Fri, 18 Jan 2019 13:23:48 +0800 Subject: [PATCH] tiny --- recognition/symbol/symbol_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/recognition/symbol/symbol_utils.py b/recognition/symbol/symbol_utils.py index 1402171..e07ebef 100644 --- a/recognition/symbol/symbol_utils.py +++ b/recognition/symbol/symbol_utils.py @@ -26,7 +26,7 @@ def Linear(data, num_filter=1, kernel=(1, 1), stride=(1, 1), pad=(0, 0), num_gro bn = mx.sym.BatchNorm(data=conv, name='%s%s_batchnorm' %(name, suffix), fix_gamma=False,momentum=bn_mom) return bn -def get_fc1(last_conv, num_classes, fc_type): +def get_fc1(last_conv, num_classes, fc_type, input_channel=512): body = last_conv if fc_type=='Z': body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn1') @@ -75,7 +75,7 @@ def get_fc1(last_conv, num_classes, fc_type): fc1 = mx.sym.Flatten(data=fc1) fc1 = mx.sym.BatchNorm(data=fc1, fix_gamma=True, eps=2e-5, momentum=0.9, name='fc1') elif fc_type=="GDC": #mobilefacenet_v1 - conv_6_dw = Linear(last_conv, num_filter=512, num_group=512, kernel=(7,7), pad=(0, 0), stride=(1, 1), name="conv_6dw7_7") + conv_6_dw = Linear(last_conv, num_filter=input_channel, num_group=input_channel, kernel=(7,7), pad=(0, 0), stride=(1, 1), name="conv_6dw7_7") conv_6_f = mx.sym.FullyConnected(data=conv_6_dw, num_hidden=num_classes, name='pre_fc1') fc1 = mx.sym.BatchNorm(data=conv_6_f, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='fc1') elif fc_type=='F':