|
- """
- Parse argument.
- """
-
- import argparse
-
- import json
-
-
- def str2bool(v):
- if v.lower() in ('yes', 'true', 't', 'y', '1'):
- return True
- elif v.lower() in ('no', 'false', 'f', 'n', '0'):
- return False
- else:
- raise argparse.ArgumentTypeError('Unsupported value encountered.')
-
-
- class HParams(dict):
- """ Hyper-parameters class
-
- Store hyper-parameters in training / infer / ... scripts.
- """
-
- def __getattr__(self, name):
- if name in self.keys():
- return self[name]
- for v in self.values():
- if isinstance(v, HParams):
- if name in v:
- return v[name]
- raise AttributeError(f"'HParams' object has no attribute '{name}'")
-
- def __setattr__(self, name, value):
- self[name] = value
-
- def save(self, filename):
- with open(filename, 'w', encoding='utf-8') as fp:
- json.dump(self, fp, ensure_ascii=False, indent=4, sort_keys=False)
-
- def load(self, filename):
- with open(filename, 'r', encoding='utf-8') as fp:
- params_dict = json.load(fp)
- for k, v in params_dict.items():
- if isinstance(v, dict):
- self[k].update(HParams(v))
- else:
- self[k] = v
-
-
- def parse_args(parser):
- """ Parse hyper-parameters from cmdline. """
- parsed = parser.parse_args()
- args = HParams()
- optional_args = parser._action_groups[1]
- for action in optional_args._group_actions[1:]:
- arg_name = action.dest
- args[arg_name] = getattr(parsed, arg_name)
- for group in parser._action_groups[2:]:
- group_args = HParams()
- for action in group._group_actions:
- arg_name = action.dest
- group_args[arg_name] = getattr(parsed, arg_name)
- if len(group_args) > 0:
- args[group.title] = group_args
- return args
|