This commit is contained in:
nttstar
2017-11-19 21:32:19 +08:00

View File

@@ -92,7 +92,7 @@ def resnet(data, units, filters, rtype, workspace):
def get_symbol(num_classes, num_layers, conv_workspace=256):
data = mx.symbol.Variable('data')
bn_mom = 0.9
if num_layers==27:
if num_layers<29:
data = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='bn_data')
else:
data = data-127.5
@@ -110,6 +110,9 @@ def get_symbol(num_classes, num_layers, conv_workspace=256):
rtype = 3
#use_last_bn = False
#use_dropout = False
elif num_layers==51:
units = [2,3,15,3]
rtype = 3
body = resnet(data = data, units = units, filters = filter_list, rtype=rtype, workspace = conv_workspace)
if use_dropout: