diff --git a/src/symbols/fresnet.py b/src/symbols/fresnet.py index af4ee2e..c089201 100644 --- a/src/symbols/fresnet.py +++ b/src/symbols/fresnet.py @@ -484,13 +484,23 @@ def resnet(units, num_stages, filter_list, num_classes, bottle_neck, **kwargs): if fc_type=='A': fc1 = flat else: - #B - fc1 = mx.sym.FullyConnected(data=flat, num_hidden=num_classes, name='pre_fc1') - if fc_type=='C': - fc1 = mx.sym.BatchNorm(data=fc1, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='fc1') - elif fc_type=='D': - fc1 = mx.sym.BatchNorm(data=fc1, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='fc1') - fc1 = Act(data=fc1, act_type='relu', name='fc1_relu') + if fc_type=='G' or fc_type=='H': + fc1 = mx.symbol.Dropout(data=flat, p=0.2) + fc1 = mx.sym.FullyConnected(data=fc1, num_hidden=num_classes, name='pre_fc1') + if fc_type=='G': + return fc1 + else: + fc1 = mx.sym.BatchNorm(data=fc1, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='fc1') + return fc1 + else: + #B-D + #B + fc1 = mx.sym.FullyConnected(data=flat, num_hidden=num_classes, name='pre_fc1') + if fc_type=='C': + fc1 = mx.sym.BatchNorm(data=fc1, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='fc1') + elif fc_type=='D': + fc1 = mx.sym.BatchNorm(data=fc1, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='fc1') + fc1 = Act(data=fc1, act_type='relu', name='fc1_relu') return fc1 def get_symbol(num_classes, num_layers, **kwargs):