mirror of
https://gitcode.com/gh_mirrors/eas/EasyFace.git
synced 2026-05-15 03:38:04 +00:00
63 lines
1.9 KiB
Python
63 lines
1.9 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
|
|
import argparse
|
|
import json
|
|
|
|
|
|
def str2bool(v):
|
|
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
|
return True
|
|
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
|
return False
|
|
else:
|
|
raise argparse.ArgumentTypeError('Unsupported value encountered.')
|
|
|
|
|
|
class HParams(dict):
|
|
""" Hyper-parameters class
|
|
|
|
Store hyper-parameters in training / infer / ... scripts.
|
|
"""
|
|
def __getattr__(self, name):
|
|
if name in self.keys():
|
|
return self[name]
|
|
for v in self.values():
|
|
if isinstance(v, HParams):
|
|
if name in v:
|
|
return v[name]
|
|
raise AttributeError(f"'HParams' object has no attribute '{name}'")
|
|
|
|
def __setattr__(self, name, value):
|
|
self[name] = value
|
|
|
|
def save(self, filename):
|
|
with open(filename, 'w', encoding='utf-8') as fp:
|
|
json.dump(self, fp, ensure_ascii=False, indent=4, sort_keys=False)
|
|
|
|
def load(self, filename):
|
|
with open(filename, 'r', encoding='utf-8') as fp:
|
|
params_dict = json.load(fp)
|
|
for k, v in params_dict.items():
|
|
if isinstance(v, dict):
|
|
self[k].update(HParams(v))
|
|
else:
|
|
self[k] = v
|
|
|
|
|
|
def parse_args(parser):
|
|
""" Parse hyper-parameters from cmdline. """
|
|
parsed = parser.parse_args()
|
|
args = HParams()
|
|
optional_args = parser._action_groups[1]
|
|
for action in optional_args._group_actions[1:]:
|
|
arg_name = action.dest
|
|
args[arg_name] = getattr(parsed, arg_name)
|
|
for group in parser._action_groups[2:]:
|
|
group_args = HParams()
|
|
for action in group._group_actions:
|
|
arg_name = action.dest
|
|
group_args[arg_name] = getattr(parsed, arg_name)
|
|
if len(group_args) > 0:
|
|
args[group.title] = group_args
|
|
return args
|