mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-18 14:55:42 +00:00
134 lines
5.3 KiB
Python
134 lines
5.3 KiB
Python
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
def rearrange_weight(weight_dict, init_num_rank, new_num_rank):
|
|
"""
|
|
A help function to convert pre-trained distributed fc parameters for
|
|
inference or fine-tuning. Note that the number of ranks or GPUs for
|
|
inference or fine-tuning can be different from that for pre-training.
|
|
|
|
Args:
|
|
weight_dict(dict): the dict store distributed parameters,
|
|
key: eg. dist@fc@rank@00000.w_0
|
|
value: numpy.ndarray
|
|
init_num_rank(int) : pre-trained weight at init_num_rank gpu device.
|
|
new_num_rank(int) : want to rearrange weight to new_num_rank gpu device.
|
|
|
|
Returns:
|
|
dict: rearranged weight for new_num_rank gpu device.
|
|
"""
|
|
|
|
ret_dict = {}
|
|
if init_num_rank == new_num_rank:
|
|
return weight_dict
|
|
|
|
if len(weight_dict) == 0:
|
|
return weight_dict
|
|
|
|
# generate name format
|
|
name_format = list(weight_dict.keys())[0]
|
|
name_format = name_format.split('.')
|
|
name_format[0] = name_format[0].split('@')
|
|
name_format[0][-1] = '%05d'
|
|
name_format[0] = '@'.join(name_format[0])
|
|
name_format = '.'.join(name_format)
|
|
|
|
# calculate num class of pretrain shard
|
|
# num class of new shard
|
|
num_class = sum([
|
|
w.shape[1] if len(w.shape) == 2 else len(w)
|
|
for _, w in weight_dict.items()
|
|
])
|
|
init_nshard = (num_class + init_num_rank - 1) // init_num_rank
|
|
new_nshard = (num_class + new_num_rank - 1) // new_num_rank
|
|
|
|
if new_nshard * (new_num_rank - 1) >= num_class:
|
|
raise ValueError(
|
|
"num class {} cann't be rationally splited by num rank {}".format(
|
|
num_class, new_num_rank))
|
|
|
|
if init_num_rank > new_num_rank:
|
|
for new_idx in range(new_num_rank):
|
|
start = new_idx * new_nshard
|
|
end = min((new_idx + 1) * new_nshard - 1, num_class - 1)
|
|
init_shard_idx_start = start // init_nshard
|
|
init_shard_idx_end = end // init_nshard
|
|
|
|
weight_list = []
|
|
for init_idx in range(init_shard_idx_start,
|
|
init_shard_idx_end + 1):
|
|
name = name_format % init_idx
|
|
init_weight = weight_dict[name]
|
|
s = max(start - init_idx * init_nshard, 0)
|
|
if init_idx == init_shard_idx_end:
|
|
e = min(end - init_idx * init_nshard + 1, init_nshard)
|
|
else:
|
|
e = init_nshard
|
|
if len(init_weight.shape) == 2:
|
|
weight_list.append(init_weight[:, s:e])
|
|
else:
|
|
weight_list.append(init_weight[s:e])
|
|
|
|
name = name_format % new_idx
|
|
# for 2-dimention, we concat at axis=1,
|
|
# else for 1-dimention, we concat at axis=0
|
|
ret_dict[name] = np.concatenate(
|
|
weight_list, axis=len(weight_list[0].shape) - 1)
|
|
else:
|
|
for new_idx in range(new_num_rank):
|
|
start = new_idx * new_nshard
|
|
end = min((new_idx + 1) * new_nshard - 1, num_class - 1)
|
|
init_shard_idx_start = start // init_nshard
|
|
init_shard_idx_end = end // init_nshard
|
|
|
|
if init_shard_idx_start == init_shard_idx_end:
|
|
name = name_format % init_shard_idx_start
|
|
init_weight = weight_dict[name]
|
|
init_start = init_shard_idx_start * init_nshard
|
|
s = max(start - init_start, 0)
|
|
e = min((init_shard_idx_start + 1) * init_nshard,
|
|
end) - init_start + 1
|
|
if len(init_weight.shape) == 2:
|
|
new_weight = init_weight[:, s:e]
|
|
else:
|
|
new_weight = init_weight[s:e]
|
|
else:
|
|
# init_shard_idx_start + 1 == init_shard_idx_end
|
|
name = name_format % init_shard_idx_start
|
|
init_weight = weight_dict[name]
|
|
init_start = init_shard_idx_start * init_nshard
|
|
s = max(start - init_start, 0)
|
|
if len(init_weight.shape) == 2:
|
|
new_weight = init_weight[:, s:]
|
|
else:
|
|
new_weight = init_weight[s:]
|
|
|
|
e = end - (init_shard_idx_end * init_nshard) + 1
|
|
if e > 0:
|
|
name = name_format % init_shard_idx_end
|
|
init_weight = weight_dict[name]
|
|
if len(init_weight.shape) == 2:
|
|
new_weight2 = init_weight[:, :e]
|
|
else:
|
|
new_weight2 = init_weight[:e]
|
|
|
|
new_weight = np.concatenate(
|
|
[new_weight, new_weight2],
|
|
axis=len(new_weight.shape) - 1)
|
|
name = name_format % new_idx
|
|
ret_dict[name] = new_weight
|
|
|
|
return ret_dict
|