This commit is contained in:
Jia Guo
2017-12-02 17:30:56 +08:00
parent 09a7068151
commit fa2202ee50

View File

@@ -291,7 +291,7 @@ def residual_unit_v3(data, num_filter, stride, dim_match, name, bottle_neck=True
return bn3 + shortcut
def residual_unit(data, num_filter, stride, dim_match, name, bottle_neck=True, use_se=False, bn_mom=0.9, workspace=256, memonger=False):
return residual_unit_v3(data, num_filter, stride, dim_match, name, bottle_neck, use_se, bn_mom, workspace, memonger)
return residual_unit_v1(data, num_filter, stride, dim_match, name, bottle_neck, use_se, bn_mom, workspace, memonger)
def resnet(units, num_stages, filter_list, num_classes, bottle_neck=True, use_se=False, bn_mom=0.9, workspace=256, memonger=False):
"""Return ResNet symbol of
@@ -310,6 +310,8 @@ def resnet(units, num_stages, filter_list, num_classes, bottle_neck=True, use_se
workspace : int
Workspace used in convolution operator
"""
L_type = False
fc_type = 'B'#'A'-'E'
num_unit = len(units)
assert(num_unit == num_stages)
data = mx.sym.Variable(name='data')
@@ -317,41 +319,48 @@ def resnet(units, num_stages, filter_list, num_classes, bottle_neck=True, use_se
data = data-127.5
data = data*0.0078125
#data = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='bn_data')
#body = Conv(data=data, num_filter=filter_list[0], kernel=(7, 7), stride=(2,2), pad=(3, 3),
# no_bias=True, name="conv0", workspace=workspace)
body = Conv(data=data, num_filter=filter_list[0], kernel=(3,3), stride=(1,1), pad=(1, 1),
no_bias=True, name="conv0", workspace=workspace)
if not L_type:
body = Conv(data=data, num_filter=filter_list[0], kernel=(7, 7), stride=(2,2), pad=(3, 3),
no_bias=True, name="conv0", workspace=workspace)
else:
body = Conv(data=data, num_filter=filter_list[0], kernel=(3,3), stride=(1,1), pad=(1, 1),
no_bias=True, name="conv0", workspace=workspace)
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn0')
body = Act(data=body, act_type='relu', name='relu0')
#body = mx.sym.Pooling(data=body, kernel=(3, 3), stride=(2,2), pad=(1,1), pool_type='max')
body = mx.sym.Pooling(data=body, kernel=(3, 3), stride=(2,2), pad=(1,1), pool_type='max')
for i in range(num_stages):
#body = residual_unit(body, filter_list[i+1], (1 if i==0 else 2, 1 if i==0 else 2), False,
# name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck, use_se=use_se,workspace=workspace,
# memonger=memonger)
body = residual_unit(body, filter_list[i+1], (2, 2), False,
name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck, use_se=use_se,workspace=workspace,
memonger=memonger)
body = residual_unit(body, filter_list[i+1], (1 if i==0 else 2, 1 if i==0 else 2), False,
name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck, use_se=use_se,workspace=workspace,
memonger=memonger)
#body = residual_unit(body, filter_list[i+1], (2, 2), False,
# name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck, use_se=use_se,workspace=workspace,
# memonger=memonger)
for j in range(units[i]-1):
body = residual_unit(body, filter_list[i+1], (1,1), True, name='stage%d_unit%d' % (i+1, j+2),
bottle_neck=bottle_neck, use_se=use_se, workspace=workspace, memonger=memonger)
#body = residual_unit(body, filter_list[i+1], (2,2), False, name='stage%d_unit%d' % (i+1, units[i]),
# bottle_neck=bottle_neck, use_se=use_se, workspace=workspace, memonger=memonger)
fc_type = 1#0 or 1
if fc_type==0:
if fc_type=='E':
body = mx.symbol.Dropout(data=body, p=0.4)
fc1 = mx.sym.FullyConnected(data=body, num_hidden=num_classes, name='pre_fc1')
fc1 = mx.sym.BatchNorm(data=fc1, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='fc1')
else:
bn1 = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn1')
relu1 = Act(data=bn1, act_type='relu', name='relu1')
# Although kernel is not used here when global_pool=True, we should put one
pool1 = mx.sym.Pooling(data=relu1, global_pool=True, kernel=(7, 7), pool_type='avg', name='pool1')
flat = mx.sym.Flatten(data=pool1)
#fc1 = mx.sym.FullyConnected(data=flat, num_hidden=num_classes, name='fc1')
fc1 = mx.sym.FullyConnected(data=body, num_hidden=num_classes, name='pre_fc1')
fc1 = mx.sym.BatchNorm(data=fc1, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='fc1')
else:
body = mx.symbol.Dropout(data=body, p=0.4)
fc1 = mx.sym.FullyConnected(data=body, num_hidden=num_classes, name='pre_fc1')
fc1 = mx.sym.BatchNorm(data=fc1, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='fc1')
if fc_type=='A':
fc1 = flat
else:
#B
fc1 = mx.sym.FullyConnected(data=flat, num_hidden=num_classes, name='pre_fc1')
if fc_type=='C':
fc1 = mx.sym.BatchNorm(data=fc1, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='fc1')
elif fc_type=='D':
fc1 = mx.sym.BatchNorm(data=fc1, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='fc1')
fc1 = Act(data=fc1, act_type='relu', name='fc1_relu')
return fc1
def get_symbol(num_classes, num_layers, conv_workspace=256, **kwargs):