diff --git a/src/symbols/fresnet.py b/src/symbols/fresnet.py index e632e3a..0fbf5c6 100644 --- a/src/symbols/fresnet.py +++ b/src/symbols/fresnet.py @@ -291,7 +291,7 @@ def residual_unit_v3(data, num_filter, stride, dim_match, name, bottle_neck=True return bn3 + shortcut def residual_unit(data, num_filter, stride, dim_match, name, bottle_neck=True, use_se=False, bn_mom=0.9, workspace=256, memonger=False): - return residual_unit_v3(data, num_filter, stride, dim_match, name, bottle_neck, use_se, bn_mom, workspace, memonger) + return residual_unit_v1(data, num_filter, stride, dim_match, name, bottle_neck, use_se, bn_mom, workspace, memonger) def resnet(units, num_stages, filter_list, num_classes, bottle_neck=True, use_se=False, bn_mom=0.9, workspace=256, memonger=False): """Return ResNet symbol of @@ -310,6 +310,8 @@ def resnet(units, num_stages, filter_list, num_classes, bottle_neck=True, use_se workspace : int Workspace used in convolution operator """ + L_type = False + fc_type = 'B'#'A'-'E' num_unit = len(units) assert(num_unit == num_stages) data = mx.sym.Variable(name='data') @@ -317,41 +319,48 @@ def resnet(units, num_stages, filter_list, num_classes, bottle_neck=True, use_se data = data-127.5 data = data*0.0078125 #data = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='bn_data') - #body = Conv(data=data, num_filter=filter_list[0], kernel=(7, 7), stride=(2,2), pad=(3, 3), - # no_bias=True, name="conv0", workspace=workspace) - body = Conv(data=data, num_filter=filter_list[0], kernel=(3,3), stride=(1,1), pad=(1, 1), - no_bias=True, name="conv0", workspace=workspace) + if not L_type: + body = Conv(data=data, num_filter=filter_list[0], kernel=(7, 7), stride=(2,2), pad=(3, 3), + no_bias=True, name="conv0", workspace=workspace) + else: + body = Conv(data=data, num_filter=filter_list[0], kernel=(3,3), stride=(1,1), pad=(1, 1), + no_bias=True, name="conv0", workspace=workspace) body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn0') body = Act(data=body, act_type='relu', name='relu0') - #body = mx.sym.Pooling(data=body, kernel=(3, 3), stride=(2,2), pad=(1,1), pool_type='max') + body = mx.sym.Pooling(data=body, kernel=(3, 3), stride=(2,2), pad=(1,1), pool_type='max') for i in range(num_stages): - #body = residual_unit(body, filter_list[i+1], (1 if i==0 else 2, 1 if i==0 else 2), False, - # name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck, use_se=use_se,workspace=workspace, - # memonger=memonger) - body = residual_unit(body, filter_list[i+1], (2, 2), False, - name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck, use_se=use_se,workspace=workspace, - memonger=memonger) + body = residual_unit(body, filter_list[i+1], (1 if i==0 else 2, 1 if i==0 else 2), False, + name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck, use_se=use_se,workspace=workspace, + memonger=memonger) + #body = residual_unit(body, filter_list[i+1], (2, 2), False, + # name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck, use_se=use_se,workspace=workspace, + # memonger=memonger) for j in range(units[i]-1): body = residual_unit(body, filter_list[i+1], (1,1), True, name='stage%d_unit%d' % (i+1, j+2), bottle_neck=bottle_neck, use_se=use_se, workspace=workspace, memonger=memonger) - #body = residual_unit(body, filter_list[i+1], (2,2), False, name='stage%d_unit%d' % (i+1, units[i]), - # bottle_neck=bottle_neck, use_se=use_se, workspace=workspace, memonger=memonger) - fc_type = 1#0 or 1 - if fc_type==0: + + if fc_type=='E': + body = mx.symbol.Dropout(data=body, p=0.4) + fc1 = mx.sym.FullyConnected(data=body, num_hidden=num_classes, name='pre_fc1') + fc1 = mx.sym.BatchNorm(data=fc1, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='fc1') + else: bn1 = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn1') relu1 = Act(data=bn1, act_type='relu', name='relu1') # Although kernel is not used here when global_pool=True, we should put one pool1 = mx.sym.Pooling(data=relu1, global_pool=True, kernel=(7, 7), pool_type='avg', name='pool1') flat = mx.sym.Flatten(data=pool1) - #fc1 = mx.sym.FullyConnected(data=flat, num_hidden=num_classes, name='fc1') - fc1 = mx.sym.FullyConnected(data=body, num_hidden=num_classes, name='pre_fc1') - fc1 = mx.sym.BatchNorm(data=fc1, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='fc1') - else: - body = mx.symbol.Dropout(data=body, p=0.4) - fc1 = mx.sym.FullyConnected(data=body, num_hidden=num_classes, name='pre_fc1') - fc1 = mx.sym.BatchNorm(data=fc1, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='fc1') + if fc_type=='A': + fc1 = flat + else: + #B + fc1 = mx.sym.FullyConnected(data=flat, num_hidden=num_classes, name='pre_fc1') + if fc_type=='C': + fc1 = mx.sym.BatchNorm(data=fc1, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='fc1') + elif fc_type=='D': + fc1 = mx.sym.BatchNorm(data=fc1, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='fc1') + fc1 = Act(data=fc1, act_type='relu', name='fc1_relu') return fc1 def get_symbol(num_classes, num_layers, conv_workspace=256, **kwargs):