per-image norm

This commit is contained in:
Jia Guo
2018-08-06 10:39:09 +08:00
parent aea56bb943
commit c17fc4a9d9

View File

@@ -52,10 +52,9 @@ def get_fc1(last_conv, num_classes, fc_type):
spatial_norm=conv_6_sep*conv_6_sep
spatial_norm=mx.sym.sum(data=spatial_norm, axis=1, keepdims=True)
spatial_sqrt=mx.sym.sqrt(spatial_norm)
spatial_mean=mx.sym.mean(spatial_sqrt)
spatial_mean=mx.sym.mean(spatial_sqrt, axis=(1,2,3), keepdims=True)
#spatial_mean=mx.sym.mean(spatial_sqrt)
spatial_div_inverse=mx.sym.broadcast_div(spatial_mean, spatial_sqrt)
#spatial_attention_inverse=mx.symbol.tile(spatial_div_inverse, reps=(1,filters_in,1,1))
#attention_re=body*spatial_attention_inverse
body = mx.sym.broadcast_mul(body, spatial_div_inverse)
fc1 = mx.sym.Pooling(body, kernel=(7, 7), global_pool=True, pool_type='avg')
if num_classes<filters_in: