mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-11 18:22:41 +00:00
118 lines
3.7 KiB
Python
118 lines
3.7 KiB
Python
import os
|
|
|
|
import numpy as np
|
|
from mxnet import nd
|
|
import mxnet as mx
|
|
|
|
from memory_samplers import WeightIndexSampler
|
|
|
|
|
|
class MemoryBank(object):
|
|
def __init__(self,
|
|
num_sample,
|
|
num_local,
|
|
rank,
|
|
local_rank,
|
|
embedding_size,
|
|
prefix,
|
|
gpu=True):
|
|
"""
|
|
Parameters
|
|
----------
|
|
num_sample: int
|
|
The number of sampled class center.
|
|
num_local: int
|
|
The number of class center storage in this rank(CPU/GPU).
|
|
rank: int
|
|
Unique process(GPU) ID from 0 to size - 1.
|
|
local_rank: int
|
|
Unique process(GPU) ID within the server from 0 to 7.
|
|
embedding_size: int
|
|
The feature dimension.
|
|
prefix_dir: str
|
|
Path prefix of model dir.
|
|
gpu: bool
|
|
If True, class center and class center mom will storage in GPU.
|
|
"""
|
|
self.num_sample = num_sample
|
|
self.num_local = num_local
|
|
self.rank = rank
|
|
self.embedding_size = embedding_size
|
|
self.gpu = gpu
|
|
self.prefix = prefix
|
|
|
|
if self.gpu:
|
|
context = mx.gpu(local_rank)
|
|
else:
|
|
context = mx.cpu()
|
|
|
|
# In order to apply update, weight and momentum should be storage.
|
|
self.weight = nd.random_normal(loc=0,
|
|
scale=0.01,
|
|
shape=(self.num_local,
|
|
self.embedding_size),
|
|
ctx=context)
|
|
self.weight_mom = nd.zeros_like(self.weight)
|
|
|
|
# Sampler object
|
|
self.weight_index_sampler = WeightIndexSampler(num_sample, num_local,
|
|
rank)
|
|
|
|
def sample(self, global_label):
|
|
"""
|
|
Parameters
|
|
----------
|
|
global_label: NDArray
|
|
Global label (after gathers label from all rank)
|
|
Returns
|
|
-------
|
|
index: ndarray(numpy)
|
|
Local index for memory bank to sample, start from 0 to num_local, length is num_sample.
|
|
global_label: ndarray(numpy)
|
|
Global label after sort and unique.
|
|
"""
|
|
assert isinstance(global_label, nd.NDArray)
|
|
global_label = global_label.asnumpy()
|
|
global_label = np.unique(global_label)
|
|
global_label.sort()
|
|
index = self.weight_index_sampler(global_label)
|
|
index.sort()
|
|
return index, global_label
|
|
|
|
def get(self, index):
|
|
"""
|
|
Get sampled class centers and their momentum.
|
|
|
|
Parameters
|
|
----------
|
|
index: NDArray
|
|
Local index for memory bank to sample, start from 0 to num_local.
|
|
"""
|
|
return self.weight[index], self.weight_mom[index]
|
|
|
|
def set(self, index, updated_weight, updated_weight_mom=None):
|
|
"""
|
|
Update sampled class to memory bank, make the class center stored
|
|
in the memory bank the latest.
|
|
|
|
Parameters
|
|
----------
|
|
index: NDArray
|
|
Local index for memory bank to sample, start from 0 to num_local.
|
|
updated_weight: NDArray
|
|
Class center which has been applied gradients.
|
|
updated_weight_mom: NDArray
|
|
Class center momentum which has been moved average.
|
|
"""
|
|
|
|
self.weight[index] = updated_weight
|
|
self.weight_mom[index] = updated_weight_mom
|
|
|
|
def save(self):
|
|
nd.save(fname=os.path.join(self.prefix,
|
|
"%d_centers.param" % self.rank),
|
|
data=self.weight)
|
|
nd.save(fname=os.path.join(self.prefix,
|
|
"%d_centers_mom.param" % self.rank),
|
|
data=self.weight_mom)
|