From 5fcfb70197f7c6eb476122de79d7b6353a9a2d3f Mon Sep 17 00:00:00 2001 From: Jia Guo Date: Sun, 19 Nov 2017 21:32:45 +0800 Subject: [PATCH] tiny --- src/marginalnet.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/marginalnet.py b/src/marginalnet.py index da4c17b..e707ea3 100644 --- a/src/marginalnet.py +++ b/src/marginalnet.py @@ -92,7 +92,7 @@ def resnet(data, units, filters, rtype, workspace): def get_symbol(num_classes, num_layers, conv_workspace=256): data = mx.symbol.Variable('data') bn_mom = 0.9 - if num_layers==27: + if num_layers<29: data = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='bn_data') else: data = data-127.5 @@ -110,6 +110,9 @@ def get_symbol(num_classes, num_layers, conv_workspace=256): rtype = 3 #use_last_bn = False #use_dropout = False + elif num_layers==51: + units = [2,3,15,3] + rtype = 3 body = resnet(data = data, units = units, filters = filter_list, rtype=rtype, workspace = conv_workspace) if use_dropout: