|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Configuration of parameters for strategy-searching algorithm in auto_parallel"""
-
- import threading
- from mindspore._c_expression import CostModelContext
- from mindspore._checkparam import args_type_check
-
- __all__ = ["get_algo_parameters", "reset_algo_parameters", "set_algo_parameters"]
-
-
- class _AlgoParameterConfig():
- """
- _AlgoParameterConfig is the configuration of setting parameters used in th algorithm.
-
- Note:
- Creating a config through instantiating _AlgoParameterConfig object is not recommended.
- Use algo_parameter_config() to get the configuration since _AlgoParameterConfig is singleton.
- """
- _instance = None
- _instance_lock = threading.Lock()
-
- def __init__(self):
- self._config_handle = CostModelContext.get_instance()
-
- def check_config_handle(self):
- """
- Check config handle.
-
- Raises:
- ValueError: If the config handle is none.
- """
- if self._config_handle is None:
- raise ValueError("Config handle is none!!!")
-
- def set_fully_use_devices(self, not_fully):
- self.check_config_handle()
- self._config_handle.set_fully_use_devices(not_fully)
-
- def get_fully_use_devices(self):
- self.check_config_handle()
- return self._config_handle.get_fully_use_devices()
-
- def set_elementwise_op_strategy_follow(self, element_strategy_follow):
- self.check_config_handle()
- self._config_handle.set_elementwise_op_strategy_follow(element_strategy_follow)
-
- def get_elementwise_op_strategy_follow(self):
- self.check_config_handle()
- return self._config_handle.get_elementwise_op_strategy_follow()
-
- def set_tensor_slice_align_enable(self, align_enable):
- self.check_config_handle()
- self._config_handle.set_tensor_slice_align_enable(align_enable)
-
- def get_tensor_slice_align_enable(self):
- self.check_config_handle()
- return self._config_handle.get_tensor_slice_align_enable()
-
- def set_tensor_slice_align_size(self, align_size):
- """
- Set tensor slice align size.
-
- Args:
- align_size (int): The minimum tensor slice shape.
-
- Raises:
- ValueError: If align_size is not in [1, 1024].
- """
- self.check_config_handle()
- if align_size < 1 or align_size > 1024:
- raise ValueError('Align_size must be in [1, 1024], but got {}'.format(align_size))
- self._config_handle.set_tensor_slice_align_size(align_size)
-
- def get_tensor_slice_align_size(self):
- self.check_config_handle()
- return self._config_handle.get_tensor_slice_align_size()
-
- def reset_algo_parameters(self):
- self.check_config_handle()
- self._config_handle.reset_algo_parameters()
-
-
- _g_algo_parameter_config = None
-
-
- def _algo_parameter_config():
- """
- Get the global _g_algo_parameter_config. If it is not created, create a new one.
-
- Returns:
- The global _g_algo_parameter_config.
- """
- global _g_algo_parameter_config
- if _g_algo_parameter_config is None:
- _g_algo_parameter_config = _AlgoParameterConfig()
- return _g_algo_parameter_config
-
-
- set_algo_parameters_config_func_map = {
- "fully_use_devices": _algo_parameter_config().set_fully_use_devices,
- "elementwise_op_strategy_follow": _algo_parameter_config().set_elementwise_op_strategy_follow,
- "tensor_slice_align_enable": _algo_parameter_config().set_tensor_slice_align_enable,
- "tensor_slice_align_size": _algo_parameter_config().set_tensor_slice_align_size}
-
-
- get_algo_parameters_config_func_map = {
- "fully_use_devices": _algo_parameter_config().get_fully_use_devices,
- "elementwise_op_strategy_follow": _algo_parameter_config().get_elementwise_op_strategy_follow,
- "tensor_slice_align_enable": _algo_parameter_config().get_tensor_slice_align_enable,
- "tensor_slice_align_size": _algo_parameter_config().get_tensor_slice_align_size}
-
-
- @args_type_check(tensor_slice_align_enable=bool, tensor_slice_align_size=int,
- fully_use_devices=bool, elementwise_op_strategy_follow=bool)
- def set_algo_parameters(**kwargs):
- """
- Set algo parameter config.
-
- Note:
- Attribute name is needed.
-
- Args:
- tensor_slice_align_enable (bool): Whether checking tensor slice shape for MatMul. Default: False
- tensor_slice_align_size (int): The minimum tensor slice shape of MatMul, the value must be in [1, 1024].
- Default: 16
- fully_use_devices (bool): Whether ONLY generating strategies that fully use all available devices. Default: True
- elementwise_op_strategy_follow (bool): Whether the elementwise operator have the same strategies as its
- subsequent operators. Default: False
-
- Raises:
- ValueError: If context keyword is not recognized.
- """
- for key, value in kwargs.items():
- if key not in set_algo_parameters_config_func_map:
- raise ValueError("Set context keyword %s is not recognized!" % key)
- set_func = set_algo_parameters_config_func_map[key]
- set_func(value)
-
-
- def get_algo_parameters(attr_key):
- """
- Get algo parameter config attributes.
-
- Note:
- Return value according to the attribute value.
-
- Args:
- attr_key (str): The key of the attribute.
-
- Raises:
- ValueError: If context keyword is not recognized.
- """
- if attr_key not in get_algo_parameters_config_func_map:
- raise ValueError("Get context keyword %s is not recognized!" % attr_key)
- get_func = get_algo_parameters_config_func_map[attr_key]
- return get_func()
-
-
- def reset_algo_parameters():
- """Reset algo parameter attributes."""
- _algo_parameter_config().reset_algo_parameters()
|