diff --git a/src/symbols/fdpn.py b/src/symbols/fdpn.py new file mode 100644 index 0000000..3ae77f0 --- /dev/null +++ b/src/symbols/fdpn.py @@ -0,0 +1,230 @@ +import mxnet as mx +import symbol_utils + +bn_momentum = 0.9 + +def BK(data): + return mx.symbol.BlockGrad(data=data) + +# - - - - - - - - - - - - - - - - - - - - - - - +# Fundamental Elements +def BN(data, fix_gamma=False, momentum=bn_momentum, name=None): + bn = mx.symbol.BatchNorm( data=data, fix_gamma=fix_gamma, momentum=bn_momentum, name=('%s__bn'%name)) + return bn + +def AC(data, act_type='relu', name=None): + act = mx.symbol.Activation(data=data, act_type=act_type, name=('%s__%s' % (name, act_type))) + return act + +def BN_AC(data, momentum=bn_momentum, name=None): + bn = BN(data=data, name=name, fix_gamma=False, momentum=momentum) + bn_ac = AC(data=bn, name=name) + return bn_ac + +def Conv(data, num_filter, kernel, stride=(1,1), pad=(0, 0), name=None, no_bias=True, w=None, b=None, attr=None, num_group=1): + Convolution = mx.symbol.Convolution + if w is None: + conv = Convolution(data=data, num_filter=num_filter, num_group=num_group, kernel=kernel, pad=pad, stride=stride, name=('%s__conv' %name), no_bias=no_bias, attr=attr) + else: + if b is None: + conv = Convolution(data=data, num_filter=num_filter, num_group=num_group, kernel=kernel, pad=pad, stride=stride, name=('%s__conv' %name), no_bias=no_bias, weight=w, attr=attr) + else: + conv = Convolution(data=data, num_filter=num_filter, num_group=num_group, kernel=kernel, pad=pad, stride=stride, name=('%s__conv' %name), no_bias=False, bias=b, weight=w, attr=attr) + return conv + +# - - - - - - - - - - - - - - - - - - - - - - - +# Standard Common functions < CVPR > +def Conv_BN( data, num_filter, kernel, pad, stride=(1,1), name=None, w=None, b=None, no_bias=True, attr=None, num_group=1): + cov = Conv( data=data, num_filter=num_filter, num_group=num_group, kernel=kernel, pad=pad, stride=stride, name=name, w=w, b=b, no_bias=no_bias, attr=attr) + cov_bn = BN( data=cov, name=('%s__bn' % name)) + return cov_bn + +def Conv_BN_AC(data, num_filter, kernel, pad, stride=(1,1), name=None, w=None, b=None, no_bias=True, attr=None, num_group=1): + cov_bn = Conv_BN(data=data, num_filter=num_filter, num_group=num_group, kernel=kernel, pad=pad, stride=stride, name=name, w=w, b=b, no_bias=no_bias, attr=attr) + cov_ba = AC( data=cov_bn, name=('%s__ac' % name)) + return cov_ba + +# - - - - - - - - - - - - - - - - - - - - - - - +# Standard Common functions < ECCV > +def BN_Conv( data, num_filter, kernel, pad, stride=(1,1), name=None, w=None, b=None, no_bias=True, attr=None, num_group=1): + bn = BN( data=data, name=('%s__bn' % name)) + bn_cov = Conv( data=bn, num_filter=num_filter, num_group=num_group, kernel=kernel, pad=pad, stride=stride, name=name, w=w, b=b, no_bias=no_bias, attr=attr) + return bn_cov + +def AC_Conv( data, num_filter, kernel, pad, stride=(1,1), name=None, w=None, b=None, no_bias=True, attr=None, num_group=1): + ac = AC( data=data, name=('%s__ac' % name)) + ac_cov = Conv( data=ac, num_filter=num_filter, num_group=num_group, kernel=kernel, pad=pad, stride=stride, name=name, w=w, b=b, no_bias=no_bias, attr=attr) + return ac_cov + +def BN_AC_Conv(data, num_filter, kernel, pad, stride=(1,1), name=None, w=None, b=None, no_bias=True, attr=None, num_group=1): + bn = BN( data=data, name=('%s__bn' % name)) + ba_cov = AC_Conv(data=bn, num_filter=num_filter, num_group=num_group, kernel=kernel, pad=pad, stride=stride, name=name, w=w, b=b, no_bias=no_bias, attr=attr) + return ba_cov + +def DualPathFactory(data, num_1x1_a, num_3x3_b, num_1x1_c, name, inc, G, _type='normal'): + kw = 3 + kh = 3 + pw = (kw-1)/2 + ph = (kh-1)/2 + + # type + if _type is 'proj': + key_stride = 1 + has_proj = True + if _type is 'down': + key_stride = 2 + has_proj = True + if _type is 'normal': + key_stride = 1 + has_proj = False + + # PROJ + if type(data) is list: + data_in = mx.symbol.Concat(*[data[0], data[1]], name=('%s_cat-input' % name)) + else: + data_in = data + + if has_proj: + c1x1_w = BN_AC_Conv( data=data_in, num_filter=(num_1x1_c+2*inc), kernel=( 1, 1), stride=(key_stride, key_stride), name=('%s_c1x1-w(s/%d)' %(name, key_stride)), pad=(0, 0)) + data_o1 = mx.symbol.slice_axis(data=c1x1_w, axis=1, begin=0, end=num_1x1_c, name=('%s_c1x1-w(s/%d)-split1' %(name, key_stride))) + data_o2 = mx.symbol.slice_axis(data=c1x1_w, axis=1, begin=num_1x1_c, end=(num_1x1_c+2*inc), name=('%s_c1x1-w(s/%d)-split2' %(name, key_stride))) + else: + data_o1 = data[0] + data_o2 = data[1] + + # MAIN + c1x1_a = BN_AC_Conv( data=data_in, num_filter=num_1x1_a, kernel=( 1, 1), pad=( 0, 0), name=('%s_c1x1-a' % name)) + c3x3_b = BN_AC_Conv( data=c1x1_a, num_filter=num_3x3_b, kernel=(kw, kh), pad=(pw, ph), name=('%s_c%dx%d-b' % (name,kw,kh)), stride=(key_stride,key_stride), num_group=G) + c1x1_c = BN_AC_Conv( data=c3x3_b, num_filter=(num_1x1_c+inc), kernel=( 1, 1), pad=( 0, 0), name=('%s_c1x1-c' % name)) + c1x1_c1= mx.symbol.slice_axis(data=c1x1_c, axis=1, begin=0, end=num_1x1_c, name=('%s_c1x1-c-split1' % name)) + c1x1_c2= mx.symbol.slice_axis(data=c1x1_c, axis=1, begin=num_1x1_c, end=(num_1x1_c+inc), name=('%s_c1x1-c-split2' % name)) + + # OUTPUTS + summ = mx.symbol.ElementWiseSum(*[data_o1, c1x1_c1], name=('%s_sum' % name)) + dense = mx.symbol.Concat( *[data_o2, c1x1_c2], name=('%s_cat' % name)) + + return [summ, dense] + +k_R = 160 + +G = 40 + +k_sec = { 2: 4, \ + 3: 8, \ + 4: 28, \ + 5: 3 } + +inc_sec= { 2: 16, \ + 3: 32, \ + 4: 32, \ + 5: 128 } + +def get_symbol(num_classes = 1000, num_layers=92, **kwargs): + if num_layers==68: + k_R = 128 + G = 32 + k_sec = { 2: 3, \ + 3: 4, \ + 4: 12, \ + 5: 3 } + inc_sec= { 2: 16, \ + 3: 32, \ + 4: 32, \ + 5: 64 } + elif num_layers==92: + k_R = 96 + G = 32 + k_sec = { 2: 3, \ + 3: 4, \ + 4: 20, \ + 5: 3 } + inc_sec= { 2: 16, \ + 3: 32, \ + 4: 24, \ + 5: 128 } + elif num_layers==107: + k_R = 200 + G = 50 + k_sec = { 2: 4, \ + 3: 8, \ + 4: 20, \ + 5: 3 } + inc_sec= { 2: 20, \ + 3: 64, \ + 4: 64, \ + 5: 128 } + elif num_layers==131: + k_R = 160 + G = 40 + k_sec = { 2: 4, \ + 3: 8, \ + 4: 28, \ + 5: 3 } + inc_sec= { 2: 16, \ + 3: 32, \ + 4: 32, \ + 5: 128 } + else: + raise ValueError("no experiments done on dpn num_layers {}, you can do it yourself".format(num_layers)) + + version_se = kwargs.get('version_se', 1) + version_input = kwargs.get('version_input', 1) + assert version_input>=0 + version_output = kwargs.get('version_output', 'E') + fc_type = version_output + version_unit = kwargs.get('version_unit', 3) + print(version_se, version_input, version_output, version_unit) + + ## define Dual Path Network + data = mx.symbol.Variable(name="data") + data = data-127.5 + data = data*0.0078125 + + # conv1 + if version_input==0: + conv1_x_1 = Conv(data=data, num_filter=128, kernel=(7, 7), name='conv1_x_1', pad=(3,3), stride=(2,2)) + else: + conv1_x_1 = Conv(data=data, num_filter=128, kernel=(3, 3), name='conv1_x_1', pad=(3,3), stride=(1,1)) + conv1_x_1 = BN_AC(conv1_x_1, name='conv1_x_1__relu-sp') + conv1_x_x = mx.symbol.Pooling(data=conv1_x_1, pool_type="max", kernel=(3, 3), pad=(1,1), stride=(2,2), name="pool1") + + # conv2 + bw = 256 + inc= inc_sec[2] + R = (k_R*bw)/256 + conv2_x_x = DualPathFactory( conv1_x_x, R, R, bw, 'conv2_x__1', inc, G, 'proj' ) + for i_ly in range(2, k_sec[2]+1): + conv2_x_x = DualPathFactory( conv2_x_x, R, R, bw, ('conv2_x__%d'% i_ly), inc, G, 'normal') + + # conv3 + bw = 512 + inc= inc_sec[3] + R = (k_R*bw)/256 + conv3_x_x = DualPathFactory( conv2_x_x, R, R, bw, 'conv3_x__1', inc, G, 'down' ) + for i_ly in range(2, k_sec[3]+1): + conv3_x_x = DualPathFactory( conv3_x_x, R, R, bw, ('conv3_x__%d'% i_ly), inc, G, 'normal') + + # conv4 + bw = 1024 + inc= inc_sec[4] + R = (k_R*bw)/256 + conv4_x_x = DualPathFactory( conv3_x_x, R, R, bw, 'conv4_x__1', inc, G, 'down' ) + for i_ly in range(2, k_sec[4]+1): + conv4_x_x = DualPathFactory( conv4_x_x, R, R, bw, ('conv4_x__%d'% i_ly), inc, G, 'normal') + + # conv5 + bw = 2048 + inc= inc_sec[5] + R = (k_R*bw)/256 + conv5_x_x = DualPathFactory( conv4_x_x, R, R, bw, 'conv5_x__1', inc, G, 'down' ) + for i_ly in range(2, k_sec[5]+1): + conv5_x_x = DualPathFactory( conv5_x_x, R, R, bw, ('conv5_x__%d'% i_ly), inc, G, 'normal') + + # output: concat + conv5_x_x = mx.symbol.Concat(*[conv5_x_x[0], conv5_x_x[1]], name='conv5_x_x_cat-final') + #conv5_x_x = BN_AC(conv5_x_x, name='conv5_x_x__relu-sp') + before_pool = conv5_x_x + fc1 = symbol_utils.get_fc1(before_pool, num_classes, fc_type) + return fc1 + + diff --git a/src/train_softmax.py b/src/train_softmax.py index 66be449..4967503 100644 --- a/src/train_softmax.py +++ b/src/train_softmax.py @@ -27,6 +27,7 @@ import finception_resnet_v2 import fmobilenet import fxception import fdensenet +import fdpn #import lfw import verification import sklearn @@ -149,6 +150,11 @@ def get_symbol(args, arg_params, aux_params): embedding = fxception.get_symbol(512, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) + elif args.network[0]=='p': + print('init dpn', args.num_layers) + embedding = fdpn.get_symbol(512, args.num_layers, + version_se=args.version_se, version_input=args.version_input, + version_output=args.version_output, version_unit=args.version_unit) else: print('init resnet', args.num_layers) embedding = fresnet.get_symbol(512, args.num_layers, @@ -344,7 +350,7 @@ def train_net(args): base_lr = args.lr base_wd = args.wd base_mom = 0.9 - if len(args.pretrained)>0: + if len(args.pretrained)==0: arg_params = None aux_params = None sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)