mirror of
https://github.com/deepinsight/insightface.git
synced 2025-12-30 08:02:27 +00:00
260 lines
9.3 KiB
Python
260 lines
9.3 KiB
Python
"""A `MutableModule` implement the `BaseModule` API, and allows input shape
|
|
varying with training iterations. If shapes vary, executors will rebind,
|
|
using shared arrays from the initial module binded with maximum shape.
|
|
"""
|
|
|
|
import logging
|
|
|
|
from mxnet import context as ctx
|
|
from mxnet.initializer import Uniform
|
|
from mxnet.module.base_module import BaseModule
|
|
from mxnet.module.module import Module
|
|
|
|
|
|
class MutableModule(BaseModule):
|
|
"""A mutable module is a module that supports variable input data.
|
|
|
|
Parameters
|
|
----------
|
|
symbol : Symbol
|
|
data_names : list of str
|
|
label_names : list of str
|
|
logger : Logger
|
|
context : Context or list of Context
|
|
work_load_list : list of number
|
|
max_data_shapes : list of (name, shape) tuple, designating inputs whose shape vary
|
|
max_label_shapes : list of (name, shape) tuple, designating inputs whose shape vary
|
|
fixed_param_prefix : list of str, indicating fixed parameters
|
|
"""
|
|
def __init__(self,
|
|
symbol,
|
|
data_names,
|
|
label_names,
|
|
logger=logging,
|
|
context=ctx.cpu(),
|
|
work_load_list=None,
|
|
max_data_shapes=None,
|
|
max_label_shapes=None,
|
|
fixed_param_prefix=None):
|
|
super(MutableModule, self).__init__(logger=logger)
|
|
self._symbol = symbol
|
|
self._data_names = data_names
|
|
self._label_names = label_names
|
|
self._context = context
|
|
self._work_load_list = work_load_list
|
|
|
|
self._curr_module = None
|
|
self._max_data_shapes = max_data_shapes
|
|
self._max_label_shapes = max_label_shapes
|
|
self._fixed_param_prefix = fixed_param_prefix
|
|
|
|
fixed_param_names = list()
|
|
if fixed_param_prefix is not None:
|
|
for name in self._symbol.list_arguments():
|
|
for prefix in self._fixed_param_prefix:
|
|
if prefix in name:
|
|
fixed_param_names.append(name)
|
|
self._fixed_param_names = fixed_param_names
|
|
|
|
def _reset_bind(self):
|
|
self.binded = False
|
|
self._curr_module = None
|
|
|
|
@property
|
|
def data_names(self):
|
|
return self._data_names
|
|
|
|
@property
|
|
def output_names(self):
|
|
return self._symbol.list_outputs()
|
|
|
|
@property
|
|
def data_shapes(self):
|
|
assert self.binded
|
|
return self._curr_module.data_shapes
|
|
|
|
@property
|
|
def label_shapes(self):
|
|
assert self.binded
|
|
return self._curr_module.label_shapes
|
|
|
|
@property
|
|
def output_shapes(self):
|
|
assert self.binded
|
|
return self._curr_module.output_shapes
|
|
|
|
def get_params(self):
|
|
assert self.binded and self.params_initialized
|
|
return self._curr_module.get_params()
|
|
|
|
def init_params(self,
|
|
initializer=Uniform(0.01),
|
|
arg_params=None,
|
|
aux_params=None,
|
|
allow_missing=False,
|
|
force_init=False,
|
|
allow_extra=False):
|
|
if self.params_initialized and not force_init:
|
|
return
|
|
assert self.binded, 'call bind before initializing the parameters'
|
|
self._curr_module.init_params(initializer=initializer,
|
|
arg_params=arg_params,
|
|
aux_params=aux_params,
|
|
allow_missing=allow_missing,
|
|
force_init=force_init,
|
|
allow_extra=allow_extra)
|
|
self.params_initialized = True
|
|
|
|
def bind(self,
|
|
data_shapes,
|
|
label_shapes=None,
|
|
for_training=True,
|
|
inputs_need_grad=False,
|
|
force_rebind=False,
|
|
shared_module=None):
|
|
# in case we already initialized params, keep it
|
|
if self.params_initialized:
|
|
arg_params, aux_params = self.get_params()
|
|
|
|
# force rebinding is typically used when one want to switch from
|
|
# training to prediction phase.
|
|
if force_rebind:
|
|
self._reset_bind()
|
|
|
|
if self.binded:
|
|
self.logger.warning('Already binded, ignoring bind()')
|
|
return
|
|
|
|
assert shared_module is None, 'shared_module for MutableModule is not supported'
|
|
|
|
self.for_training = for_training
|
|
self.inputs_need_grad = inputs_need_grad
|
|
self.binded = True
|
|
|
|
max_shapes_dict = dict()
|
|
if self._max_data_shapes is not None:
|
|
max_shapes_dict.update(dict(self._max_data_shapes))
|
|
if self._max_label_shapes is not None:
|
|
max_shapes_dict.update(dict(self._max_label_shapes))
|
|
|
|
max_data_shapes = list()
|
|
for name, shape in data_shapes:
|
|
if name in max_shapes_dict:
|
|
max_data_shapes.append((name, max_shapes_dict[name]))
|
|
else:
|
|
max_data_shapes.append((name, shape))
|
|
|
|
max_label_shapes = list()
|
|
if label_shapes is not None:
|
|
for name, shape in label_shapes:
|
|
if name in max_shapes_dict:
|
|
max_label_shapes.append((name, max_shapes_dict[name]))
|
|
else:
|
|
max_label_shapes.append((name, shape))
|
|
|
|
if len(max_label_shapes) == 0:
|
|
max_label_shapes = None
|
|
|
|
module = Module(self._symbol,
|
|
self._data_names,
|
|
self._label_names,
|
|
logger=self.logger,
|
|
context=self._context,
|
|
work_load_list=self._work_load_list,
|
|
fixed_param_names=self._fixed_param_names)
|
|
module.bind(max_data_shapes,
|
|
max_label_shapes,
|
|
for_training,
|
|
inputs_need_grad,
|
|
force_rebind=False,
|
|
shared_module=None)
|
|
self._curr_module = module
|
|
|
|
# copy back saved params, if already initialized
|
|
if self.params_initialized:
|
|
self.set_params(arg_params, aux_params)
|
|
|
|
def init_optimizer(self,
|
|
kvstore='local',
|
|
optimizer='sgd',
|
|
optimizer_params=(('learning_rate', 0.01), ),
|
|
force_init=False):
|
|
assert self.binded and self.params_initialized
|
|
if self.optimizer_initialized and not force_init:
|
|
self.logger.warning('optimizer already initialized, ignoring.')
|
|
return
|
|
|
|
self._curr_module.init_optimizer(kvstore,
|
|
optimizer,
|
|
optimizer_params,
|
|
force_init=force_init)
|
|
self.optimizer_initialized = True
|
|
|
|
def forward(self, data_batch, is_train=None):
|
|
assert self.binded and self.params_initialized
|
|
|
|
# get current_shapes
|
|
if self._curr_module.label_shapes is not None:
|
|
current_shapes = dict(self._curr_module.data_shapes +
|
|
self._curr_module.label_shapes)
|
|
else:
|
|
current_shapes = dict(self._curr_module.data_shapes)
|
|
|
|
# get input_shapes
|
|
if data_batch.provide_label is not None:
|
|
input_shapes = dict(data_batch.provide_data +
|
|
data_batch.provide_label)
|
|
else:
|
|
input_shapes = dict(data_batch.provide_data)
|
|
|
|
# decide if shape changed
|
|
shape_changed = False
|
|
for k, v in current_shapes.items():
|
|
if v != input_shapes[k]:
|
|
shape_changed = True
|
|
|
|
if shape_changed:
|
|
module = Module(self._symbol,
|
|
self._data_names,
|
|
self._label_names,
|
|
logger=self.logger,
|
|
context=self._context,
|
|
work_load_list=self._work_load_list,
|
|
fixed_param_names=self._fixed_param_names)
|
|
module.bind(data_batch.provide_data,
|
|
data_batch.provide_label,
|
|
self._curr_module.for_training,
|
|
self._curr_module.inputs_need_grad,
|
|
force_rebind=False,
|
|
shared_module=self._curr_module)
|
|
self._curr_module = module
|
|
|
|
self._curr_module.forward(data_batch, is_train=is_train)
|
|
|
|
def backward(self, out_grads=None):
|
|
assert self.binded and self.params_initialized
|
|
self._curr_module.backward(out_grads=out_grads)
|
|
|
|
def update(self):
|
|
assert self.binded and self.params_initialized and self.optimizer_initialized
|
|
self._curr_module.update()
|
|
|
|
def get_outputs(self, merge_multi_context=True):
|
|
assert self.binded and self.params_initialized
|
|
return self._curr_module.get_outputs(
|
|
merge_multi_context=merge_multi_context)
|
|
|
|
def get_input_grads(self, merge_multi_context=True):
|
|
assert self.binded and self.params_initialized and self.inputs_need_grad
|
|
return self._curr_module.get_input_grads(
|
|
merge_multi_context=merge_multi_context)
|
|
|
|
def update_metric(self, eval_metric, labels):
|
|
assert self.binded and self.params_initialized
|
|
self._curr_module.update_metric(eval_metric, labels)
|
|
|
|
def install_monitor(self, mon):
|
|
""" Install monitor on all executors """
|
|
assert self.binded
|
|
self._curr_module.install_monitor(mon)
|