mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-14 20:35:30 +00:00
125 lines
5.2 KiB
Python
125 lines
5.2 KiB
Python
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
def check_contains(name, name_list):
|
|
for n in name_list:
|
|
if name in n:
|
|
return True
|
|
return False
|
|
|
|
|
|
def gather_optimization_pass(program, weight_name):
|
|
op_idxs = []
|
|
gather_grad_op = None
|
|
momentum_op = None
|
|
for idx, op in enumerate(program.global_block().ops):
|
|
if (op.type == 'gather_grad' or
|
|
op.type == 'momentum') and check_contains(weight_name,
|
|
op.input_arg_names):
|
|
op_idxs.append(idx)
|
|
if op.type == 'momentum':
|
|
momentum_op = op
|
|
if op.type == 'gather_grad':
|
|
gather_grad_op = op
|
|
|
|
if gather_grad_op is not None and momentum_op is not None:
|
|
inputs = {
|
|
'Param': momentum_op.input('Param'),
|
|
'Velocity': momentum_op.input('Velocity'),
|
|
'LearningRate': momentum_op.input('LearningRate'),
|
|
'Grad': gather_grad_op.input('Out@GRAD'),
|
|
'Index': gather_grad_op.input('Index'),
|
|
'Axis': gather_grad_op.input('Axis'),
|
|
}
|
|
outputs = {
|
|
'ParamOut': momentum_op.output('ParamOut'),
|
|
'VelocityOut': momentum_op.output('VelocityOut'),
|
|
}
|
|
if 'MasterParam' in momentum_op.input_names and len(
|
|
momentum_op.input('MasterParam')) > 0:
|
|
inputs['MasterParam'] = momentum_op.input('MasterParam')
|
|
if 'MasterParamOut' in momentum_op.output_names and len(
|
|
momentum_op.output('MasterParamOut')) > 0:
|
|
outputs['MasterParamOut'] = momentum_op.output('MasterParamOut')
|
|
|
|
attrs = {
|
|
'mu': momentum_op.attr('mu'),
|
|
'use_nesterov': momentum_op.attr('use_nesterov'),
|
|
'regularization_method': momentum_op.attr('regularization_method'),
|
|
'regularization_coeff': momentum_op.attr('regularization_coeff'),
|
|
'multi_precision': momentum_op.attr('multi_precision'),
|
|
'rescale_grad': momentum_op.attr('rescale_grad'),
|
|
'op_device': momentum_op.attr('op_device'),
|
|
'op_namescope': momentum_op.attr('op_namescope'),
|
|
'op_role': momentum_op.attr('op_role'),
|
|
'op_role_var': momentum_op.input('Param'),
|
|
'axis': gather_grad_op.attr('axis'),
|
|
}
|
|
program.global_block()._insert_op(
|
|
op_idxs[-1] + 1,
|
|
type='sparse_momentum',
|
|
inputs=inputs,
|
|
outputs=outputs,
|
|
attrs=attrs)
|
|
|
|
for idx in reversed(op_idxs):
|
|
program.global_block()._remove_op(idx, sync=False)
|
|
|
|
var_names = []
|
|
for idx, name in enumerate(program.global_block().vars):
|
|
if '@GRAD' in name and weight_name in name:
|
|
var_names.append(name)
|
|
for name in var_names:
|
|
program.global_block()._remove_var(name, sync=False)
|
|
program.global_block()._sync_with_cpp()
|
|
|
|
|
|
def amp_pass(program, weight_name):
|
|
for idx, op in enumerate(program.global_block().ops):
|
|
if (op.type == 'update_loss_scaling' or
|
|
op.type == 'check_finite_and_unscale'):
|
|
input_idxs = []
|
|
input_arg_names = op.input("X")
|
|
# input_arg_names.append(gather_grad_op.input('Out@GRAD')[0])
|
|
for i, name in enumerate(input_arg_names):
|
|
if '@GRAD' in name and weight_name in name:
|
|
input_idxs.append(i)
|
|
if len(input_idxs) > 0:
|
|
for i in reversed(input_idxs):
|
|
input_arg_names.pop(i)
|
|
op.desc.set_input("X", input_arg_names)
|
|
|
|
output_idxs = []
|
|
output_arg_names = op.output("Out")
|
|
# output_arg_names.append(gather_grad_op.input('Out@GRAD')[0])
|
|
for i, name in enumerate(output_arg_names):
|
|
if '@GRAD' in name and weight_name in name:
|
|
output_idxs.append(i)
|
|
if len(output_idxs) > 0:
|
|
for i in reversed(output_idxs):
|
|
output_arg_names.pop(i)
|
|
op.desc.set_output("Out", output_arg_names)
|
|
|
|
if op.type == 'check_finite_and_unscale':
|
|
op_role_idxs = []
|
|
op_role_var = op.attr("op_role_var")
|
|
for i, name in enumerate(op_role_var):
|
|
if '@GRAD' in name and weight_name in name:
|
|
op_role_idxs.append(i)
|
|
if len(op_role_idxs) > 0:
|
|
for i in reversed(op_role_idxs):
|
|
op_role_var.pop(i)
|
|
op.desc._set_attr("op_role_var", op_role_var)
|