fix path errors

This commit is contained in:
olojuwin
2021-11-06 19:46:22 +08:00
parent 7770ca74f6
commit 2251bead1b
4 changed files with 16 additions and 13 deletions

View File

@@ -5,7 +5,7 @@ from utils.utils_logging import AverageMeter
from utils.utils_callbacks import CallBackVerification, CallBackLogging, CallBackModelCheckpoint
from backbones import get_model
from graph import TrainGraph, EvalGraph
from losses import CrossEntropyLoss_sbp
from utils.losses import CrossEntropyLoss_sbp
import logging

View File

@@ -30,14 +30,12 @@ class TrainGraph(flow.nn.Graph):
):
super().__init__()
if cfg.use_fp16:
if cfg.fp16:
self.config.enable_amp(True)
self.set_grad_scaler(make_grad_scaler())
elif cfg.scale_grad:
self.set_grad_scaler(make_static_grad_scaler())
self.config.allow_fuse_add_to_output(True)
self.config.allow_fuse_model_update_ops(True)

View File

@@ -47,7 +47,8 @@ class CallBackVerification(object):
self.ver_list[i], backbone, 10, 10, self.is_consistent
)
logging.info(
"[%s][%d]XNorm: %f" % (self.ver_name_list[i], global_step, xnorm)
"[%s][%d]XNorm: %f" % (
self.ver_name_list[i], global_step, xnorm)
)
logging.info(
"[%s][%d]Accuracy-Flip: %1.5f+-%1.5f"
@@ -64,11 +65,13 @@ class CallBackVerification(object):
def init_dataset(self, val_targets, data_dir, image_size):
for name in val_targets:
path = os.path.join(data_dir, name + ".bin")
path = os.path.join(data_dir, "val", name + ".bin")
if os.path.exists(path):
data_set = verification.load_bin_cv(path, image_size)
self.ver_list.append(data_set)
self.ver_name_list.append(name)
if len(self.ver_list) == 0:
logging.info("Val targets is None !")
def __call__(self, num_update, backbone: flow.nn.Module, backbone_graph=None):
@@ -120,8 +123,10 @@ class CallBackLogging(object):
time_total = time_now / ((global_step + 1) / self.total_step)
time_for_end = time_total - time_now
if self.writer is not None:
self.writer.add_scalar("time_for_end", time_for_end, global_step)
self.writer.add_scalar("learning_rate", learning_rate, global_step)
self.writer.add_scalar(
"time_for_end", time_for_end, global_step)
self.writer.add_scalar(
"learning_rate", learning_rate, global_step)
self.writer.add_scalar("loss", loss.avg, global_step)
if fp16:
msg = (
@@ -168,7 +173,8 @@ class CallBackModelCheckpoint(object):
path_module = os.path.join(self.output, "epoch_%d" % (epoch))
if is_consistent:
flow.save(backbone.state_dict(), path_module, consistent_dst_rank=0)
flow.save(backbone.state_dict(),
path_module, consistent_dst_rank=0)
else:
if self.rank == 0:
flow.save(backbone.state_dict(), path_module)

View File

@@ -1,4 +1,3 @@
import backbones
import oneflow as flow
from utils.utils_callbacks import CallBackVerification
from backbones import get_model
@@ -6,7 +5,6 @@ from graph import TrainGraph, EvalGraph
import logging
import argparse
from utils.utils_config import get_config
from function import EvalGraph
def main(args):
@@ -18,7 +16,8 @@ def main(args):
backbone = get_model(cfg.network, dropout=0.0, num_features=cfg.embedding_size).to(
"cuda"
)
val_callback = CallBackVerification(1, 0, cfg.val_targets, cfg.ofrecord_path)
val_callback = CallBackVerification(
1, 0, cfg.val_targets, cfg.ofrecord_path)
state_dict = flow.load(args.model_path)
@@ -32,7 +31,7 @@ def main(args):
backbone.load_state_dict(new_parameters)
infer_graph = EvalGraph(backbone)
infer_graph = EvalGraph(backbone, cfg)
val_callback(1000, backbone, infer_graph)