add get_head

This commit is contained in:
Jia Guo
2017-12-11 20:07:12 +08:00
parent d45992e0cb
commit 6b00c055db
3 changed files with 75 additions and 11 deletions

View File

@@ -1,7 +1,7 @@
import os
# MXNET_CPU_WORKER_NTHREADS must be greater than 1 for custom op to work on CPU
os.environ['MXNET_CPU_WORKER_NTHREADS'] = '2'
#os.environ['MXNET_CPU_WORKER_NTHREADS'] = '2'
import mxnet as mx

View File

@@ -177,16 +177,15 @@ def get_symbol(num_classes = 1000, num_layers=92, **kwargs):
## define Dual Path Network
data = mx.symbol.Variable(name="data")
data = data-127.5
data = data*0.0078125
# conv1
if version_input==0:
conv1_x_1 = Conv(data=data, num_filter=128, kernel=(7, 7), name='conv1_x_1', pad=(3,3), stride=(2,2))
else:
conv1_x_1 = Conv(data=data, num_filter=128, kernel=(3, 3), name='conv1_x_1', pad=(3,3), stride=(1,1))
conv1_x_1 = BN_AC(conv1_x_1, name='conv1_x_1__relu-sp')
conv1_x_x = mx.symbol.Pooling(data=conv1_x_1, pool_type="max", kernel=(3, 3), pad=(1,1), stride=(2,2), name="pool1")
#data = data-127.5
#data = data*0.0078125
#if version_input==0:
# conv1_x_1 = Conv(data=data, num_filter=128, kernel=(7, 7), name='conv1_x_1', pad=(3,3), stride=(2,2))
#else:
# conv1_x_1 = Conv(data=data, num_filter=128, kernel=(3, 3), name='conv1_x_1', pad=(3,3), stride=(1,1))
#conv1_x_1 = BN_AC(conv1_x_1, name='conv1_x_1__relu-sp')
#conv1_x_x = mx.symbol.Pooling(data=conv1_x_1, pool_type="max", kernel=(3, 3), pad=(1,1), stride=(2,2), name="pool1")
conv1_x_x = symbol_utils.get_head(data, version_input, 128)
# conv2
bw = 256

View File

@@ -44,3 +44,68 @@ def get_fc1(last_conv, num_classes, fc_type):
fc1 = Act(data=fc1, act_type='relu', name='fc1_relu')
return fc1
def residual_unit_v3(data, num_filter, stride, dim_match, name, **kwargs):
"""Return ResNet Unit symbol for building ResNet
Parameters
----------
data : str
Input data
num_filter : int
Number of output channels
bnf : int
Bottle neck channels factor with regard to num_filter
stride : tuple
Stride used in convolution
dim_match : Boolean
True means channel number between input and output is the same, otherwise means differ
name : str
Base name of the operators
workspace : int
Workspace used in convolution operator
"""
bn_mom = kwargs.get('bn_mom', 0.9)
workspace = kwargs.get('workspace', 256)
memonger = kwargs.get('memonger', False)
#print('in unit3')
bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn1')
conv1 = Conv(data=bn1, num_filter=num_filter, kernel=(3,3), stride=(1,1), pad=(1,1),
no_bias=True, workspace=workspace, name=name + '_conv1')
bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn2')
act1 = Act(data=bn2, act_type='relu', name=name + '_relu1')
conv2 = Conv(data=act1, num_filter=num_filter, kernel=(3,3), stride=stride, pad=(1,1),
no_bias=True, workspace=workspace, name=name + '_conv2')
bn3 = mx.sym.BatchNorm(data=conv2, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn3')
if dim_match:
shortcut = data
else:
conv1sc = Conv(data=data, num_filter=num_filter, kernel=(1,1), stride=stride, no_bias=True,
workspace=workspace, name=name+'_conv1sc')
shortcut = mx.sym.BatchNorm(data=conv1sc, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_sc')
if memonger:
shortcut._set_attr(mirror_stage='True')
return bn3 + shortcut
def get_head(data, version_input, num_filter):
data = data-127.5
data = data*0.0078125
#data = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='bn_data')
if version_input==0:
body = Conv(data=data, num_filter=num_filter, kernel=(7, 7), stride=(2,2), pad=(3, 3),
no_bias=True, name="conv0", workspace=workspace)
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn0')
body = Act(data=body, act_type='relu', name='relu0')
body = mx.sym.Pooling(data=body, kernel=(3, 3), stride=(2,2), pad=(1,1), pool_type='max')
else:
body = data
body = Conv(data=body, num_filter=num_filter, kernel=(3,3), stride=(1,1), pad=(1, 1),
no_bias=True, name="conv0", workspace=workspace)
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn0')
body = Act(data=body, act_type='relu', name='relu0')
body = residual_unit_v3(body, num_filter, (2, 2), False,
name='head', **kwargs)
return body