Files
insightface/recognition/arcface_oneflow/val.py
2021-11-04 19:35:06 +08:00

45 lines
1.3 KiB
Python

import backbones
import oneflow as flow
from utils.utils_callbacks import CallBackVerification
from backbones import get_model
from graph import TrainGraph, EvalGraph
import logging
import argparse
from utils.utils_config import get_config
from function import EvalGraph
def main(args):
cfg = get_config(args.config)
logging.basicConfig(level=logging.NOTSET)
logging.info(args.model_path)
backbone = get_model(cfg.network, dropout=0.0, num_features=cfg.embedding_size).to(
"cuda"
)
val_callback = CallBackVerification(1, 0, cfg.val_targets, cfg.ofrecord_path)
state_dict = flow.load(args.model_path)
new_parameters = dict()
for key, value in state_dict.items():
if "num_batches_tracked" not in key:
if key == "fc.weight":
continue
new_key = key.replace("backbone.", "")
new_parameters[new_key] = value
backbone.load_state_dict(new_parameters)
infer_graph = EvalGraph(backbone)
val_callback(1000, backbone, infer_graph)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="OneFlow ArcFace val")
parser.add_argument("config", type=str, help="py config file")
parser.add_argument("--model_path", type=str, help="model path")
main(parser.parse_args())