tiny on network symbols

This commit is contained in:
nttstar
2018-10-25 19:35:17 +08:00
parent bc1cf15c50
commit 50a2e2d2f0
2 changed files with 2 additions and 5 deletions

View File

@@ -47,7 +47,7 @@ def get_symbol(num_classes, **kwargs):
bn_mom = kwargs.get('bn_mom', 0.9)
wd_mult = kwargs.get('wd_mult', 1.)
version_output = kwargs.get('version_output', 'GNAP')
assert version_output=='GDC' or version_output=='GNAP'
#assert version_output=='GDC' or version_output=='GNAP'
fc_type = version_output
data = mx.symbol.Variable(name="data")
data = data-127.5

View File

@@ -479,10 +479,7 @@ def residual_unit(data, num_filter, stride, dim_match, name, bottle_neck, **kwar
elif uv==4:
return residual_unit_v4(data, num_filter, stride, dim_match, name, bottle_neck, **kwargs)
else:
if version_input<=1:
return residual_unit_v3(data, num_filter, stride, dim_match, name, bottle_neck, **kwargs)
else:
return residual_unit_v3_x(data, num_filter, stride, dim_match, name, bottle_neck, **kwargs)
return residual_unit_v3(data, num_filter, stride, dim_match, name, bottle_neck, **kwargs)
def resnet(units, num_stages, filter_list, num_classes, bottle_neck, **kwargs):
bn_mom = kwargs.get('bn_mom', 0.9)