diff --git a/src/marginalnet.py b/src/marginalnet.py index da4c17b..e707ea3 100644 --- a/src/marginalnet.py +++ b/src/marginalnet.py @@ -92,7 +92,7 @@ def resnet(data, units, filters, rtype, workspace): def get_symbol(num_classes, num_layers, conv_workspace=256): data = mx.symbol.Variable('data') bn_mom = 0.9 - if num_layers==27: + if num_layers<29: data = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='bn_data') else: data = data-127.5 @@ -110,6 +110,9 @@ def get_symbol(num_classes, num_layers, conv_workspace=256): rtype = 3 #use_last_bn = False #use_dropout = False + elif num_layers==51: + units = [2,3,15,3] + rtype = 3 body = resnet(data = data, units = units, filters = filter_list, rtype=rtype, workspace = conv_workspace) if use_dropout: