From 6b00c055db07df36d667df7982ccf4b33f388a01 Mon Sep 17 00:00:00 2001 From: Jia Guo Date: Mon, 11 Dec 2017 20:07:12 +0800 Subject: [PATCH] add get_head --- src/losses/center_loss.py | 2 +- src/symbols/fdpn.py | 19 +++++------ src/symbols/symbol_utils.py | 65 +++++++++++++++++++++++++++++++++++++ 3 files changed, 75 insertions(+), 11 deletions(-) diff --git a/src/losses/center_loss.py b/src/losses/center_loss.py index 6c211ef..b14bda3 100644 --- a/src/losses/center_loss.py +++ b/src/losses/center_loss.py @@ -1,7 +1,7 @@ import os # MXNET_CPU_WORKER_NTHREADS must be greater than 1 for custom op to work on CPU -os.environ['MXNET_CPU_WORKER_NTHREADS'] = '2' +#os.environ['MXNET_CPU_WORKER_NTHREADS'] = '2' import mxnet as mx diff --git a/src/symbols/fdpn.py b/src/symbols/fdpn.py index 3ae77f0..0544f83 100644 --- a/src/symbols/fdpn.py +++ b/src/symbols/fdpn.py @@ -177,16 +177,15 @@ def get_symbol(num_classes = 1000, num_layers=92, **kwargs): ## define Dual Path Network data = mx.symbol.Variable(name="data") - data = data-127.5 - data = data*0.0078125 - - # conv1 - if version_input==0: - conv1_x_1 = Conv(data=data, num_filter=128, kernel=(7, 7), name='conv1_x_1', pad=(3,3), stride=(2,2)) - else: - conv1_x_1 = Conv(data=data, num_filter=128, kernel=(3, 3), name='conv1_x_1', pad=(3,3), stride=(1,1)) - conv1_x_1 = BN_AC(conv1_x_1, name='conv1_x_1__relu-sp') - conv1_x_x = mx.symbol.Pooling(data=conv1_x_1, pool_type="max", kernel=(3, 3), pad=(1,1), stride=(2,2), name="pool1") + #data = data-127.5 + #data = data*0.0078125 + #if version_input==0: + # conv1_x_1 = Conv(data=data, num_filter=128, kernel=(7, 7), name='conv1_x_1', pad=(3,3), stride=(2,2)) + #else: + # conv1_x_1 = Conv(data=data, num_filter=128, kernel=(3, 3), name='conv1_x_1', pad=(3,3), stride=(1,1)) + #conv1_x_1 = BN_AC(conv1_x_1, name='conv1_x_1__relu-sp') + #conv1_x_x = mx.symbol.Pooling(data=conv1_x_1, pool_type="max", kernel=(3, 3), pad=(1,1), stride=(2,2), name="pool1") + conv1_x_x = symbol_utils.get_head(data, version_input, 128) # conv2 bw = 256 diff --git a/src/symbols/symbol_utils.py b/src/symbols/symbol_utils.py index c76c2ce..b99842b 100644 --- a/src/symbols/symbol_utils.py +++ b/src/symbols/symbol_utils.py @@ -44,3 +44,68 @@ def get_fc1(last_conv, num_classes, fc_type): fc1 = Act(data=fc1, act_type='relu', name='fc1_relu') return fc1 +def residual_unit_v3(data, num_filter, stride, dim_match, name, **kwargs): + + """Return ResNet Unit symbol for building ResNet + Parameters + ---------- + data : str + Input data + num_filter : int + Number of output channels + bnf : int + Bottle neck channels factor with regard to num_filter + stride : tuple + Stride used in convolution + dim_match : Boolean + True means channel number between input and output is the same, otherwise means differ + name : str + Base name of the operators + workspace : int + Workspace used in convolution operator + """ + bn_mom = kwargs.get('bn_mom', 0.9) + workspace = kwargs.get('workspace', 256) + memonger = kwargs.get('memonger', False) + #print('in unit3') + bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn1') + conv1 = Conv(data=bn1, num_filter=num_filter, kernel=(3,3), stride=(1,1), pad=(1,1), + no_bias=True, workspace=workspace, name=name + '_conv1') + bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn2') + act1 = Act(data=bn2, act_type='relu', name=name + '_relu1') + conv2 = Conv(data=act1, num_filter=num_filter, kernel=(3,3), stride=stride, pad=(1,1), + no_bias=True, workspace=workspace, name=name + '_conv2') + bn3 = mx.sym.BatchNorm(data=conv2, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn3') + + if dim_match: + shortcut = data + else: + conv1sc = Conv(data=data, num_filter=num_filter, kernel=(1,1), stride=stride, no_bias=True, + workspace=workspace, name=name+'_conv1sc') + shortcut = mx.sym.BatchNorm(data=conv1sc, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_sc') + if memonger: + shortcut._set_attr(mirror_stage='True') + return bn3 + shortcut + + +def get_head(data, version_input, num_filter): + data = data-127.5 + data = data*0.0078125 + #data = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='bn_data') + if version_input==0: + body = Conv(data=data, num_filter=num_filter, kernel=(7, 7), stride=(2,2), pad=(3, 3), + no_bias=True, name="conv0", workspace=workspace) + body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn0') + body = Act(data=body, act_type='relu', name='relu0') + body = mx.sym.Pooling(data=body, kernel=(3, 3), stride=(2,2), pad=(1,1), pool_type='max') + else: + body = data + body = Conv(data=body, num_filter=num_filter, kernel=(3,3), stride=(1,1), pad=(1, 1), + no_bias=True, name="conv0", workspace=workspace) + body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn0') + body = Act(data=body, act_type='relu', name='relu0') + body = residual_unit_v3(body, num_filter, (2, 2), False, + name='head', **kwargs) + return body + +