Files
insightface/recognition/arcface_oneflow/train.py
2021-10-15 11:30:29 +08:00

94 lines
3.4 KiB
Python

import argparse
import logging
import os
import oneflow as flow
import oneflow.nn as nn
import sys
import math
import numpy as np
import pickle
import time
from backbones import get_model
from utils.utils_callbacks import CallBackVerification, CallBackLogging
from utils.utils_config import get_config
from utils.utils_logging import AverageMeter, init_logging
from utils.ofrecord_data_utils import load_train_dataset, load_synthetic
from function import make_train_func, Validator
def main(args):
cfg = get_config(args.config)
cfg.device_num_per_node = args.device_num_per_node
cfg.total_batch_size = cfg.batch_size*cfg.device_num_per_node*cfg.num_nodes
cfg.steps_per_epoch = math.ceil(cfg.num_image / cfg.total_batch_size)
cfg.total_step = cfg.num_epoch*cfg.steps_per_epoch
cfg.lr_steps = (np.array(cfg.decay_epoch)*cfg.steps_per_epoch).tolist()
lr_scales = [0.1, 0.01, 0.001, 0.0001]
cfg.lr_scales = lr_scales[:len(cfg.lr_steps)]
cfg.output = os.path.join("work_dir", cfg.output, cfg.loss)
world_size = cfg.num_nodes
os.makedirs(cfg.output, exist_ok=True)
log_root = logging.getLogger()
init_logging(log_root, cfg.output)
flow.config.gpu_device_num(cfg.device_num_per_node)
logging.info("gpu num: %d" % cfg.device_num_per_node)
if cfg.num_nodes > 1:
assert cfg.num_nodes <= len(
cfg.node_ips), "The number of nodes should not be greater than length of node_ips list."
flow.env.ctrl_port(12138)
nodes = []
for ip in cfg.node_ips:
addr_dict = {}
addr_dict["addr"] = ip
nodes.append(addr_dict)
flow.env.machine(nodes)
flow.env.log_dir(cfg.output)
for key, value in cfg.items():
num_space = 35 - len(key)
logging.info(": " + key + " " * num_space + str(value))
train_func = make_train_func(cfg)
val_infer = Validator(cfg)
callback_verification = CallBackVerification(
3000, cfg.val_targets, cfg.eval_ofrecord_path)
callback_logging = CallBackLogging(
50, cfg.total_step, cfg.total_batch_size, world_size, None)
if cfg.resume and os.path.exists(cfg.model_load_dir):
logging.info("Loading model from {}".format(cfg.model_load_dir))
variables = flow.checkpoint.get(cfg.model_load_dir)
flow.load_variables(variables)
start_epoch = 0
global_step = 0
lr = cfg.lr
for epoch in range(start_epoch, cfg.num_epoch):
for steps in range(cfg.steps_per_epoch):
train_func().async_get(callback_logging.metric_cb(global_step, epoch, lr))
callback_verification(global_step, val_infer.get_symbol_val_fn)
global_step += 1
if epoch in cfg.decay_epoch:
lr *= 0.1
logging.info("lr_steps: %d" % global_step)
logging.info("lr change to %f" % lr)
# snapshot
path = os.path.join(
cfg.output, "snapshot_" + str(epoch))
flow.checkpoint.save(path)
logging.info("oneflow Model Saved in '{}'".format(path))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='OneFlow ArcFace Training')
parser.add_argument('config', type=str, help='py config file')
parser.add_argument('--local_rank', type=int, default=0, help='local_rank')
parser.add_argument('--device_num_per_node', type=int,
default=1, help='local_rank')
main(parser.parse_args())