You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

hyper_config.py 3.5 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Hyper config."""
  16. import json
  17. import os
  18. from attrdict import AttrDict
  19. from mindinsight.optimizer.common.constants import HYPER_CONFIG_ENV_NAME
  20. from mindinsight.optimizer.common.exceptions import HyperConfigError
  21. _HYPER_CONFIG_LEN_LIMIT = 100000
  22. class HyperConfig:
  23. """
  24. Hyper config.
  25. Init hyper config:
  26. >>> hyper_config = HyperConfig()
  27. Get suggest params:
  28. >>> param_obj = hyper_config.params
  29. >>> learning_rate = params.learning_rate
  30. Get summary dir:
  31. >>> summary_dir = hyper_config.summary_dir
  32. Record by SummaryCollector:
  33. >>> summary_cb = SummaryCollector(summary_dir)
  34. """
  35. def __init__(self):
  36. self._init_validate_hyper_config()
  37. def _init_validate_hyper_config(self):
  38. """Init and validate hyper config."""
  39. hyper_config = os.environ.get(HYPER_CONFIG_ENV_NAME)
  40. if hyper_config is None:
  41. raise HyperConfigError("Hyper config is not in system environment.")
  42. if len(hyper_config) > _HYPER_CONFIG_LEN_LIMIT:
  43. raise HyperConfigError("Hyper config is too long. The length limit is %s, the length of "
  44. "hyper_config is %s." % (_HYPER_CONFIG_LEN_LIMIT, len(hyper_config)))
  45. try:
  46. hyper_config = json.loads(hyper_config)
  47. except TypeError as exc:
  48. raise HyperConfigError("Hyper config type error. detail: %s." % str(exc))
  49. except Exception as exc:
  50. raise HyperConfigError("Hyper config decode error. detail: %s." % str(exc))
  51. self._validate_hyper_config(hyper_config)
  52. self._summary_dir = hyper_config.get('summary_dir')
  53. self._param_obj = AttrDict(hyper_config.get('params'))
  54. def _validate_hyper_config(self, hyper_config):
  55. """Validate hyper config."""
  56. for key in ['summary_dir', 'params']:
  57. if key not in hyper_config:
  58. raise HyperConfigError("%r must exist in hyper_config." % key)
  59. # validate summary_dir
  60. summary_dir = hyper_config.get('summary_dir')
  61. if not isinstance(summary_dir, str):
  62. raise HyperConfigError("The 'summary_dir' should be string.")
  63. hyper_config['summary_dir'] = os.path.realpath(summary_dir)
  64. # validate params
  65. params = hyper_config.get('params')
  66. if not isinstance(params, dict):
  67. raise HyperConfigError("'params' is not a dict.")
  68. for key, value in params.items():
  69. if not isinstance(value, (int, float)):
  70. raise HyperConfigError("The value of %r is not integer or float." % key)
  71. @property
  72. def params(self):
  73. """Get params."""
  74. return self._param_obj
  75. @property
  76. def summary_dir(self):
  77. """Get train summary dir path."""
  78. return self._summary_dir