diff --git a/recognition/sample_config.py b/recognition/sample_config.py index 8b11819..647d367 100644 --- a/recognition/sample_config.py +++ b/recognition/sample_config.py @@ -12,6 +12,7 @@ config.net_se = 0 config.net_act = 'prelu' config.net_unit = 3 config.net_input = 1 +config.net_blocks = [1,4,6,2] config.net_output = 'E' config.net_multiplier = 1.0 config.val_targets = ['lfw', 'cfp_fp', 'agedb_30'] @@ -59,6 +60,12 @@ network.y1.net_name = 'fmobilefacenet' network.y1.emb_size = 128 network.y1.net_output = 'GDC' +network.y2 = edict() +network.y2.net_name = 'fmobilefacenet' +network.y2.emb_size = 256 +network.y2.net_output = 'GDC' +network.y2.net_blocks = [2,8,16,4] + network.m1 = edict() network.m1.net_name = 'fmobilenet' network.m1.emb_size = 256 diff --git a/recognition/symbol/fmobilefacenet.py b/recognition/symbol/fmobilefacenet.py index 0cfacd6..3050fc9 100644 --- a/recognition/symbol/fmobilefacenet.py +++ b/recognition/symbol/fmobilefacenet.py @@ -53,14 +53,18 @@ def get_symbol(): data = mx.symbol.Variable(name="data") data = data-127.5 data = data*0.0078125 + blocks = config.net_blocks conv_1 = Conv(data, num_filter=64, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name="conv_1") - conv_2_dw = Conv(conv_1, num_group=64, num_filter=64, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_2_dw") + if blocks[0]==1: + conv_2_dw = Conv(conv_1, num_group=64, num_filter=64, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_2_dw") + else: + conv_2_dw = Residual(conv_1, num_block=blocks[0], num_out=64, kernel=(3, 3), stride=(1, 1), pad=(1, 1), num_group=64, name="res_2") conv_23 = DResidual(conv_2_dw, num_out=64, kernel=(3, 3), stride=(2, 2), pad=(1, 1), num_group=128, name="dconv_23") - conv_3 = Residual(conv_23, num_block=4, num_out=64, kernel=(3, 3), stride=(1, 1), pad=(1, 1), num_group=128, name="res_3") + conv_3 = Residual(conv_23, num_block=blocks[1], num_out=64, kernel=(3, 3), stride=(1, 1), pad=(1, 1), num_group=128, name="res_3") conv_34 = DResidual(conv_3, num_out=128, kernel=(3, 3), stride=(2, 2), pad=(1, 1), num_group=256, name="dconv_34") - conv_4 = Residual(conv_34, num_block=6, num_out=128, kernel=(3, 3), stride=(1, 1), pad=(1, 1), num_group=256, name="res_4") + conv_4 = Residual(conv_34, num_block=blocks[2], num_out=128, kernel=(3, 3), stride=(1, 1), pad=(1, 1), num_group=256, name="res_4") conv_45 = DResidual(conv_4, num_out=128, kernel=(3, 3), stride=(2, 2), pad=(1, 1), num_group=512, name="dconv_45") - conv_5 = Residual(conv_45, num_block=2, num_out=128, kernel=(3, 3), stride=(1, 1), pad=(1, 1), num_group=256, name="res_5") + conv_5 = Residual(conv_45, num_block=blocks[3], num_out=128, kernel=(3, 3), stride=(1, 1), pad=(1, 1), num_group=256, name="res_5") conv_6_sep = Conv(conv_5, num_filter=512, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_6sep") fc1 = symbol_utils.get_fc1(conv_6_sep, num_classes, fc_type) diff --git a/recognition/symbol/symbol_utils.py b/recognition/symbol/symbol_utils.py index 67ad913..470880e 100644 --- a/recognition/symbol/symbol_utils.py +++ b/recognition/symbol/symbol_utils.py @@ -37,6 +37,10 @@ def get_fc1(last_conv, num_classes, fc_type, input_channel=512): body = mx.symbol.Dropout(data=body, p=0.4) fc1 = mx.sym.FullyConnected(data=body, num_hidden=num_classes, name='pre_fc1') fc1 = mx.sym.BatchNorm(data=fc1, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='fc1') + elif fc_type=='FC': + body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn1') + fc1 = mx.sym.FullyConnected(data=body, num_hidden=num_classes, name='pre_fc1') + fc1 = mx.sym.BatchNorm(data=fc1, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='fc1') elif fc_type=='GAP': bn1 = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn1') relu1 = Act(data=bn1, act_type=config.net_act, name='relu1')