mirror of
https://github.com/deepinsight/insightface.git
synced 2025-12-30 08:02:27 +00:00
fix path errors
This commit is contained in:
@@ -5,7 +5,7 @@ from utils.utils_logging import AverageMeter
|
||||
from utils.utils_callbacks import CallBackVerification, CallBackLogging, CallBackModelCheckpoint
|
||||
from backbones import get_model
|
||||
from graph import TrainGraph, EvalGraph
|
||||
from losses import CrossEntropyLoss_sbp
|
||||
from utils.losses import CrossEntropyLoss_sbp
|
||||
import logging
|
||||
|
||||
|
||||
|
||||
@@ -30,14 +30,12 @@ class TrainGraph(flow.nn.Graph):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if cfg.use_fp16:
|
||||
if cfg.fp16:
|
||||
self.config.enable_amp(True)
|
||||
self.set_grad_scaler(make_grad_scaler())
|
||||
elif cfg.scale_grad:
|
||||
self.set_grad_scaler(make_static_grad_scaler())
|
||||
|
||||
|
||||
|
||||
self.config.allow_fuse_add_to_output(True)
|
||||
self.config.allow_fuse_model_update_ops(True)
|
||||
|
||||
|
||||
@@ -47,7 +47,8 @@ class CallBackVerification(object):
|
||||
self.ver_list[i], backbone, 10, 10, self.is_consistent
|
||||
)
|
||||
logging.info(
|
||||
"[%s][%d]XNorm: %f" % (self.ver_name_list[i], global_step, xnorm)
|
||||
"[%s][%d]XNorm: %f" % (
|
||||
self.ver_name_list[i], global_step, xnorm)
|
||||
)
|
||||
logging.info(
|
||||
"[%s][%d]Accuracy-Flip: %1.5f+-%1.5f"
|
||||
@@ -64,11 +65,13 @@ class CallBackVerification(object):
|
||||
def init_dataset(self, val_targets, data_dir, image_size):
|
||||
|
||||
for name in val_targets:
|
||||
path = os.path.join(data_dir, name + ".bin")
|
||||
path = os.path.join(data_dir, "val", name + ".bin")
|
||||
if os.path.exists(path):
|
||||
data_set = verification.load_bin_cv(path, image_size)
|
||||
self.ver_list.append(data_set)
|
||||
self.ver_name_list.append(name)
|
||||
if len(self.ver_list) == 0:
|
||||
logging.info("Val targets is None !")
|
||||
|
||||
def __call__(self, num_update, backbone: flow.nn.Module, backbone_graph=None):
|
||||
|
||||
@@ -120,8 +123,10 @@ class CallBackLogging(object):
|
||||
time_total = time_now / ((global_step + 1) / self.total_step)
|
||||
time_for_end = time_total - time_now
|
||||
if self.writer is not None:
|
||||
self.writer.add_scalar("time_for_end", time_for_end, global_step)
|
||||
self.writer.add_scalar("learning_rate", learning_rate, global_step)
|
||||
self.writer.add_scalar(
|
||||
"time_for_end", time_for_end, global_step)
|
||||
self.writer.add_scalar(
|
||||
"learning_rate", learning_rate, global_step)
|
||||
self.writer.add_scalar("loss", loss.avg, global_step)
|
||||
if fp16:
|
||||
msg = (
|
||||
@@ -168,7 +173,8 @@ class CallBackModelCheckpoint(object):
|
||||
path_module = os.path.join(self.output, "epoch_%d" % (epoch))
|
||||
|
||||
if is_consistent:
|
||||
flow.save(backbone.state_dict(), path_module, consistent_dst_rank=0)
|
||||
flow.save(backbone.state_dict(),
|
||||
path_module, consistent_dst_rank=0)
|
||||
else:
|
||||
if self.rank == 0:
|
||||
flow.save(backbone.state_dict(), path_module)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import backbones
|
||||
import oneflow as flow
|
||||
from utils.utils_callbacks import CallBackVerification
|
||||
from backbones import get_model
|
||||
@@ -6,7 +5,6 @@ from graph import TrainGraph, EvalGraph
|
||||
import logging
|
||||
import argparse
|
||||
from utils.utils_config import get_config
|
||||
from function import EvalGraph
|
||||
|
||||
|
||||
def main(args):
|
||||
@@ -18,7 +16,8 @@ def main(args):
|
||||
backbone = get_model(cfg.network, dropout=0.0, num_features=cfg.embedding_size).to(
|
||||
"cuda"
|
||||
)
|
||||
val_callback = CallBackVerification(1, 0, cfg.val_targets, cfg.ofrecord_path)
|
||||
val_callback = CallBackVerification(
|
||||
1, 0, cfg.val_targets, cfg.ofrecord_path)
|
||||
|
||||
state_dict = flow.load(args.model_path)
|
||||
|
||||
@@ -32,7 +31,7 @@ def main(args):
|
||||
|
||||
backbone.load_state_dict(new_parameters)
|
||||
|
||||
infer_graph = EvalGraph(backbone)
|
||||
infer_graph = EvalGraph(backbone, cfg)
|
||||
val_callback(1000, backbone, infer_graph)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user