mirror of
https://github.com/deepinsight/insightface.git
synced 2025-12-30 08:02:27 +00:00
222 lines
7.3 KiB
Python
222 lines
7.3 KiB
Python
import logging
|
|
import os
|
|
import sys
|
|
import time
|
|
|
|
import horovod.mxnet as hvd
|
|
import mxnet as mx
|
|
from mxboard import SummaryWriter
|
|
from mxnet import nd
|
|
|
|
from default import config
|
|
from evaluation import verification
|
|
|
|
|
|
class MetricNdarray(object):
|
|
def __init__(self):
|
|
self.sum = None
|
|
self.count = 0
|
|
self.reset()
|
|
|
|
def reset(self):
|
|
self.sum = None
|
|
self.count = 0
|
|
|
|
def update(self, val, n=1):
|
|
assert isinstance(val, mx.nd.NDArray), type(val)
|
|
if self.sum is None: # init sum
|
|
self.sum = mx.nd.zeros_like(val)
|
|
|
|
self.sum += val * n
|
|
self.count += n
|
|
|
|
def get(self):
|
|
average = self.sum / self.count
|
|
return average.asscalar()
|
|
|
|
|
|
class CallBackVertification(object):
|
|
def __init__(self, symbol, model):
|
|
self.verbose = config.verbose
|
|
self.symbol = symbol
|
|
self.highest_acc = 0.0
|
|
self.highest_acc_list = [0.0] * len(config.val_targets)
|
|
self.model = model
|
|
self.ver_list = []
|
|
self.ver_name_list = []
|
|
self.init_dataset(val_targets=config.val_targets,
|
|
data_dir=os.path.dirname(config.rec),
|
|
image_size=(config.image_size, config.image_size))
|
|
|
|
def ver_test(self, num_update):
|
|
results = []
|
|
for i in range(len(self.ver_list)):
|
|
acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(
|
|
self.ver_list[i], self.model, 10, 10, None, None)
|
|
logging.info('[%s][%d]XNorm: %f' %
|
|
(self.ver_name_list[i], num_update, xnorm))
|
|
logging.info('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' %
|
|
(self.ver_name_list[i], num_update, acc2, std2))
|
|
if acc2 > self.highest_acc_list[i]:
|
|
self.highest_acc_list[i] = acc2
|
|
logging.info(
|
|
'[%s][%d]Accuracy-Highest: %1.5f' %
|
|
(self.ver_name_list[i], num_update, self.highest_acc_list[i]))
|
|
results.append(acc2)
|
|
|
|
def init_dataset(self, val_targets, data_dir, image_size):
|
|
for name in val_targets:
|
|
path = os.path.join(data_dir, name + ".bin")
|
|
if os.path.exists(path):
|
|
data_set = verification.load_bin(path, image_size)
|
|
self.ver_list.append(data_set)
|
|
self.ver_name_list.append(name)
|
|
|
|
def __call__(self, param):
|
|
#
|
|
num_update = param.num_update
|
|
#
|
|
if num_update > 0 and num_update % self.verbose == 0: # debug in mbatches in 100 and 200
|
|
# accuracy list
|
|
self.ver_test(num_update)
|
|
|
|
|
|
class CallBackCenterSave(object):
|
|
def __init__(self, memory_bank, save_interval=10000):
|
|
self.save_interval = save_interval
|
|
self.memory_bank = memory_bank
|
|
|
|
def __call__(self, param):
|
|
if param.num_update % self.save_interval == 0:
|
|
self.memory_bank.save()
|
|
|
|
|
|
class CallBackModelSave(object):
|
|
def __init__(self, symbol, model, prefix, rank):
|
|
self.symbol = symbol
|
|
self.model = model
|
|
self.prefix = prefix
|
|
self.max_step = config.max_update
|
|
self.rank = rank
|
|
|
|
def __call__(self, param):
|
|
num_update = param.num_update
|
|
|
|
if num_update in [
|
|
self.max_step - 10,
|
|
] or (num_update % 10000 == 0 and num_update > 0):
|
|
|
|
# params
|
|
arg, aux = self.model.get_export_params()
|
|
# symbol
|
|
_sym = self.symbol
|
|
# save
|
|
|
|
# average all aux
|
|
new_arg, new_aux = {}, {}
|
|
for key, tensor in aux.items():
|
|
new_aux[key] = hvd.allreduce(tensor, average=True)
|
|
for key, tensor in arg.items():
|
|
new_arg[key] = hvd.allreduce(tensor, average=True)
|
|
|
|
if self.rank == 0:
|
|
mx.model.save_checkpoint(prefix=self.prefix + "_average",
|
|
epoch=0,
|
|
symbol=_sym,
|
|
arg_params=new_arg,
|
|
aux_params=new_aux)
|
|
mx.model.save_checkpoint(prefix=self.prefix,
|
|
epoch=0,
|
|
symbol=_sym,
|
|
arg_params=arg,
|
|
aux_params=aux)
|
|
|
|
# training is over
|
|
if num_update > self.max_step > 0:
|
|
logging.info('Training is over!')
|
|
sys.exit(0)
|
|
|
|
|
|
class MetricCallBack(object):
|
|
def __init__(self, batch_size, rank, size, prefix_dir, frequent):
|
|
self.batch_size = batch_size
|
|
self.rank = rank
|
|
self.size = size
|
|
self.prefix_dir = prefix_dir
|
|
self.frequent = frequent
|
|
self.init = False
|
|
self.tic = 0
|
|
self.last_count = 0
|
|
self.loss_metric_list = MetricNdarray()
|
|
t = time.localtime()
|
|
|
|
self.summary_writer = SummaryWriter(
|
|
logdir=os.path.join(self.prefix_dir, 'log_tensorboard', str(t.tm_mon) + '_' + str(t.tm_mday) \
|
|
+ '_' + str(t.tm_hour)),
|
|
verbose=False)
|
|
pass
|
|
|
|
|
|
class CallBackLogging(object):
|
|
def __init__(self, rank, size, prefix_dir):
|
|
self.batch_size = config.batch_size
|
|
self.rank = rank
|
|
self.size = size
|
|
self.prefix_dir = prefix_dir
|
|
self.frequent = config.frequent
|
|
self.init = False
|
|
self.tic = 0
|
|
self.last_count = 0
|
|
self.loss_metric = MetricNdarray()
|
|
t = time.localtime()
|
|
|
|
if self.rank == 0:
|
|
self.summary_writer = SummaryWriter(logdir=os.path.join(
|
|
self.prefix_dir, "log_tensorboard",
|
|
"%s_%s_%s" % (str(t.tm_mon), str(t.tm_mday), str(t.tm_hour))),
|
|
verbose=False)
|
|
else:
|
|
time.sleep(2)
|
|
|
|
def __call__(self, param):
|
|
"""Callback to Show speed
|
|
"""
|
|
count = param.num_update
|
|
|
|
if self.last_count > count:
|
|
self.init = False
|
|
self.last_count = count
|
|
|
|
self.loss_metric.update(param.loss[0])
|
|
|
|
if self.init:
|
|
if count % self.frequent == 0:
|
|
nd.waitall()
|
|
try:
|
|
speed = self.frequent * self.batch_size / (time.time() - self.tic)
|
|
speed_total = speed * self.size
|
|
except ZeroDivisionError:
|
|
speed = float('inf')
|
|
speed_total = float('inf')
|
|
|
|
# summary loss
|
|
loss_scalar = self.loss_metric.get()
|
|
|
|
if self.rank == 0:
|
|
self.summary_writer.add_scalar(tag="loss", value=loss_scalar, global_step=param.num_update)
|
|
loss_str_format = "[%d][%s]:%.2f " % (param.num_epoch, "loss",
|
|
loss_scalar)
|
|
self.loss_metric.reset()
|
|
|
|
if self.rank == 0:
|
|
self.summary_writer.add_scalar(tag="speed", value=speed, global_step=param.num_update)
|
|
self.summary_writer.flush()
|
|
logging.info(
|
|
"Iter:%d Rank:%.2f it/sec Total:%.2f it/sec %s",
|
|
param.num_update, speed, speed_total, loss_str_format)
|
|
|
|
self.tic = time.time()
|
|
else:
|
|
self.init = True
|
|
self.tic = time.time()
|