mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-15 21:23:52 +00:00
add mobilenet multiplier
This commit is contained in:
@@ -42,40 +42,42 @@ def get_symbol(num_classes, **kwargs):
|
||||
assert version_input>=0
|
||||
version_output = kwargs.get('version_output', 'E')
|
||||
fc_type = version_output
|
||||
version_unit = kwargs.get('version_unit', 3)
|
||||
print(version_input, version_output, version_unit)
|
||||
#version_unit = kwargs.get('version_unit', 3)
|
||||
version_multiplier = kwargs.get('version_multiplier', 1.0)
|
||||
bf = int(32*version_multiplier)
|
||||
print(version_input, version_output, version_multiplier, bf)
|
||||
if version_input==0:
|
||||
conv_1 = Conv(data, num_filter=32, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name="conv_1") # 224/112
|
||||
conv_1 = Conv(data, num_filter=bf, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name="conv_1") # 224/112
|
||||
else:
|
||||
conv_1 = Conv(data, num_filter=32, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_1") # 224/112
|
||||
conv_2_dw = Conv(conv_1, num_group=32, num_filter=32, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_2_dw") # 112/112
|
||||
conv_2 = Conv(conv_2_dw, num_filter=64, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_2") # 112/112
|
||||
conv_3_dw = Conv(conv_2, num_group=64, num_filter=64, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name="conv_3_dw") # 112/56
|
||||
conv_3 = Conv(conv_3_dw, num_filter=128, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_3") # 56/56
|
||||
conv_4_dw = Conv(conv_3, num_group=128, num_filter=128, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_4_dw") # 56/56
|
||||
conv_4 = Conv(conv_4_dw, num_filter=128, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_4") # 56/56
|
||||
conv_5_dw = Conv(conv_4, num_group=128, num_filter=128, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name="conv_5_dw") # 56/28
|
||||
conv_5 = Conv(conv_5_dw, num_filter=256, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_5") # 28/28
|
||||
conv_6_dw = Conv(conv_5, num_group=256, num_filter=256, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_6_dw") # 28/28
|
||||
conv_6 = Conv(conv_6_dw, num_filter=256, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_6") # 28/28
|
||||
conv_7_dw = Conv(conv_6, num_group=256, num_filter=256, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name="conv_7_dw") # 28/14
|
||||
conv_7 = Conv(conv_7_dw, num_filter=512, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_7") # 14/14
|
||||
conv_1 = Conv(data, num_filter=bf, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_1") # 224/112
|
||||
conv_2_dw = Conv(conv_1, num_group=bf, num_filter=bf, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_2_dw") # 112/112
|
||||
conv_2 = Conv(conv_2_dw, num_filter=bf*2, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_2") # 112/112
|
||||
conv_3_dw = Conv(conv_2, num_group=bf*2, num_filter=bf*2, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name="conv_3_dw") # 112/56
|
||||
conv_3 = Conv(conv_3_dw, num_filter=bf*4, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_3") # 56/56
|
||||
conv_4_dw = Conv(conv_3, num_group=bf*4, num_filter=bf*4, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_4_dw") # 56/56
|
||||
conv_4 = Conv(conv_4_dw, num_filter=bf*4, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_4") # 56/56
|
||||
conv_5_dw = Conv(conv_4, num_group=bf*4, num_filter=bf*4, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name="conv_5_dw") # 56/28
|
||||
conv_5 = Conv(conv_5_dw, num_filter=bf*8, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_5") # 28/28
|
||||
conv_6_dw = Conv(conv_5, num_group=bf*8, num_filter=bf*8, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_6_dw") # 28/28
|
||||
conv_6 = Conv(conv_6_dw, num_filter=bf*8, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_6") # 28/28
|
||||
conv_7_dw = Conv(conv_6, num_group=bf*8, num_filter=bf*8, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name="conv_7_dw") # 28/14
|
||||
conv_7 = Conv(conv_7_dw, num_filter=bf*16, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_7") # 14/14
|
||||
|
||||
conv_8_dw = Conv(conv_7, num_group=512, num_filter=512, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_8_dw") # 14/14
|
||||
conv_8 = Conv(conv_8_dw, num_filter=512, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_8") # 14/14
|
||||
conv_9_dw = Conv(conv_8, num_group=512, num_filter=512, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_9_dw") # 14/14
|
||||
conv_9 = Conv(conv_9_dw, num_filter=512, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_9") # 14/14
|
||||
conv_10_dw = Conv(conv_9, num_group=512, num_filter=512, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_10_dw") # 14/14
|
||||
conv_10 = Conv(conv_10_dw, num_filter=512, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_10") # 14/14
|
||||
conv_11_dw = Conv(conv_10, num_group=512, num_filter=512, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_11_dw") # 14/14
|
||||
conv_11 = Conv(conv_11_dw, num_filter=512, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_11") # 14/14
|
||||
conv_12_dw = Conv(conv_11, num_group=512, num_filter=512, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_12_dw") # 14/14
|
||||
conv_12 = Conv(conv_12_dw, num_filter=512, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_12") # 14/14
|
||||
conv_8_dw = Conv(conv_7, num_group=bf*16, num_filter=bf*16, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_8_dw") # 14/14
|
||||
conv_8 = Conv(conv_8_dw, num_filter=bf*16, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_8") # 14/14
|
||||
conv_9_dw = Conv(conv_8, num_group=bf*16, num_filter=bf*16, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_9_dw") # 14/14
|
||||
conv_9 = Conv(conv_9_dw, num_filter=bf*16, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_9") # 14/14
|
||||
conv_10_dw = Conv(conv_9, num_group=bf*16, num_filter=bf*16, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_10_dw") # 14/14
|
||||
conv_10 = Conv(conv_10_dw, num_filter=bf*16, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_10") # 14/14
|
||||
conv_11_dw = Conv(conv_10, num_group=bf*16, num_filter=bf*16, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_11_dw") # 14/14
|
||||
conv_11 = Conv(conv_11_dw, num_filter=bf*16, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_11") # 14/14
|
||||
conv_12_dw = Conv(conv_11, num_group=bf*16, num_filter=bf*16, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_12_dw") # 14/14
|
||||
conv_12 = Conv(conv_12_dw, num_filter=bf*16, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_12") # 14/14
|
||||
|
||||
conv_13_dw = Conv(conv_12, num_group=512, num_filter=512, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name="conv_13_dw") # 14/7
|
||||
conv_13 = Conv(conv_13_dw, num_filter=1024, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_13") # 7/7
|
||||
conv_14_dw = Conv(conv_13, num_group=1024, num_filter=1024, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_14_dw") # 7/7
|
||||
conv_14 = Conv(conv_14_dw, num_filter=1024, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_14") # 7/7
|
||||
conv_13_dw = Conv(conv_12, num_group=bf*16, num_filter=bf*16, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name="conv_13_dw") # 14/7
|
||||
conv_13 = Conv(conv_13_dw, num_filter=bf*32, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_13") # 7/7
|
||||
conv_14_dw = Conv(conv_13, num_group=bf*32, num_filter=bf*32, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_14_dw") # 7/7
|
||||
conv_14 = Conv(conv_14_dw, num_filter=bf*32, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_14") # 7/7
|
||||
body = conv_14
|
||||
fc1 = symbol_utils.get_fc1(body, num_classes, fc_type)
|
||||
return fc1
|
||||
|
||||
@@ -98,6 +98,7 @@ def parse_args():
|
||||
parser.add_argument('--version-input', type=int, default=1, help='network input config')
|
||||
parser.add_argument('--version-output', type=str, default='E', help='network embedding output config')
|
||||
parser.add_argument('--version-unit', type=int, default=3, help='resnet unit config')
|
||||
parser.add_argument('--version-multiplier', type=float, default=1.0, help='filters multiplier')
|
||||
parser.add_argument('--version-act', type=str, default='prelu', help='network activation config')
|
||||
parser.add_argument('--use-deformable', type=int, default=0, help='use deformable cnn in network')
|
||||
parser.add_argument('--lr', type=float, default=0.1, help='start learning rate')
|
||||
@@ -144,8 +145,9 @@ def get_symbol(args, arg_params, aux_params):
|
||||
print('init mobilenet', args.num_layers)
|
||||
if args.num_layers==1:
|
||||
embedding = fmobilenet.get_symbol(args.emb_size,
|
||||
version_se=args.version_se, version_input=args.version_input,
|
||||
version_output=args.version_output, version_unit=args.version_unit)
|
||||
version_input=args.version_input,
|
||||
version_output=args.version_output,
|
||||
version_multiplier = args.version_multiplier)
|
||||
else:
|
||||
embedding = fmobilenetv2.get_symbol(args.emb_size)
|
||||
elif args.network[0]=='i':
|
||||
@@ -450,6 +452,7 @@ def train_net(args):
|
||||
initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2) #inception
|
||||
else:
|
||||
initializer = mx.init.Xavier(rnd_type='uniform', factor_type="in", magnitude=2)
|
||||
#initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
|
||||
_rescale = 1.0/args.ctx_num
|
||||
opt = optimizer.SGD(learning_rate=base_lr, momentum=base_mom, wd=base_wd, rescale_grad=_rescale)
|
||||
som = 20
|
||||
|
||||
Reference in New Issue
Block a user