mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-15 04:37:50 +00:00
fix GNAP
This commit is contained in:
@@ -48,14 +48,19 @@ def get_fc1(last_conv, num_classes, fc_type):
|
||||
filters_in = num_classes
|
||||
else:
|
||||
body = last_conv
|
||||
conv_6_sep = mx.sym.BatchNorm(data=body, fix_gamma=True, eps=2e-5, momentum=0.9, name='bn6f')
|
||||
spatial_norm=conv_6_sep*conv_6_sep
|
||||
body = mx.sym.BatchNorm(data=body, fix_gamma=True, eps=2e-5, momentum=0.9, name='bn6f')
|
||||
|
||||
spatial_norm=body*body
|
||||
spatial_norm=mx.sym.sum(data=spatial_norm, axis=1, keepdims=True)
|
||||
spatial_sqrt=mx.sym.sqrt(spatial_norm)
|
||||
spatial_mean=mx.sym.mean(spatial_sqrt, axis=(1,2,3), keepdims=True)
|
||||
#spatial_mean=mx.sym.mean(spatial_sqrt)
|
||||
#spatial_mean=mx.sym.mean(spatial_sqrt, axis=(1,2,3), keepdims=True)
|
||||
spatial_mean=mx.sym.mean(spatial_sqrt)
|
||||
spatial_div_inverse=mx.sym.broadcast_div(spatial_mean, spatial_sqrt)
|
||||
body = mx.sym.broadcast_mul(body, spatial_div_inverse)
|
||||
|
||||
spatial_attention_inverse=mx.symbol.tile(spatial_div_inverse, reps=(1,filters_in,1,1))
|
||||
body=body*spatial_attention_inverse
|
||||
#body = mx.sym.broadcast_mul(body, spatial_div_inverse)
|
||||
|
||||
fc1 = mx.sym.Pooling(body, kernel=(7, 7), global_pool=True, pool_type='avg')
|
||||
if num_classes<filters_in:
|
||||
fc1 = mx.sym.BatchNorm(data=fc1, fix_gamma=True, eps=2e-5, momentum=0.9, name='bn6w')
|
||||
|
||||
Reference in New Issue
Block a user