You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

algo_parameter_config.py 6.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Configuration of parameters for strategy-searching algorithm in auto_parallel"""
  16. import threading
  17. from mindspore._c_expression import CostModelContext
  18. from mindspore._checkparam import args_type_check
  19. __all__ = ["get_algo_parameters", "reset_algo_parameters", "set_algo_parameters"]
  20. class _AlgoParameterConfig():
  21. """
  22. _AlgoParameterConfig is the configuration of setting parameters used in th algorithm.
  23. Note:
  24. Creating a config through instantiating _AlgoParameterConfig object is not recommended.
  25. Use algo_parameter_config() to get the configuration since _AlgoParameterConfig is singleton.
  26. """
  27. _instance = None
  28. _instance_lock = threading.Lock()
  29. def __init__(self):
  30. self._config_handle = CostModelContext.get_instance()
  31. def check_config_handle(self):
  32. """
  33. Check config handle.
  34. Raises:
  35. ValueError: If the config handle is none.
  36. """
  37. if self._config_handle is None:
  38. raise ValueError("Config handle is none!!!")
  39. def set_fully_use_devices(self, not_fully):
  40. self.check_config_handle()
  41. self._config_handle.set_fully_use_devices(not_fully)
  42. def get_fully_use_devices(self):
  43. self.check_config_handle()
  44. return self._config_handle.get_fully_use_devices()
  45. def set_elementwise_op_strategy_follow(self, element_strategy_follow):
  46. self.check_config_handle()
  47. self._config_handle.set_elementwise_op_strategy_follow(element_strategy_follow)
  48. def get_elementwise_op_strategy_follow(self):
  49. self.check_config_handle()
  50. return self._config_handle.get_elementwise_op_strategy_follow()
  51. def set_tensor_slice_align_enable(self, align_enable):
  52. self.check_config_handle()
  53. self._config_handle.set_tensor_slice_align_enable(align_enable)
  54. def get_tensor_slice_align_enable(self):
  55. self.check_config_handle()
  56. return self._config_handle.get_tensor_slice_align_enable()
  57. def set_tensor_slice_align_size(self, align_size):
  58. """
  59. Set tensor slice align size.
  60. Args:
  61. align_size (int): The minimum tensor slice shape.
  62. Raises:
  63. ValueError: If align_size is not in [1, 1024].
  64. """
  65. self.check_config_handle()
  66. if align_size < 1 or align_size > 1024:
  67. raise ValueError('Align_size must be in [1, 1024], but got {}'.format(align_size))
  68. self._config_handle.set_tensor_slice_align_size(align_size)
  69. def get_tensor_slice_align_size(self):
  70. self.check_config_handle()
  71. return self._config_handle.get_tensor_slice_align_size()
  72. def reset_algo_parameters(self):
  73. self.check_config_handle()
  74. self._config_handle.reset_algo_parameters()
  75. _g_algo_parameter_config = None
  76. def _algo_parameter_config():
  77. """
  78. Get the global _g_algo_parameter_config. If it is not created, create a new one.
  79. Returns:
  80. The global _g_algo_parameter_config.
  81. """
  82. global _g_algo_parameter_config
  83. if _g_algo_parameter_config is None:
  84. _g_algo_parameter_config = _AlgoParameterConfig()
  85. return _g_algo_parameter_config
  86. set_algo_parameters_config_func_map = {
  87. "fully_use_devices": _algo_parameter_config().set_fully_use_devices,
  88. "elementwise_op_strategy_follow": _algo_parameter_config().set_elementwise_op_strategy_follow,
  89. "tensor_slice_align_enable": _algo_parameter_config().set_tensor_slice_align_enable,
  90. "tensor_slice_align_size": _algo_parameter_config().set_tensor_slice_align_size}
  91. get_algo_parameters_config_func_map = {
  92. "fully_use_devices": _algo_parameter_config().get_fully_use_devices,
  93. "elementwise_op_strategy_follow": _algo_parameter_config().get_elementwise_op_strategy_follow,
  94. "tensor_slice_align_enable": _algo_parameter_config().get_tensor_slice_align_enable,
  95. "tensor_slice_align_size": _algo_parameter_config().get_tensor_slice_align_size}
  96. @args_type_check(tensor_slice_align_enable=bool, tensor_slice_align_size=int,
  97. fully_use_devices=bool, elementwise_op_strategy_follow=bool)
  98. def set_algo_parameters(**kwargs):
  99. """
  100. Set algo parameter config.
  101. Note:
  102. Attribute name is needed.
  103. Args:
  104. tensor_slice_align_enable (bool): Whether checking tensor slice shape for MatMul. Default: False
  105. tensor_slice_align_size (int): The minimum tensor slice shape of MatMul, the value must be in [1, 1024].
  106. Default: 16
  107. fully_use_devices (bool): Whether ONLY generating strategies that fully use all available devices. Default: True
  108. elementwise_op_strategy_follow (bool): Whether the elementwise operator have the same strategies as its
  109. subsequent operators. Default: False
  110. Raises:
  111. ValueError: If context keyword is not recognized.
  112. """
  113. for key, value in kwargs.items():
  114. if key not in set_algo_parameters_config_func_map:
  115. raise ValueError("Set context keyword %s is not recognized!" % key)
  116. set_func = set_algo_parameters_config_func_map[key]
  117. set_func(value)
  118. def get_algo_parameters(attr_key):
  119. """
  120. Get algo parameter config attributes.
  121. Note:
  122. Return value according to the attribute value.
  123. Args:
  124. attr_key (str): The key of the attribute.
  125. Raises:
  126. ValueError: If context keyword is not recognized.
  127. """
  128. if attr_key not in get_algo_parameters_config_func_map:
  129. raise ValueError("Get context keyword %s is not recognized!" % attr_key)
  130. get_func = get_algo_parameters_config_func_map[attr_key]
  131. return get_func()
  132. def reset_algo_parameters():
  133. """Reset algo parameter attributes."""
  134. _algo_parameter_config().reset_algo_parameters()