Files
insightface/recognition/partial_fc/mxnet/symbol/memonger.py
2020-11-06 13:59:21 +08:00

176 lines
5.2 KiB
Python

import mxnet as mx
import math
def prod(shape):
"""Get product of the shape.
"""
ret = 1
for s in shape:
ret *= s
return ret
def is_param(name):
"""Quick script to check if name is a parameter.
"""
if name == 'data':
return False
if name.endswith('weight'):
return True
if name.endswith('bias'):
return True
if name.endswith('beta'):
return True
if name.endswith('gamma'):
return True
return False
def make_mirror_plan(sym, threshold, plan_info=None, **kwargs):
"""Memory allocation planner with a given threshold.
The user can pass in a network configuration,
a threshold that limits memory per block.
And input shape configurations.
Parameters
----------
sym : symbol
Input configuration of symbols.
The user need to pre-mark the attribute "mirror_stage" on the nodes
that can be book-kept as stage
The algorithm will decide whether to disbale mirror on the stage nodes.
threshold: integer
A tuning parameter to tune the approximate size of each stage blocks
plan_info: dict, optional
Used to hold plan information.
**kwargs:
The arguments to infer shape.
Returns
-------
alloc_sym: symbol
A symbol with force mirror tagged on the nodes for better allocation.
"""
threshold = threshold << 20
sym = sym.__copy__()
internals = sym.get_internals()
_, out_shapes, _ = internals.infer_shape(**kwargs)
shape_dict = list(zip(internals.list_outputs(), out_shapes))
total_size = 0
param_size = 0
local_size = 0
save_size = 0
max_size = 0
last_sb = None
last_local = 0
period = 1
last_stage = ''
stage_decision = ''
for idx, item in enumerate(shape_dict):
sb = internals[idx]
name, shape = item
if is_param(name):
param_size += prod(shape) * 4
continue
else:
total_size += prod(shape) * 4
local_size += prod(shape) * 4
sb._set_attr(force_mirroring='True')
if sb.attr('mirror_stage') is not None:
stage = sb.attr('mirror_stage')
if stage == 'True' or stage != last_stage:
if local_size > threshold:
save_size += prod(shape) * 4
max_size = max(max_size, local_size)
local_size = 0
stage_decision = 'False'
sb._set_attr(force_mirroring=stage_decision)
else:
stage_decision = 'True'
pass
last_stage = stage
elif stage == last_stage and stage_decision == 'False':
save_size += prod(shape) * 4
sb._set_attr(force_mirroring=stage_decision)
if plan_info is not None:
plan_info['max_size'] = max_size
plan_info['save_size'] = save_size
return sym
def get_cost(sym, type_dict=None, **kwargs):
"""Get the cost of the current symbolic plan by running bind on CPU.
sym : Symbolic Variable
"""
texec = sym.simple_bind(ctx=mx.gpu(),
grad_req='write',
type_dict=type_dict,
**kwargs)
return int(texec.debug_str().split('\n')[-3].split()[1])
def search_plan(sym, ntrial=6, type_dict=None, **kwargs):
"""Quickly heurestic search over possible plans to find good memory plan.
Parameters
----------
sym : symbolic
Symbolic configurations
ntrial: integer
Additional grid search steps
"""
history = []
threshold = 0
min_threshold = None
min_cost = None
nbegin = 3
for k in range(nbegin):
info = {}
sym = make_mirror_plan(sym,
threshold=threshold,
plan_info=info,
**kwargs)
cost = get_cost(sym, type_dict, **kwargs)
save_size = info['save_size'] >> 20
local_size = info['max_size'] >> 20
guess = int(math.sqrt(save_size * local_size / 2))
if min_cost is None or min_cost > cost:
min_cost = cost
if min_threshold is None or local_size < min_threshold:
min_threshold = local_size
print("Search threshold=%d MB, cost=%d MB" % (threshold, cost))
history.append((cost, threshold, sym))
threshold = guess
max_threshold = threshold * math.sqrt(2)
step = int((max_threshold - min_threshold) / ntrial)
threshold = min_threshold + step
if step > 0:
for k in range(ntrial):
sym = make_mirror_plan(sym,
threshold=threshold,
plan_info=info,
**kwargs)
cost = get_cost(sym, type_dict, **kwargs)
print("Search threshold=%d MB, cost=%d MB" % (threshold, cost))
history.append((cost, threshold, sym))
threshold += step
history.sort(key=lambda x: x[0])
cost, threshold, sym = history[0]
print('Find best plan with threshold=%d, cost=%d MB' % (threshold, cost))
return sym