Files
EasyFace/modelscope/utils/nlp/space/args.py
2023-03-02 11:17:26 +08:00

63 lines
1.9 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
import json
def str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Unsupported value encountered.')
class HParams(dict):
""" Hyper-parameters class
Store hyper-parameters in training / infer / ... scripts.
"""
def __getattr__(self, name):
if name in self.keys():
return self[name]
for v in self.values():
if isinstance(v, HParams):
if name in v:
return v[name]
raise AttributeError(f"'HParams' object has no attribute '{name}'")
def __setattr__(self, name, value):
self[name] = value
def save(self, filename):
with open(filename, 'w', encoding='utf-8') as fp:
json.dump(self, fp, ensure_ascii=False, indent=4, sort_keys=False)
def load(self, filename):
with open(filename, 'r', encoding='utf-8') as fp:
params_dict = json.load(fp)
for k, v in params_dict.items():
if isinstance(v, dict):
self[k].update(HParams(v))
else:
self[k] = v
def parse_args(parser):
""" Parse hyper-parameters from cmdline. """
parsed = parser.parse_args()
args = HParams()
optional_args = parser._action_groups[1]
for action in optional_args._group_actions[1:]:
arg_name = action.dest
args[arg_name] = getattr(parsed, arg_name)
for group in parser._action_groups[2:]:
group_args = HParams()
for action in group._group_actions:
arg_name = action.dest
group_args[arg_name] = getattr(parsed, arg_name)
if len(group_args) > 0:
args[group.title] = group_args
return args