You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

algo_parameter_config.py 7.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Configuration of parameters for strategy-searching algorithm in auto_parallel"""
  16. import threading
  17. from mindspore._c_expression import CostModelContext
  18. from mindspore._checkparam import args_type_check
  19. __all__ = ["get_algo_parameters", "reset_algo_parameters", "set_algo_parameters"]
  20. class _AlgoParameterConfig():
  21. """
  22. _AlgoParameterConfig is the configuration of setting parameters used in th algorithm.
  23. Note:
  24. Creating a config through instantiating _AlgoParameterConfig object is not recommended.
  25. Use algo_parameter_config() to get the configuration since _AlgoParameterConfig is singleton.
  26. """
  27. _instance = None
  28. _instance_lock = threading.Lock()
  29. def __init__(self):
  30. self._config_handle = CostModelContext.get_instance()
  31. def check_config_handle(self):
  32. """
  33. Check config handle.
  34. Raises:
  35. ValueError: If the config handle is none.
  36. """
  37. if self._config_handle is None:
  38. raise ValueError("Config handle is none!!!")
  39. def set_fully_use_devices(self, not_fully):
  40. self.check_config_handle()
  41. self._config_handle.set_fully_use_devices(not_fully)
  42. def get_fully_use_devices(self):
  43. self.check_config_handle()
  44. return self._config_handle.get_fully_use_devices()
  45. def set_elementwise_op_strategy_follow(self, element_strategy_follow):
  46. self.check_config_handle()
  47. self._config_handle.set_elementwise_op_strategy_follow(element_strategy_follow)
  48. def get_elementwise_op_strategy_follow(self):
  49. self.check_config_handle()
  50. return self._config_handle.get_elementwise_op_strategy_follow()
  51. def set_tensor_slice_align_enable(self, align_enable):
  52. self.check_config_handle()
  53. self._config_handle.set_tensor_slice_align_enable(align_enable)
  54. def get_tensor_slice_align_enable(self):
  55. self.check_config_handle()
  56. return self._config_handle.get_tensor_slice_align_enable()
  57. def set_tensor_slice_align_size(self, align_size):
  58. """
  59. Set tensor slice align size.
  60. Args:
  61. align_size (int): The minimum tensor slice shape.
  62. Raises:
  63. ValueError: If align_size is not in [1, 1024].
  64. """
  65. self.check_config_handle()
  66. if align_size < 1 or align_size > 1024:
  67. raise ValueError('Align_size must be in [1, 1024], but got {}'.format(align_size))
  68. self._config_handle.set_tensor_slice_align_size(align_size)
  69. def get_tensor_slice_align_size(self):
  70. self.check_config_handle()
  71. return self._config_handle.get_tensor_slice_align_size()
  72. def set_dp_algo_enable_approxi(self, enable_flag):
  73. self.check_config_handle()
  74. self._config_handle.set_dp_algo_enable_approxi(enable_flag)
  75. def get_dp_algo_enable_approxi(self):
  76. self.check_config_handle()
  77. return self._config_handle.get_dp_algo_enable_approxi()
  78. def set_dp_algo_approxi_epsilon(self, epsilon):
  79. self.check_config_handle()
  80. self._config_handle.set_dp_algo_approxi_epsilon(epsilon)
  81. def get_dp_algo_approxi_epsilon(self):
  82. self.check_config_handle()
  83. return self._config_handle.get_dp_algo_approxi_epsilon()
  84. def reset_algo_parameters(self):
  85. self.check_config_handle()
  86. self._config_handle.reset_algo_parameters()
  87. _g_algo_parameter_config = None
  88. def _algo_parameter_config():
  89. """
  90. Get the global _g_algo_parameter_config. If it is not created, create a new one.
  91. Returns:
  92. The global _g_algo_parameter_config.
  93. """
  94. global _g_algo_parameter_config
  95. if _g_algo_parameter_config is None:
  96. _g_algo_parameter_config = _AlgoParameterConfig()
  97. return _g_algo_parameter_config
  98. set_algo_parameters_config_func_map = {
  99. "fully_use_devices": _algo_parameter_config().set_fully_use_devices,
  100. "elementwise_op_strategy_follow": _algo_parameter_config().set_elementwise_op_strategy_follow,
  101. "tensor_slice_align_enable": _algo_parameter_config().set_tensor_slice_align_enable,
  102. "tensor_slice_align_size": _algo_parameter_config().set_tensor_slice_align_size,
  103. "enable_algo_approxi": _algo_parameter_config().set_dp_algo_enable_approxi,
  104. "algo_approxi_epsilon": _algo_parameter_config().set_dp_algo_approxi_epsilon}
  105. get_algo_parameters_config_func_map = {
  106. "fully_use_devices": _algo_parameter_config().get_fully_use_devices,
  107. "elementwise_op_strategy_follow": _algo_parameter_config().get_elementwise_op_strategy_follow,
  108. "tensor_slice_align_enable": _algo_parameter_config().get_tensor_slice_align_enable,
  109. "tensor_slice_align_size": _algo_parameter_config().get_tensor_slice_align_size,
  110. "enable_algo_approxi": _algo_parameter_config().get_dp_algo_enable_approxi,
  111. "algo_approxi_epsilon": _algo_parameter_config().get_dp_algo_approxi_epsilon}
  112. @args_type_check(tensor_slice_align_enable=bool, tensor_slice_align_size=int,
  113. fully_use_devices=bool, elementwise_op_strategy_follow=bool,
  114. enable_algo_approxi=bool, algo_approxi_epsilon=float)
  115. def set_algo_parameters(**kwargs):
  116. """
  117. Set algo parameter config.
  118. Note:
  119. The attribute name is required.
  120. Args:
  121. tensor_slice_align_enable (bool): Whether to check the shape of tensor slice of MatMul. Default: False
  122. tensor_slice_align_size (int): The minimum tensor slice shape of MatMul, the value must be in [1, 1024].
  123. Default: 16
  124. fully_use_devices (bool): Whether ONLY generating strategies that fully use all available devices. Default: True
  125. elementwise_op_strategy_follow (bool): Whether the elementwise operator has the same strategies as its
  126. subsequent operators. Default: False
  127. enable_algo_approxi (bool): Whether to enable the approximation in the DP algorithms.
  128. algo_approxi_epsilon (float): The epsilon value used int the approximation DP algorithm.
  129. Raises:
  130. ValueError: If context keyword is not recognized.
  131. """
  132. for key, value in kwargs.items():
  133. if key not in set_algo_parameters_config_func_map:
  134. raise ValueError("Set context keyword %s is not recognized!" % key)
  135. set_func = set_algo_parameters_config_func_map[key]
  136. set_func(value)
  137. def get_algo_parameters(attr_key):
  138. """
  139. Get algo parameter config attributes.
  140. Note:
  141. Returns the specified attribute value.
  142. Args:
  143. attr_key (str): The key of the attribute.
  144. Raises:
  145. ValueError: If context keyword is not recognized.
  146. """
  147. if attr_key not in get_algo_parameters_config_func_map:
  148. raise ValueError("Get context keyword %s is not recognized!" % attr_key)
  149. get_func = get_algo_parameters_config_func_map[attr_key]
  150. return get_func()
  151. def reset_algo_parameters():
  152. """Reset algo parameter attributes."""
  153. _algo_parameter_config().reset_algo_parameters()