You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

config.py 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import ast
  3. import copy
  4. import os
  5. import os.path as osp
  6. import platform
  7. import shutil
  8. import sys
  9. import tempfile
  10. import types
  11. import uuid
  12. from importlib import import_module
  13. from pathlib import Path
  14. from typing import Dict
  15. import addict
  16. from yapf.yapflib.yapf_api import FormatCode
  17. from maas_lib.utils.logger import get_logger
  18. from maas_lib.utils.pymod import (import_modules, import_modules_from_file,
  19. validate_py_syntax)
  20. if platform.system() == 'Windows':
  21. import regex as re # type: ignore
  22. else:
  23. import re # type: ignore
  24. logger = get_logger()
  25. BASE_KEY = '_base_'
  26. DELETE_KEY = '_delete_'
  27. DEPRECATION_KEY = '_deprecation_'
  28. RESERVED_KEYS = ['filename', 'text', 'pretty_text']
  29. class ConfigDict(addict.Dict):
  30. """ Dict which support get value through getattr
  31. Examples:
  32. >>> cdict = ConfigDict({'a':1232})
  33. >>> print(cdict.a)
  34. 1232
  35. """
  36. def __missing__(self, name):
  37. raise KeyError(name)
  38. def __getattr__(self, name):
  39. try:
  40. value = super(ConfigDict, self).__getattr__(name)
  41. except KeyError:
  42. ex = AttributeError(f"'{self.__class__.__name__}' object has no "
  43. f"attribute '{name}'")
  44. except Exception as e:
  45. ex = e
  46. else:
  47. return value
  48. raise ex
  49. class Config:
  50. """A facility for config and config files.
  51. It supports common file formats as configs: python/json/yaml. The interface
  52. is the same as a dict object and also allows access config values as
  53. attributes.
  54. Example:
  55. >>> cfg = Config(dict(a=1, b=dict(c=[1,2,3], d='dd')))
  56. >>> cfg.a
  57. 1
  58. >>> cfg.b
  59. {'c': [1, 2, 3], 'd': 'dd'}
  60. >>> cfg.b.d
  61. 'dd'
  62. >>> cfg = Config.from_file('configs/examples/config.json')
  63. >>> cfg.filename
  64. 'configs/examples/config.json'
  65. >>> cfg.b
  66. {'c': [1, 2, 3], 'd': 'dd'}
  67. >>> cfg = Config.from_file('configs/examples/config.py')
  68. >>> cfg.filename
  69. "configs/examples/config.py"
  70. >>> cfg = Config.from_file('configs/examples/config.yaml')
  71. >>> cfg.filename
  72. "configs/examples/config.yaml"
  73. """
  74. @staticmethod
  75. def _file2dict(filename):
  76. filename = osp.abspath(osp.expanduser(filename))
  77. if not osp.exists(filename):
  78. raise ValueError(f'File does not exists {filename}')
  79. fileExtname = osp.splitext(filename)[1]
  80. if fileExtname not in ['.py', '.json', '.yaml', '.yml']:
  81. raise IOError('Only py/yml/yaml/json type are supported now!')
  82. with tempfile.TemporaryDirectory() as tmp_cfg_dir:
  83. tmp_cfg_file = tempfile.NamedTemporaryFile(
  84. dir=tmp_cfg_dir, suffix=fileExtname)
  85. if platform.system() == 'Windows':
  86. tmp_cfg_file.close()
  87. tmp_cfg_name = osp.basename(tmp_cfg_file.name)
  88. shutil.copyfile(filename, tmp_cfg_file.name)
  89. if filename.endswith('.py'):
  90. module_nanme, mod = import_modules_from_file(
  91. osp.join(tmp_cfg_dir, tmp_cfg_name))
  92. cfg_dict = {}
  93. for name, value in mod.__dict__.items():
  94. if not name.startswith('__') and \
  95. not isinstance(value, types.ModuleType) and \
  96. not isinstance(value, types.FunctionType):
  97. cfg_dict[name] = value
  98. # delete imported module
  99. del sys.modules[module_nanme]
  100. elif filename.endswith(('.yml', '.yaml', '.json')):
  101. from maas_lib.fileio import load
  102. cfg_dict = load(tmp_cfg_file.name)
  103. # close temp file
  104. tmp_cfg_file.close()
  105. cfg_text = filename + '\n'
  106. with open(filename, 'r', encoding='utf-8') as f:
  107. # Setting encoding explicitly to resolve coding issue on windows
  108. cfg_text += f.read()
  109. return cfg_dict, cfg_text
  110. @staticmethod
  111. def from_file(filename):
  112. if isinstance(filename, Path):
  113. filename = str(filename)
  114. cfg_dict, cfg_text = Config._file2dict(filename)
  115. return Config(cfg_dict, cfg_text=cfg_text, filename=filename)
  116. @staticmethod
  117. def from_string(cfg_str, file_format):
  118. """Generate config from config str.
  119. Args:
  120. cfg_str (str): Config str.
  121. file_format (str): Config file format corresponding to the
  122. config str. Only py/yml/yaml/json type are supported now!
  123. Returns:
  124. :obj:`Config`: Config obj.
  125. """
  126. if file_format not in ['.py', '.json', '.yaml', '.yml']:
  127. raise IOError('Only py/yml/yaml/json type are supported now!')
  128. if file_format != '.py' and 'dict(' in cfg_str:
  129. # check if users specify a wrong suffix for python
  130. logger.warning(
  131. 'Please check "file_format", the file format may be .py')
  132. with tempfile.NamedTemporaryFile(
  133. 'w', encoding='utf-8', suffix=file_format,
  134. delete=False) as temp_file:
  135. temp_file.write(cfg_str)
  136. # on windows, previous implementation cause error
  137. # see PR 1077 for details
  138. cfg = Config.from_file(temp_file.name)
  139. os.remove(temp_file.name)
  140. return cfg
  141. def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
  142. if cfg_dict is None:
  143. cfg_dict = dict()
  144. elif not isinstance(cfg_dict, dict):
  145. raise TypeError('cfg_dict must be a dict, but '
  146. f'got {type(cfg_dict)}')
  147. for key in cfg_dict:
  148. if key in RESERVED_KEYS:
  149. raise KeyError(f'{key} is reserved for config file')
  150. if isinstance(filename, Path):
  151. filename = str(filename)
  152. super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict))
  153. super(Config, self).__setattr__('_filename', filename)
  154. if cfg_text:
  155. text = cfg_text
  156. elif filename:
  157. with open(filename, 'r') as f:
  158. text = f.read()
  159. else:
  160. text = ''
  161. super(Config, self).__setattr__('_text', text)
  162. @property
  163. def filename(self):
  164. return self._filename
  165. @property
  166. def text(self):
  167. return self._text
  168. @property
  169. def pretty_text(self):
  170. indent = 4
  171. def _indent(s_, num_spaces):
  172. s = s_.split('\n')
  173. if len(s) == 1:
  174. return s_
  175. first = s.pop(0)
  176. s = [(num_spaces * ' ') + line for line in s]
  177. s = '\n'.join(s)
  178. s = first + '\n' + s
  179. return s
  180. def _format_basic_types(k, v, use_mapping=False):
  181. if isinstance(v, str):
  182. v_str = f"'{v}'"
  183. else:
  184. v_str = str(v)
  185. if use_mapping:
  186. k_str = f"'{k}'" if isinstance(k, str) else str(k)
  187. attr_str = f'{k_str}: {v_str}'
  188. else:
  189. attr_str = f'{str(k)}={v_str}'
  190. attr_str = _indent(attr_str, indent)
  191. return attr_str
  192. def _format_list(k, v, use_mapping=False):
  193. # check if all items in the list are dict
  194. if all(isinstance(_, dict) for _ in v):
  195. v_str = '[\n'
  196. v_str += '\n'.join(
  197. f'dict({_indent(_format_dict(v_), indent)}),'
  198. for v_ in v).rstrip(',')
  199. if use_mapping:
  200. k_str = f"'{k}'" if isinstance(k, str) else str(k)
  201. attr_str = f'{k_str}: {v_str}'
  202. else:
  203. attr_str = f'{str(k)}={v_str}'
  204. attr_str = _indent(attr_str, indent) + ']'
  205. else:
  206. attr_str = _format_basic_types(k, v, use_mapping)
  207. return attr_str
  208. def _contain_invalid_identifier(dict_str):
  209. contain_invalid_identifier = False
  210. for key_name in dict_str:
  211. contain_invalid_identifier |= \
  212. (not str(key_name).isidentifier())
  213. return contain_invalid_identifier
  214. def _format_dict(input_dict, outest_level=False):
  215. r = ''
  216. s = []
  217. use_mapping = _contain_invalid_identifier(input_dict)
  218. if use_mapping:
  219. r += '{'
  220. for idx, (k, v) in enumerate(input_dict.items()):
  221. is_last = idx >= len(input_dict) - 1
  222. end = '' if outest_level or is_last else ','
  223. if isinstance(v, dict):
  224. v_str = '\n' + _format_dict(v)
  225. if use_mapping:
  226. k_str = f"'{k}'" if isinstance(k, str) else str(k)
  227. attr_str = f'{k_str}: dict({v_str}'
  228. else:
  229. attr_str = f'{str(k)}=dict({v_str}'
  230. attr_str = _indent(attr_str, indent) + ')' + end
  231. elif isinstance(v, list):
  232. attr_str = _format_list(k, v, use_mapping) + end
  233. else:
  234. attr_str = _format_basic_types(k, v, use_mapping) + end
  235. s.append(attr_str)
  236. r += '\n'.join(s)
  237. if use_mapping:
  238. r += '}'
  239. return r
  240. cfg_dict = self._cfg_dict.to_dict()
  241. text = _format_dict(cfg_dict, outest_level=True)
  242. # copied from setup.cfg
  243. yapf_style = dict(
  244. based_on_style='pep8',
  245. blank_line_before_nested_class_or_def=True,
  246. split_before_expression_after_opening_paren=True)
  247. text, _ = FormatCode(text, style_config=yapf_style, verify=True)
  248. return text
  249. def __repr__(self):
  250. return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}'
  251. def __len__(self):
  252. return len(self._cfg_dict)
  253. def __getattr__(self, name):
  254. return getattr(self._cfg_dict, name)
  255. def __getitem__(self, name):
  256. return self._cfg_dict.__getitem__(name)
  257. def __setattr__(self, name, value):
  258. if isinstance(value, dict):
  259. value = ConfigDict(value)
  260. self._cfg_dict.__setattr__(name, value)
  261. def __setitem__(self, name, value):
  262. if isinstance(value, dict):
  263. value = ConfigDict(value)
  264. self._cfg_dict.__setitem__(name, value)
  265. def __iter__(self):
  266. return iter(self._cfg_dict)
  267. def __getstate__(self):
  268. return (self._cfg_dict, self._filename, self._text)
  269. def __copy__(self):
  270. cls = self.__class__
  271. other = cls.__new__(cls)
  272. other.__dict__.update(self.__dict__)
  273. return other
  274. def __deepcopy__(self, memo):
  275. cls = self.__class__
  276. other = cls.__new__(cls)
  277. memo[id(self)] = other
  278. for key, value in self.__dict__.items():
  279. super(Config, other).__setattr__(key, copy.deepcopy(value, memo))
  280. return other
  281. def __setstate__(self, state):
  282. _cfg_dict, _filename, _text = state
  283. super(Config, self).__setattr__('_cfg_dict', _cfg_dict)
  284. super(Config, self).__setattr__('_filename', _filename)
  285. super(Config, self).__setattr__('_text', _text)
  286. def dump(self, file: str = None):
  287. """Dumps config into a file or returns a string representation of the
  288. config.
  289. If a file argument is given, saves the config to that file using the
  290. format defined by the file argument extension.
  291. Otherwise, returns a string representing the config. The formatting of
  292. this returned string is defined by the extension of `self.filename`. If
  293. `self.filename` is not defined, returns a string representation of a
  294. dict (lowercased and using ' for strings).
  295. Examples:
  296. >>> cfg_dict = dict(item1=[1, 2], item2=dict(a=0),
  297. ... item3=True, item4='test')
  298. >>> cfg = Config(cfg_dict=cfg_dict)
  299. >>> dump_file = "a.py"
  300. >>> cfg.dump(dump_file)
  301. Args:
  302. file (str, optional): Path of the output file where the config
  303. will be dumped. Defaults to None.
  304. """
  305. from maas_lib.fileio import dump
  306. cfg_dict = super(Config, self).__getattribute__('_cfg_dict').to_dict()
  307. if file is None:
  308. if self.filename is None or self.filename.endswith('.py'):
  309. return self.pretty_text
  310. else:
  311. file_format = self.filename.split('.')[-1]
  312. return dump(cfg_dict, file_format=file_format)
  313. elif file.endswith('.py'):
  314. with open(file, 'w', encoding='utf-8') as f:
  315. f.write(self.pretty_text)
  316. else:
  317. file_format = file.split('.')[-1]
  318. return dump(cfg_dict, file=file, file_format=file_format)
  319. def merge_from_dict(self, options, allow_list_keys=True):
  320. """Merge list into cfg_dict.
  321. Merge the dict parsed by MultipleKVAction into this cfg.
  322. Examples:
  323. >>> options = {'model.backbone.depth': 50,
  324. ... 'model.backbone.with_cp':True}
  325. >>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet'))))
  326. >>> cfg.merge_from_dict(options)
  327. >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
  328. >>> assert cfg_dict == dict(
  329. ... model=dict(backbone=dict(depth=50, with_cp=True)))
  330. >>> # Merge list element
  331. >>> cfg = Config(dict(pipeline=[
  332. ... dict(type='Resize'), dict(type='RandomDistortion')]))
  333. >>> options = dict(pipeline={'0': dict(type='MyResize')})
  334. >>> cfg.merge_from_dict(options, allow_list_keys=True)
  335. >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
  336. >>> assert cfg_dict == dict(pipeline=[
  337. ... dict(type='MyResize'), dict(type='RandomDistortion')])
  338. Args:
  339. options (dict): dict of configs to merge from.
  340. allow_list_keys (bool): If True, int string keys (e.g. '0', '1')
  341. are allowed in ``options`` and will replace the element of the
  342. corresponding index in the config if the config is a list.
  343. Default: True.
  344. """
  345. option_cfg_dict = {}
  346. for full_key, v in options.items():
  347. d = option_cfg_dict
  348. key_list = full_key.split('.')
  349. for subkey in key_list[:-1]:
  350. d.setdefault(subkey, ConfigDict())
  351. d = d[subkey]
  352. subkey = key_list[-1]
  353. d[subkey] = v
  354. cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
  355. super(Config, self).__setattr__(
  356. '_cfg_dict',
  357. Config._merge_a_into_b(
  358. option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys))
  359. def to_dict(self) -> Dict:
  360. """ Convert Config object to python dict
  361. """
  362. return self._cfg_dict.to_dict()
  363. def to_args(self, parse_fn, use_hyphen=True):
  364. """ Convert config obj to args using parse_fn
  365. Args:
  366. parse_fn: a function object, which takes args as input,
  367. such as ['--foo', 'FOO'] and return parsed args, an
  368. example is given as follows
  369. including literal blocks::
  370. def parse_fn(args):
  371. parser = argparse.ArgumentParser(prog='PROG')
  372. parser.add_argument('-x')
  373. parser.add_argument('--foo')
  374. return parser.parse_args(args)
  375. use_hyphen (bool, optional): if set true, hyphen in keyname
  376. will be converted to underscore
  377. Return:
  378. args: arg object parsed by argparse.ArgumentParser
  379. """
  380. args = []
  381. for k, v in self._cfg_dict.items():
  382. arg_name = f'--{k}'
  383. if use_hyphen:
  384. arg_name = arg_name.replace('_', '-')
  385. if isinstance(v, bool) and v:
  386. args.append(arg_name)
  387. elif isinstance(v, (int, str, float)):
  388. args.append(arg_name)
  389. args.append(str(v))
  390. elif isinstance(v, list):
  391. args.append(arg_name)
  392. assert isinstance(v, (int, str, float, bool)), 'Element type in list ' \
  393. f'is expected to be either int,str,float, but got type {v[0]}'
  394. args.append(str(v))
  395. else:
  396. raise ValueError(
  397. 'type in config file which supported to be '
  398. 'converted to args should be either bool, '
  399. f'int, str, float or list of them but got type {v}')
  400. return parse_fn(args)

致力于通过开放的社区合作,开源AI模型以及相关创新技术,推动基于模型即服务的生态繁荣发展