This commit is contained in:
Jia Guo
2017-12-05 20:23:58 +08:00
parent fe0a9bd808
commit f297af8911

View File

@@ -484,13 +484,23 @@ def resnet(units, num_stages, filter_list, num_classes, bottle_neck, **kwargs):
if fc_type=='A':
fc1 = flat
else:
#B
fc1 = mx.sym.FullyConnected(data=flat, num_hidden=num_classes, name='pre_fc1')
if fc_type=='C':
fc1 = mx.sym.BatchNorm(data=fc1, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='fc1')
elif fc_type=='D':
fc1 = mx.sym.BatchNorm(data=fc1, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='fc1')
fc1 = Act(data=fc1, act_type='relu', name='fc1_relu')
if fc_type=='G' or fc_type=='H':
fc1 = mx.symbol.Dropout(data=flat, p=0.2)
fc1 = mx.sym.FullyConnected(data=fc1, num_hidden=num_classes, name='pre_fc1')
if fc_type=='G':
return fc1
else:
fc1 = mx.sym.BatchNorm(data=fc1, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='fc1')
return fc1
else:
#B-D
#B
fc1 = mx.sym.FullyConnected(data=flat, num_hidden=num_classes, name='pre_fc1')
if fc_type=='C':
fc1 = mx.sym.BatchNorm(data=fc1, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='fc1')
elif fc_type=='D':
fc1 = mx.sym.BatchNorm(data=fc1, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='fc1')
fc1 = Act(data=fc1, act_type='relu', name='fc1_relu')
return fc1
def get_symbol(num_classes, num_layers, **kwargs):