Files
insightface/recognition/arcface_paddle/static/utils/optimization_pass.py
2021-10-11 10:16:02 +08:00

125 lines
5.2 KiB
Python

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def check_contains(name, name_list):
for n in name_list:
if name in n:
return True
return False
def gather_optimization_pass(program, weight_name):
op_idxs = []
gather_grad_op = None
momentum_op = None
for idx, op in enumerate(program.global_block().ops):
if (op.type == 'gather_grad' or
op.type == 'momentum') and check_contains(weight_name,
op.input_arg_names):
op_idxs.append(idx)
if op.type == 'momentum':
momentum_op = op
if op.type == 'gather_grad':
gather_grad_op = op
if gather_grad_op is not None and momentum_op is not None:
inputs = {
'Param': momentum_op.input('Param'),
'Velocity': momentum_op.input('Velocity'),
'LearningRate': momentum_op.input('LearningRate'),
'Grad': gather_grad_op.input('Out@GRAD'),
'Index': gather_grad_op.input('Index'),
'Axis': gather_grad_op.input('Axis'),
}
outputs = {
'ParamOut': momentum_op.output('ParamOut'),
'VelocityOut': momentum_op.output('VelocityOut'),
}
if 'MasterParam' in momentum_op.input_names and len(
momentum_op.input('MasterParam')) > 0:
inputs['MasterParam'] = momentum_op.input('MasterParam')
if 'MasterParamOut' in momentum_op.output_names and len(
momentum_op.output('MasterParamOut')) > 0:
outputs['MasterParamOut'] = momentum_op.output('MasterParamOut')
attrs = {
'mu': momentum_op.attr('mu'),
'use_nesterov': momentum_op.attr('use_nesterov'),
'regularization_method': momentum_op.attr('regularization_method'),
'regularization_coeff': momentum_op.attr('regularization_coeff'),
'multi_precision': momentum_op.attr('multi_precision'),
'rescale_grad': momentum_op.attr('rescale_grad'),
'op_device': momentum_op.attr('op_device'),
'op_namescope': momentum_op.attr('op_namescope'),
'op_role': momentum_op.attr('op_role'),
'op_role_var': momentum_op.input('Param'),
'axis': gather_grad_op.attr('axis'),
}
program.global_block()._insert_op(
op_idxs[-1] + 1,
type='sparse_momentum',
inputs=inputs,
outputs=outputs,
attrs=attrs)
for idx in reversed(op_idxs):
program.global_block()._remove_op(idx, sync=False)
var_names = []
for idx, name in enumerate(program.global_block().vars):
if '@GRAD' in name and weight_name in name:
var_names.append(name)
for name in var_names:
program.global_block()._remove_var(name, sync=False)
program.global_block()._sync_with_cpp()
def amp_pass(program, weight_name):
for idx, op in enumerate(program.global_block().ops):
if (op.type == 'update_loss_scaling' or
op.type == 'check_finite_and_unscale'):
input_idxs = []
input_arg_names = op.input("X")
# input_arg_names.append(gather_grad_op.input('Out@GRAD')[0])
for i, name in enumerate(input_arg_names):
if '@GRAD' in name and weight_name in name:
input_idxs.append(i)
if len(input_idxs) > 0:
for i in reversed(input_idxs):
input_arg_names.pop(i)
op.desc.set_input("X", input_arg_names)
output_idxs = []
output_arg_names = op.output("Out")
# output_arg_names.append(gather_grad_op.input('Out@GRAD')[0])
for i, name in enumerate(output_arg_names):
if '@GRAD' in name and weight_name in name:
output_idxs.append(i)
if len(output_idxs) > 0:
for i in reversed(output_idxs):
output_arg_names.pop(i)
op.desc.set_output("Out", output_arg_names)
if op.type == 'check_finite_and_unscale':
op_role_idxs = []
op_role_var = op.attr("op_role_var")
for i, name in enumerate(op_role_var):
if '@GRAD' in name and weight_name in name:
op_role_idxs.append(i)
if len(op_role_idxs) > 0:
for i in reversed(op_role_idxs):
op_role_var.pop(i)
op.desc._set_attr("op_role_var", op_role_var)