mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-01 19:20:17 +00:00
94 lines
3.4 KiB
Python
94 lines
3.4 KiB
Python
import argparse
|
|
import logging
|
|
import os
|
|
import oneflow as flow
|
|
import oneflow.nn as nn
|
|
import sys
|
|
import math
|
|
import numpy as np
|
|
import pickle
|
|
import time
|
|
from backbones import get_model
|
|
from utils.utils_callbacks import CallBackVerification, CallBackLogging
|
|
from utils.utils_config import get_config
|
|
from utils.utils_logging import AverageMeter, init_logging
|
|
from utils.ofrecord_data_utils import load_train_dataset, load_synthetic
|
|
from function import make_train_func, Validator
|
|
|
|
|
|
def main(args):
|
|
cfg = get_config(args.config)
|
|
|
|
cfg.device_num_per_node = args.device_num_per_node
|
|
cfg.total_batch_size = cfg.batch_size*cfg.device_num_per_node*cfg.num_nodes
|
|
cfg.steps_per_epoch = math.ceil(cfg.num_image / cfg.total_batch_size)
|
|
cfg.total_step = cfg.num_epoch*cfg.steps_per_epoch
|
|
cfg.lr_steps = (np.array(cfg.decay_epoch)*cfg.steps_per_epoch).tolist()
|
|
lr_scales = [0.1, 0.01, 0.001, 0.0001]
|
|
cfg.lr_scales = lr_scales[:len(cfg.lr_steps)]
|
|
cfg.output = os.path.join("work_dir", cfg.output, cfg.loss)
|
|
|
|
world_size = cfg.num_nodes
|
|
os.makedirs(cfg.output, exist_ok=True)
|
|
|
|
log_root = logging.getLogger()
|
|
init_logging(log_root, cfg.output)
|
|
flow.config.gpu_device_num(cfg.device_num_per_node)
|
|
logging.info("gpu num: %d" % cfg.device_num_per_node)
|
|
if cfg.num_nodes > 1:
|
|
assert cfg.num_nodes <= len(
|
|
cfg.node_ips), "The number of nodes should not be greater than length of node_ips list."
|
|
flow.env.ctrl_port(12138)
|
|
nodes = []
|
|
for ip in cfg.node_ips:
|
|
addr_dict = {}
|
|
addr_dict["addr"] = ip
|
|
nodes.append(addr_dict)
|
|
flow.env.machine(nodes)
|
|
flow.env.log_dir(cfg.output)
|
|
|
|
for key, value in cfg.items():
|
|
num_space = 35 - len(key)
|
|
logging.info(": " + key + " " * num_space + str(value))
|
|
|
|
train_func = make_train_func(cfg)
|
|
val_infer = Validator(cfg)
|
|
|
|
callback_verification = CallBackVerification(
|
|
3000, cfg.val_targets, cfg.eval_ofrecord_path)
|
|
callback_logging = CallBackLogging(
|
|
50, cfg.total_step, cfg.total_batch_size, world_size, None)
|
|
|
|
if cfg.resume and os.path.exists(cfg.model_load_dir):
|
|
logging.info("Loading model from {}".format(cfg.model_load_dir))
|
|
variables = flow.checkpoint.get(cfg.model_load_dir)
|
|
flow.load_variables(variables)
|
|
|
|
start_epoch = 0
|
|
global_step = 0
|
|
lr = cfg.lr
|
|
for epoch in range(start_epoch, cfg.num_epoch):
|
|
for steps in range(cfg.steps_per_epoch):
|
|
train_func().async_get(callback_logging.metric_cb(global_step, epoch, lr))
|
|
callback_verification(global_step, val_infer.get_symbol_val_fn)
|
|
global_step += 1
|
|
if epoch in cfg.decay_epoch:
|
|
lr *= 0.1
|
|
logging.info("lr_steps: %d" % global_step)
|
|
logging.info("lr change to %f" % lr)
|
|
|
|
# snapshot
|
|
path = os.path.join(
|
|
cfg.output, "snapshot_" + str(epoch))
|
|
flow.checkpoint.save(path)
|
|
logging.info("oneflow Model Saved in '{}'".format(path))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description='OneFlow ArcFace Training')
|
|
parser.add_argument('config', type=str, help='py config file')
|
|
parser.add_argument('--local_rank', type=int, default=0, help='local_rank')
|
|
parser.add_argument('--device_num_per_node', type=int,
|
|
default=1, help='local_rank')
|
|
main(parser.parse_args())
|