You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

algo_parameter_config.py 6.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Configuration of parameters for strategy-searching algorithm in auto_parallel"""
  16. import threading
  17. from mindspore._c_expression import CostModelContext
  18. from mindspore._checkparam import args_type_check
  19. __all__ = ["get_algo_parameters", "reset_algo_parameters", "set_algo_parameters"]
  20. class _AlgoParameterConfig():
  21. """
  22. _AlgoParameterConfig is the configuration of setting parameters used in th algorithm.
  23. Note:
  24. Creating a config through instantiating _AlgoParameterConfig object is not recommended.
  25. Use algo_parameter_config() to get the configuration since _AlgoParameterConfig is singleton.
  26. """
  27. _instance = None
  28. _instance_lock = threading.Lock()
  29. def __init__(self):
  30. self._config_handle = CostModelContext.get_instance()
  31. def check_config_handle(self):
  32. """
  33. Check config handle.
  34. Raises:
  35. ValueError: If the config handle is none.
  36. """
  37. if self._config_handle is None:
  38. raise ValueError("Config handle is none!!!")
  39. def set_simplify_cal(self, simplify_cal):
  40. self.check_config_handle()
  41. self._config_handle.set_simplify_cal(simplify_cal)
  42. def get_simplify_cal(self):
  43. self.check_config_handle()
  44. return self._config_handle.get_simplify_cal()
  45. def set_fully_use_devices(self, not_fully):
  46. self.check_config_handle()
  47. self._config_handle.set_fully_use_devices(not_fully)
  48. def get_fully_use_devices(self):
  49. self.check_config_handle()
  50. return self._config_handle.get_fully_use_devices()
  51. def set_elementwise_op_strategy_follow(self, element_strategy_follow):
  52. self.check_config_handle()
  53. self._config_handle.set_elementwise_op_strategy_follow(element_strategy_follow)
  54. def get_elementwise_op_strategy_follow(self):
  55. self.check_config_handle()
  56. return self._config_handle.get_elementwise_op_strategy_follow()
  57. def set_tensor_slice_align_enable(self, align_enable):
  58. self.check_config_handle()
  59. self._config_handle.set_tensor_slice_align_enable(align_enable)
  60. def get_tensor_slice_align_enable(self):
  61. self.check_config_handle()
  62. return self._config_handle.get_tensor_slice_align_enable()
  63. def set_tensor_slice_align_size(self, align_size):
  64. """
  65. Set tensor slice align size.
  66. Args:
  67. align_size (int): The minimum tensor slice shape.
  68. Raises:
  69. ValueError: If align_size is not in [1, 1024].
  70. """
  71. self.check_config_handle()
  72. if align_size < 1 or align_size > 1024:
  73. raise ValueError('Align_size must be in [1, 1024], but got {}'.format(align_size))
  74. self._config_handle.set_tensor_slice_align_size(align_size)
  75. def get_tensor_slice_align_size(self):
  76. self.check_config_handle()
  77. return self._config_handle.get_tensor_slice_align_size()
  78. def reset_algo_parameters(self):
  79. self.check_config_handle()
  80. self._config_handle.reset_algo_parameters()
  81. _g_algo_parameter_config = None
  82. def _algo_parameter_config():
  83. """
  84. Get the global _g_algo_parameter_config. If it is not created, create a new one.
  85. Returns:
  86. The global _g_algo_parameter_config.
  87. """
  88. global _g_algo_parameter_config
  89. if _g_algo_parameter_config is None:
  90. _g_algo_parameter_config = _AlgoParameterConfig()
  91. return _g_algo_parameter_config
  92. set_algo_parameters_config_func_map = {
  93. "simplify_cal": _algo_parameter_config().set_simplify_cal,
  94. "fully_use_devices": _algo_parameter_config().set_fully_use_devices,
  95. "elementwise_op_strategy_follow": _algo_parameter_config().set_elementwise_op_strategy_follow,
  96. "tensor_slice_align_enable": _algo_parameter_config().set_tensor_slice_align_enable,
  97. "tensor_slice_align_size": _algo_parameter_config().set_tensor_slice_align_size}
  98. get_algo_parameters_config_func_map = {
  99. "simplify_cal": _algo_parameter_config().get_simplify_cal,
  100. "fully_use_devices": _algo_parameter_config().get_fully_use_devices,
  101. "elementwise_op_strategy_follow": _algo_parameter_config().get_elementwise_op_strategy_follow,
  102. "tensor_slice_align_enable": _algo_parameter_config().get_tensor_slice_align_enable,
  103. "tensor_slice_align_size": _algo_parameter_config().get_tensor_slice_align_size}
  104. @args_type_check(simplify_cal=bool, tensor_slice_align_enable=bool, tensor_slice_align_size=int,
  105. fully_use_devices=bool, elementwise_op_strategy_follow=bool)
  106. def set_algo_parameters(**kwargs):
  107. """
  108. Set algo parameter config.
  109. Note:
  110. Attribute name is needed.
  111. Args:
  112. simplify_cal (bool): Whether simplifying calculations in strategy-searching algorithm. Default: True
  113. tensor_slice_align_enable (bool): Whether checking tensor slice shape. Default: False
  114. tensor_slice_align_size (int): The minimum tensor slice shape, the value must be in [1, 1024]. Default: 16
  115. fully_use_devices (bool): Whether generating strategies that fully use all available devices. Default: True
  116. elementwise_op_strategy_follow (bool): Whether the elementwise operator have the same strategies as its
  117. subsequent operators. Default: False
  118. Raises:
  119. ValueError: If context keyword is not recognized.
  120. """
  121. for key, value in kwargs.items():
  122. if key not in set_algo_parameters_config_func_map:
  123. raise ValueError("Set context keyword %s is not recognized!" % key)
  124. set_func = set_algo_parameters_config_func_map[key]
  125. set_func(value)
  126. def get_algo_parameters(attr_key):
  127. """
  128. Get algo parameter config attributes.
  129. Note:
  130. Return value according to the attribute value.
  131. Args:
  132. attr_key (str): The key of the attribute.
  133. Raises:
  134. ValueError: If context keyword is not recognized.
  135. """
  136. if attr_key not in get_algo_parameters_config_func_map:
  137. raise ValueError("Get context keyword %s is not recognized!" % attr_key)
  138. get_func = get_algo_parameters_config_func_map[attr_key]
  139. return get_func()
  140. def reset_algo_parameters():
  141. """Reset algo parameter attributes."""
  142. _algo_parameter_config().reset_algo_parameters()