2021-10-15 11:30:29 +08:00
|
|
|
import argparse
|
|
|
|
|
import logging
|
|
|
|
|
import os
|
|
|
|
|
import oneflow as flow
|
2021-11-04 19:35:06 +08:00
|
|
|
|
|
|
|
|
from function import Trainer
|
|
|
|
|
from utils.utils_logging import init_logging
|
2021-10-15 11:30:29 +08:00
|
|
|
from utils.utils_config import get_config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(args):
|
|
|
|
|
cfg = get_config(args.config)
|
2021-11-04 19:35:06 +08:00
|
|
|
cfg.graph = args.graph
|
|
|
|
|
rank = flow.env.get_rank()
|
|
|
|
|
world_size = flow.env.get_world_size()
|
|
|
|
|
placement = flow.env.all_device_placement("cuda")
|
2021-10-15 11:30:29 +08:00
|
|
|
|
|
|
|
|
os.makedirs(cfg.output, exist_ok=True)
|
|
|
|
|
log_root = logging.getLogger()
|
2021-11-04 19:35:06 +08:00
|
|
|
init_logging(log_root, rank, cfg.output)
|
|
|
|
|
|
|
|
|
|
# root dir of loading checkpoint
|
|
|
|
|
load_path = None
|
2021-10-15 11:30:29 +08:00
|
|
|
|
|
|
|
|
for key, value in cfg.items():
|
2021-11-04 19:35:06 +08:00
|
|
|
num_space = 25 - len(key)
|
2021-10-15 11:30:29 +08:00
|
|
|
logging.info(": " + key + " " * num_space + str(value))
|
|
|
|
|
|
2021-11-04 19:35:06 +08:00
|
|
|
trainer = Trainer(cfg, placement, load_path, world_size, rank)
|
|
|
|
|
trainer()
|
2021-10-15 11:30:29 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2021-11-04 19:35:06 +08:00
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="OneFlow ArcFace Training")
|
|
|
|
|
parser.add_argument("config", type=str, help="py config file")
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--graph",
|
|
|
|
|
action="store_true",
|
|
|
|
|
help="Run model in graph mode,else run model in ddp mode.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument("--local_rank", type=int, default=0, help="local_rank")
|
2021-10-15 11:30:29 +08:00
|
|
|
main(parser.parse_args())
|