Files
insightface/recognition/partial_fc/mxnet/memory_module.py
2020-11-06 13:59:21 +08:00

509 lines
20 KiB
Python

import logging
import warnings
from collections import namedtuple
import horovod.mxnet as hvd
import mxnet as mx
import mxnet.ndarray as nd
import numpy as np
from default import config
from optimizer import DistributedOptimizer
class SampleDistributeModule(object):
"""
Large-scale distributed sampling face recognition training Module, of course sampling is an option,
There will be no loss of accuracy in sampling in large-scale identities training tasks, uses only
8 NVIDIA RTX2080Ti to complete classification tasks with 10 millions of identities, 64 NVIDIA
RTX2080Ti can complete classification tasks with 100 million of identities.
See the original paper:
https://arxiv.org/abs/2010.05222
Parameters
----------
symbol: Symbol
Backbone symbol.
fc7_model: Object
Object of margin loss.
memory_bank: Memory bank Object.
Object of memory bank, which maintain local class centers and their momentum.
memory_optimizer: Optimizer object.
The updater of memory bank, default is sgd optimizer.
logger:
"""
def __init__(
self,
symbol,
fc7_model,
memory_bank,
memory_optimizer,
logger=logging,
):
self.size = hvd.size()
self.rank = hvd.rank()
self.local_rank = hvd.local_rank()
self.gpu = mx.gpu(self.local_rank)
self.cpu = mx.cpu() # `device_id` is not needed for CPU.
self.nd_cache = {}
self.embedding_size = config.embedding_size
self.batch_size = config.batch_size
self.num_update = 0
self.batch_end_param = namedtuple('batch_end_param',
['loss', 'num_epoch', 'num_update'])
self.fc7_model = fc7_model
self.symbol = symbol
self.logger = logger
self.backbone_module = mx.module.Module(self.symbol, ['data'],
['softmax_label'],
logger=self.logger,
context=self.gpu)
self.memory_bank = memory_bank
self.memory_optimizer = memory_optimizer
self.memory_lr = None
self.loss_cache = None
self.grad_cache = None
def forward_backward(self, data_batch):
"""A convenient function that calls both ``forward`` and ``backward``.
"""
total_feature, total_label = self.forward(data_batch, is_train=True)
self.backward_all(total_feature, total_label)
@staticmethod
def broadcast_parameters(params):
"""
:param params:
:return:
"""
rank_0_dict = {}
# Run broadcasts.
for key, tensor in params.items():
rank_0_dict[key] = hvd.broadcast(tensor, 0, key)
return rank_0_dict
def fit(self,
train_data,
optimizer_params,
batch_end_callback,
initializer,
arg_params=None,
aux_params=None):
# Bind -> Init_params -> Init_optimizers
self.bind(train_data.provide_data, train_data.provide_label, True)
self.init_params(initializer, arg_params, aux_params, False)
self.init_optimizer(optimizer_params=optimizer_params)
# Sync init
_arg_params, _aux_params = self.backbone_module.get_params()
_arg_params_rank_0 = self.broadcast_parameters(_arg_params)
_aux_params_rank_0 = self.broadcast_parameters(_aux_params)
self.backbone_module.set_params(_arg_params_rank_0, _aux_params_rank_0)
# Training loop
num_epoch = 0
while True:
data_iter = iter(train_data)
end_of_batch = False
next_data_batch = next(data_iter)
while not end_of_batch:
data_batch = next_data_batch
self.forward_backward(data_batch)
self.update()
try:
# pre fetch next batch
next_data_batch = next(data_iter)
self.prepare(next_data_batch, sparse_row_id_fn=None)
except StopIteration:
num_epoch += 1
end_of_batch = True
logging.info('reset dataset')
train_data.reset()
if batch_end_callback is not None:
batch_end_params = self.batch_end_param(
loss=self.loss_cache,
num_epoch=num_epoch,
num_update=self.num_update)
batch_end_callback(batch_end_params)
def get_export_params(self):
_g, _x = self.backbone_module.get_params()
g = _g.copy()
x = _x.copy()
return g, x
def get_ndarray2(self, context, name, arr):
key = "%s_%s" % (name, context)
if key not in self.nd_cache:
v = nd.zeros(shape=arr.shape, ctx=context, dtype=arr.dtype)
self.nd_cache[key] = v
else:
v = self.nd_cache[key]
arr.copyto(v)
return v
def get_ndarray(self, context, name, shape, dtype='float32'):
key = "%s_%s" % (name, context)
if key not in self.nd_cache:
v = nd.zeros(shape=shape, ctx=context, dtype=dtype)
self.nd_cache[key] = v
else:
v = self.nd_cache[key]
return v
def init_params(self,
initializer,
arg_params=None,
aux_params=None,
allow_missing=False,
force_init=False,
allow_extra=False):
"""Initializes the parameters and auxiliary states.
Parameters
----------
initializer : Initializer
Called to initialize parameters if needed.
arg_params : dict
If not ``None``, should be a dictionary of existing arg_params. Initialization
will be copied from that.
aux_params : dict
If not ``None``, should be a dictionary of existing aux_params. Initialization
will be copied from that.
allow_missing : bool
If ``True``, params could contain missing values, and the initializer will be
called to fill those missing params.
force_init : bool
If ``True``, will force re-initialize even if already initialized.
allow_extra : boolean, optional
Whether allow extra parameters that are not needed by symbol.
If this is True, no error will be thrown when arg_params or aux_params
contain extra parameters that is not needed by the executor.
"""
# backbone
self.backbone_module.init_params(initializer=initializer,
arg_params=arg_params,
aux_params=aux_params,
allow_missing=allow_missing,
force_init=force_init,
allow_extra=allow_extra)
def prepare(self, data_batch, sparse_row_id_fn=None):
if sparse_row_id_fn is not None:
warnings.warn(
UserWarning("sparse_row_id_fn is not invoked for BaseModule."))
def allgather(self, tensor, name, shape, dtype, context):
""" Implement in-place AllGather using AllReduce
"""
assert isinstance(tensor, nd.NDArray), type(tensor)
assert isinstance(name, str), type(name)
assert isinstance(shape, tuple), type(shape)
assert isinstance(dtype, str), type(dtype)
assert isinstance(context, mx.context.Context), type(context)
total_tensor = self.get_ndarray(context=context,
name=name,
shape=shape,
dtype=dtype)
total_tensor[:] = 0 # reset array before all-reduce is very important
total_tensor[self.rank * self.batch_size:self.rank * self.batch_size +
self.batch_size] = tensor
hvd.allreduce_(total_tensor, average=False) # all-reduce in-place
return total_tensor
def forward(self, data_batch, is_train=None):
self.backbone_module.forward(data_batch, is_train=is_train)
if is_train:
self.num_update += 1
fc1 = self.backbone_module.get_outputs()[0]
label = data_batch.label[0]
total_features = self.allgather(tensor=fc1,
name='total_feature',
shape=(self.batch_size * self.size,
self.embedding_size),
dtype='float32',
context=self.gpu)
total_labels = self.allgather(tensor=label,
name='total_label',
shape=(self.batch_size *
self.size, ),
dtype='int32',
context=self.cpu)
return total_features, total_labels
else:
return None
def backward_all(
self,
total_feature,
total_label,
):
# get memory bank learning rate
self.memory_lr = self.memory_optimizer.lr_scheduler(self.num_update)
self.grad_cache = self.get_ndarray(self.gpu, 'grad_cache',
total_feature.shape)
self.loss_cache = self.get_ndarray(self.gpu, 'loss_cache', [1])
self.grad_cache[:] = 0
self.loss_cache[:] = 0
if not bool(config.sample_ratio - 1):
grad, loss = self.backward(total_feature, total_label)
else:
grad, loss = self.backward_sample(total_feature, total_label)
self.loss_cache[0] = loss
total_feature_grad = grad
total_feature_grad = hvd.allreduce(total_feature_grad, average=False)
fc1_grad = total_feature_grad[self.batch_size *
self.rank:self.batch_size * self.rank +
self.batch_size]
self.backbone_module.backward(out_grads=[fc1_grad / self.size])
def get_outputs(self, merge_multi_context=True):
"""
Gets outputs of the previous forward computation.
Returns
-------
list of NDArray or list of list of NDArray
Output.
"""
return self.backbone_module.get_outputs(
merge_multi_context=merge_multi_context)
def update(self):
"""
Updates parameters according to the installed optimizer and the gradients computed
in the previous forward-backward batch.
"""
self.backbone_module.update()
mx.nd.waitall()
def bind(self, data_shapes, label_shapes=None, for_training=True):
self.backbone_module.bind(data_shapes,
label_shapes,
for_training=for_training)
def init_optimizer(self, optimizer_params, force_init=False):
"""
Installs and initializes optimizers.
Parameters
----------
optimizer_params : dict
Default `(('learning_rate', 0.01),)`. The default value is not a dictionary,
just to avoid pylint warning of dangerous default values.
force_init : bool
Default ``False``, indicating whether we should force re-initializing the
optimizer in the case an optimizer is already installed.
"""
optimizer_backbone = DistributedOptimizer(
mx.optimizer.SGD(**optimizer_params))
self.backbone_module.init_optimizer('local',
optimizer_backbone,
force_init=force_init)
def backward(self, total_feature, label):
memory_bank = self.memory_bank
assert memory_bank.num_local == memory_bank.num_sample, "pass"
_data = self.get_ndarray2(self.gpu, "data_%d" % self.rank,
total_feature)
# Attach grad
_data.attach_grad()
memory_bank.weight.attach_grad()
# Convert label
_label = self.get_ndarray2(self.gpu, 'label_%d' % self.rank, label)
_label = _label - int(self.rank * memory_bank.num_local)
_fc7, _one_hot = self.fc7_model.forward(_data,
memory_bank.weight,
mapping_label=_label,
depth=memory_bank.num_local)
# Sync max
max_fc7 = nd.max(_fc7, axis=1, keepdims=True)
max_fc7 = nd.reshape(max_fc7, -1)
total_max_fc7 = self.get_ndarray(context=self.gpu,
name='total_max_fc7',
shape=(max_fc7.shape[0], self.size),
dtype='float32')
total_max_fc7[:] = 0
total_max_fc7[:, self.rank] = max_fc7
hvd.allreduce_(total_max_fc7, average=False)
global_max_fc7 = self.get_ndarray(context=self.gpu,
name='global_max_fc7',
shape=(max_fc7.shape[0], 1),
dtype='float32')
nd.max(total_max_fc7, axis=1, keepdims=True, out=global_max_fc7)
# Calculate exp(logits)
_fc7_grad = nd.broadcast_sub(_fc7, global_max_fc7)
_fc7_grad = nd.exp(_fc7_grad)
# Calculate sum
sum_fc7 = nd.sum(_fc7_grad, axis=1, keepdims=True)
global_sum_fc7 = hvd.allreduce(sum_fc7, average=False)
# Calculate prob
_fc7_grad = nd.broadcast_div(_fc7_grad, global_sum_fc7)
# Calculate loss
tmp = _fc7_grad * _one_hot
tmp = nd.sum(tmp, axis=1, keepdims=True)
tmp = self.get_ndarray2(self.gpu, 'ctx_loss', tmp)
tmp = hvd.allreduce(tmp, average=False)
global_loss = -nd.mean(nd.log(tmp + 1e-30))
# Calculate fc7 grad
_fc7_grad = _fc7_grad - _one_hot
# Backward
_fc7.backward(out_grad=_fc7_grad)
# Update center
_weight_grad = memory_bank.weight.grad
self.memory_optimizer.update(weight=memory_bank.weight,
grad=_weight_grad,
state=memory_bank.weight_mom,
learning_rate=self.memory_lr)
return _data.grad, global_loss
def backward_sample(self, total_feature, label):
this_rank_classes = int(self.memory_bank.num_sample)
local_index, unique_sorted_global_label = self.memory_bank.sample(
label)
# Get local index
_mapping_dict = {}
local_sampled_class = local_index + self.rank * self.memory_bank.num_local
global_label_set = set(unique_sorted_global_label)
for idx, absolute_label in enumerate(local_sampled_class):
if absolute_label in global_label_set:
_mapping_dict[
absolute_label] = idx + self.rank * self.memory_bank.num_sample
label_list = list(label.asnumpy())
mapping_label = []
for i in range(len(label_list)):
absolute_label = label_list[i]
if absolute_label in _mapping_dict.keys():
mapping_label.append(_mapping_dict[absolute_label])
else:
mapping_label.append(-1)
mapping_label = nd.array(mapping_label, dtype=np.int32)
# Get weight
local_index = nd.array(local_index)
local_index = self.get_ndarray2(self.gpu, "local_index", local_index)
sample_weight, sample_weight_mom = self.memory_bank.get(local_index)
# Sync to gpu
if self.memory_bank.gpu:
_data = self.get_ndarray2(self.gpu, "data_%d" % self.rank,
total_feature)
_weight = self.get_ndarray2(self.gpu, 'weight_%d' % self.rank,
sample_weight)
_weight_mom = self.get_ndarray2(self.gpu,
'weight_mom_%d' % self.rank,
sample_weight_mom)
else:
_data = self.get_ndarray2(self.gpu, "data_%d" % self.rank,
total_feature)
_weight = self.get_ndarray2(self.gpu, 'weight_%d' % self.rank,
sample_weight)
_weight_mom = self.get_ndarray2(self.gpu,
'weight_mom_%d' % self.rank,
sample_weight_mom)
# Attach grad
_data.attach_grad()
_weight.attach_grad()
# Convert label
_label = self.get_ndarray2(self.gpu, 'mapping_label_%d' % self.rank,
mapping_label)
_label = _label - int(self.rank * self.memory_bank.num_sample)
_fc7, _one_hot = self.fc7_model.forward(_data,
_weight,
mapping_label=_label,
depth=this_rank_classes)
# Sync max
max_fc7 = nd.max(_fc7, axis=1, keepdims=True)
max_fc7 = nd.reshape(max_fc7, -1)
total_max_fc7 = self.get_ndarray(context=self.gpu,
name='total_max_fc7',
shape=(max_fc7.shape[0], self.size),
dtype='float32')
total_max_fc7[:] = 0
total_max_fc7[:, self.rank] = max_fc7
hvd.allreduce_(total_max_fc7, average=False)
global_max_fc7 = self.get_ndarray(context=self.gpu,
name='global_max_fc7',
shape=(max_fc7.shape[0], 1),
dtype='float32')
nd.max(total_max_fc7, axis=1, keepdims=True, out=global_max_fc7)
# Calculate exp(logits)
_fc7_grad = nd.broadcast_sub(_fc7, global_max_fc7)
_fc7_grad = nd.exp(_fc7_grad)
# Calculate sum
sum_fc7 = nd.sum(_fc7_grad, axis=1, keepdims=True)
global_sum_fc7 = hvd.allreduce(sum_fc7, average=False)
# Calculate grad
_fc7_grad = nd.broadcast_div(_fc7_grad, global_sum_fc7)
# Calculate loss
tmp = _fc7_grad * _one_hot
tmp = nd.sum(tmp, axis=1, keepdims=True)
tmp = self.get_ndarray2(self.gpu, 'ctx_loss', tmp)
tmp = hvd.allreduce(tmp, average=False)
global_loss = -nd.mean(nd.log(tmp + 1e-30))
_fc7_grad = _fc7_grad - _one_hot
# Backward
_fc7.backward(out_grad=_fc7_grad)
# Update center
_weight_grad = _weight.grad
self.memory_optimizer.update(weight=_weight,
grad=_weight_grad,
state=_weight_mom,
learning_rate=self.memory_lr)
if self.memory_bank.gpu:
self.memory_bank.set(index=local_index,
updated_weight=_weight,
updated_weight_mom=_weight_mom)
else:
self.memory_bank.set(index=local_index,
updated_weight=self.get_ndarray2(
mx.cpu(), "cpu_weight_%d" % self.rank,
_weight),
updated_weight_mom=self.get_ndarray2(
mx.cpu(), "cpu_weight_mom_%d" % self.rank,
_weight_mom))
return _data.grad, global_loss