diff --git a/src/symbols/symbol_utils.py b/src/symbols/symbol_utils.py index a25d488..e23b01b 100644 --- a/src/symbols/symbol_utils.py +++ b/src/symbols/symbol_utils.py @@ -48,14 +48,19 @@ def get_fc1(last_conv, num_classes, fc_type): filters_in = num_classes else: body = last_conv - conv_6_sep = mx.sym.BatchNorm(data=body, fix_gamma=True, eps=2e-5, momentum=0.9, name='bn6f') - spatial_norm=conv_6_sep*conv_6_sep + body = mx.sym.BatchNorm(data=body, fix_gamma=True, eps=2e-5, momentum=0.9, name='bn6f') + + spatial_norm=body*body spatial_norm=mx.sym.sum(data=spatial_norm, axis=1, keepdims=True) spatial_sqrt=mx.sym.sqrt(spatial_norm) - spatial_mean=mx.sym.mean(spatial_sqrt, axis=(1,2,3), keepdims=True) - #spatial_mean=mx.sym.mean(spatial_sqrt) + #spatial_mean=mx.sym.mean(spatial_sqrt, axis=(1,2,3), keepdims=True) + spatial_mean=mx.sym.mean(spatial_sqrt) spatial_div_inverse=mx.sym.broadcast_div(spatial_mean, spatial_sqrt) - body = mx.sym.broadcast_mul(body, spatial_div_inverse) + + spatial_attention_inverse=mx.symbol.tile(spatial_div_inverse, reps=(1,filters_in,1,1)) + body=body*spatial_attention_inverse + #body = mx.sym.broadcast_mul(body, spatial_div_inverse) + fc1 = mx.sym.Pooling(body, kernel=(7, 7), global_pool=True, pool_type='avg') if num_classes