diff --git a/src/symbols/fmobilenet.py b/src/symbols/fmobilenet.py index 24f02bd..5a2d183 100644 --- a/src/symbols/fmobilenet.py +++ b/src/symbols/fmobilenet.py @@ -42,40 +42,42 @@ def get_symbol(num_classes, **kwargs): assert version_input>=0 version_output = kwargs.get('version_output', 'E') fc_type = version_output - version_unit = kwargs.get('version_unit', 3) - print(version_input, version_output, version_unit) + #version_unit = kwargs.get('version_unit', 3) + version_multiplier = kwargs.get('version_multiplier', 1.0) + bf = int(32*version_multiplier) + print(version_input, version_output, version_multiplier, bf) if version_input==0: - conv_1 = Conv(data, num_filter=32, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name="conv_1") # 224/112 + conv_1 = Conv(data, num_filter=bf, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name="conv_1") # 224/112 else: - conv_1 = Conv(data, num_filter=32, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_1") # 224/112 - conv_2_dw = Conv(conv_1, num_group=32, num_filter=32, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_2_dw") # 112/112 - conv_2 = Conv(conv_2_dw, num_filter=64, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_2") # 112/112 - conv_3_dw = Conv(conv_2, num_group=64, num_filter=64, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name="conv_3_dw") # 112/56 - conv_3 = Conv(conv_3_dw, num_filter=128, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_3") # 56/56 - conv_4_dw = Conv(conv_3, num_group=128, num_filter=128, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_4_dw") # 56/56 - conv_4 = Conv(conv_4_dw, num_filter=128, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_4") # 56/56 - conv_5_dw = Conv(conv_4, num_group=128, num_filter=128, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name="conv_5_dw") # 56/28 - conv_5 = Conv(conv_5_dw, num_filter=256, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_5") # 28/28 - conv_6_dw = Conv(conv_5, num_group=256, num_filter=256, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_6_dw") # 28/28 - conv_6 = Conv(conv_6_dw, num_filter=256, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_6") # 28/28 - conv_7_dw = Conv(conv_6, num_group=256, num_filter=256, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name="conv_7_dw") # 28/14 - conv_7 = Conv(conv_7_dw, num_filter=512, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_7") # 14/14 + conv_1 = Conv(data, num_filter=bf, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_1") # 224/112 + conv_2_dw = Conv(conv_1, num_group=bf, num_filter=bf, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_2_dw") # 112/112 + conv_2 = Conv(conv_2_dw, num_filter=bf*2, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_2") # 112/112 + conv_3_dw = Conv(conv_2, num_group=bf*2, num_filter=bf*2, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name="conv_3_dw") # 112/56 + conv_3 = Conv(conv_3_dw, num_filter=bf*4, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_3") # 56/56 + conv_4_dw = Conv(conv_3, num_group=bf*4, num_filter=bf*4, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_4_dw") # 56/56 + conv_4 = Conv(conv_4_dw, num_filter=bf*4, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_4") # 56/56 + conv_5_dw = Conv(conv_4, num_group=bf*4, num_filter=bf*4, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name="conv_5_dw") # 56/28 + conv_5 = Conv(conv_5_dw, num_filter=bf*8, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_5") # 28/28 + conv_6_dw = Conv(conv_5, num_group=bf*8, num_filter=bf*8, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_6_dw") # 28/28 + conv_6 = Conv(conv_6_dw, num_filter=bf*8, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_6") # 28/28 + conv_7_dw = Conv(conv_6, num_group=bf*8, num_filter=bf*8, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name="conv_7_dw") # 28/14 + conv_7 = Conv(conv_7_dw, num_filter=bf*16, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_7") # 14/14 - conv_8_dw = Conv(conv_7, num_group=512, num_filter=512, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_8_dw") # 14/14 - conv_8 = Conv(conv_8_dw, num_filter=512, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_8") # 14/14 - conv_9_dw = Conv(conv_8, num_group=512, num_filter=512, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_9_dw") # 14/14 - conv_9 = Conv(conv_9_dw, num_filter=512, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_9") # 14/14 - conv_10_dw = Conv(conv_9, num_group=512, num_filter=512, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_10_dw") # 14/14 - conv_10 = Conv(conv_10_dw, num_filter=512, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_10") # 14/14 - conv_11_dw = Conv(conv_10, num_group=512, num_filter=512, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_11_dw") # 14/14 - conv_11 = Conv(conv_11_dw, num_filter=512, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_11") # 14/14 - conv_12_dw = Conv(conv_11, num_group=512, num_filter=512, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_12_dw") # 14/14 - conv_12 = Conv(conv_12_dw, num_filter=512, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_12") # 14/14 + conv_8_dw = Conv(conv_7, num_group=bf*16, num_filter=bf*16, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_8_dw") # 14/14 + conv_8 = Conv(conv_8_dw, num_filter=bf*16, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_8") # 14/14 + conv_9_dw = Conv(conv_8, num_group=bf*16, num_filter=bf*16, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_9_dw") # 14/14 + conv_9 = Conv(conv_9_dw, num_filter=bf*16, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_9") # 14/14 + conv_10_dw = Conv(conv_9, num_group=bf*16, num_filter=bf*16, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_10_dw") # 14/14 + conv_10 = Conv(conv_10_dw, num_filter=bf*16, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_10") # 14/14 + conv_11_dw = Conv(conv_10, num_group=bf*16, num_filter=bf*16, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_11_dw") # 14/14 + conv_11 = Conv(conv_11_dw, num_filter=bf*16, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_11") # 14/14 + conv_12_dw = Conv(conv_11, num_group=bf*16, num_filter=bf*16, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_12_dw") # 14/14 + conv_12 = Conv(conv_12_dw, num_filter=bf*16, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_12") # 14/14 - conv_13_dw = Conv(conv_12, num_group=512, num_filter=512, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name="conv_13_dw") # 14/7 - conv_13 = Conv(conv_13_dw, num_filter=1024, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_13") # 7/7 - conv_14_dw = Conv(conv_13, num_group=1024, num_filter=1024, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_14_dw") # 7/7 - conv_14 = Conv(conv_14_dw, num_filter=1024, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_14") # 7/7 + conv_13_dw = Conv(conv_12, num_group=bf*16, num_filter=bf*16, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name="conv_13_dw") # 14/7 + conv_13 = Conv(conv_13_dw, num_filter=bf*32, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_13") # 7/7 + conv_14_dw = Conv(conv_13, num_group=bf*32, num_filter=bf*32, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_14_dw") # 7/7 + conv_14 = Conv(conv_14_dw, num_filter=bf*32, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_14") # 7/7 body = conv_14 fc1 = symbol_utils.get_fc1(body, num_classes, fc_type) return fc1 diff --git a/src/train_softmax.py b/src/train_softmax.py index 5716ae7..80fa921 100644 --- a/src/train_softmax.py +++ b/src/train_softmax.py @@ -98,6 +98,7 @@ def parse_args(): parser.add_argument('--version-input', type=int, default=1, help='network input config') parser.add_argument('--version-output', type=str, default='E', help='network embedding output config') parser.add_argument('--version-unit', type=int, default=3, help='resnet unit config') + parser.add_argument('--version-multiplier', type=float, default=1.0, help='filters multiplier') parser.add_argument('--version-act', type=str, default='prelu', help='network activation config') parser.add_argument('--use-deformable', type=int, default=0, help='use deformable cnn in network') parser.add_argument('--lr', type=float, default=0.1, help='start learning rate') @@ -144,8 +145,9 @@ def get_symbol(args, arg_params, aux_params): print('init mobilenet', args.num_layers) if args.num_layers==1: embedding = fmobilenet.get_symbol(args.emb_size, - version_se=args.version_se, version_input=args.version_input, - version_output=args.version_output, version_unit=args.version_unit) + version_input=args.version_input, + version_output=args.version_output, + version_multiplier = args.version_multiplier) else: embedding = fmobilenetv2.get_symbol(args.emb_size) elif args.network[0]=='i': @@ -450,6 +452,7 @@ def train_net(args): initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2) #inception else: initializer = mx.init.Xavier(rnd_type='uniform', factor_type="in", magnitude=2) + #initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style _rescale = 1.0/args.ctx_num opt = optimizer.SGD(learning_rate=base_lr, momentum=base_mom, wd=base_wd, rescale_grad=_rescale) som = 20