| @@ -0,0 +1,3 @@ | |||||
| from .core._config import * | |||||
| __import__("mprop").init() | |||||
| @@ -0,0 +1,205 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import os | |||||
| from contextlib import contextmanager | |||||
| __compute_mode = "default" | |||||
| __conv_format = "default" | |||||
| _benchmark_kernel = False | |||||
| _deterministic_kernel = False | |||||
| _async_level = os.getenv("MEGENGINE_INTERP_ASYNC_LEVEL", 2) | |||||
| __all__ = [ | |||||
| "benchmark_kernel", | |||||
| "deterministic_kernel", | |||||
| "async_level", | |||||
| "_compute_mode", | |||||
| "_conv_format", | |||||
| "_override", | |||||
| ] | |||||
| @property | |||||
| def benchmark_kernel(mod): | |||||
| r"""Whether or not run possible algorithms on real device to find the best one. The default option is false, | |||||
| which means use heuristic to choose the fastest algorithm. | |||||
| Examples: | |||||
| .. code-block:: | |||||
| import megengine as mge | |||||
| mge.config.benchmark_kernel = True | |||||
| """ | |||||
| return _benchmark_kernel | |||||
| @benchmark_kernel.setter | |||||
| def benchmark_kernel(mod, option: bool): | |||||
| global _benchmark_kernel | |||||
| _benchmark_kernel = option | |||||
| @property | |||||
| def deterministic_kernel(mod): | |||||
| r"""Whether or not the fastest algorithm choosed is reproducible. The default option is false, | |||||
| which means the algorithm is not reproducible. | |||||
| Examples: | |||||
| .. code-block:: | |||||
| import megengine as mge | |||||
| mge.config.deterministic_kernel = True | |||||
| """ | |||||
| return _deterministic_kernel | |||||
| @deterministic_kernel.setter | |||||
| def deterministic_kernel(mod, option: bool): | |||||
| global _deterministic_kernel | |||||
| _deterministic_kernel = option | |||||
| @property | |||||
| def async_level(mod) -> int: | |||||
| r"""Get or set config whether raise error exactly when invoking op. The default level is 2, | |||||
| which means both device and user side errors are async. | |||||
| Examples: | |||||
| .. code-block:: | |||||
| import megengine as mge | |||||
| mge.config.async_level = 2 | |||||
| """ | |||||
| return _async_level | |||||
| @async_level.setter | |||||
| def async_level(mod, level: int): | |||||
| global _async_level | |||||
| _async_level = level | |||||
| @property | |||||
| def _compute_mode(mod): | |||||
| r"""Get or set the precision of intermediate results. The default option is "default", | |||||
| which means that no special requirements will be placed on. When set to 'float32', it | |||||
| would be used for accumulator and intermediate result, but only effective when input and | |||||
| output are of float16 dtype. | |||||
| Examples: | |||||
| .. code-block:: | |||||
| import megengine as mge | |||||
| mge.config._compute_mode = "default" | |||||
| """ | |||||
| return __compute_mode | |||||
| @_compute_mode.setter | |||||
| def _compute_mode(mod, _compute_mode: str): | |||||
| global __compute_mode | |||||
| __compute_mode = _compute_mode | |||||
| @property | |||||
| def _conv_format(mod): | |||||
| r"""Get or set convolution data/filter/output layout format. The default option is "default", | |||||
| which means that no special format will be placed on. There are all layout definitions | |||||
| ``NCHW`` layout: ``{N, C, H, W}`` | |||||
| ``NHWC`` layout: ``{N, H, W, C}`` | |||||
| ``NHWCD4`` layout: ``{N, H, (C + 3) / 4, W, 4}`` | |||||
| ``NHWCD4I`` layout: with ``align_axis = 2`` | |||||
| ``NCHW4`` layout: ``{N, C/4, H, W, 4}`` | |||||
| ``NCHW88`` layout: ``{N, C/8, H, W, 8}`` | |||||
| ``CHWN4`` layout: ``{C/4, H, W, N, 4}`` | |||||
| ``NCHW64`` layout: ``{N, C/64, H, W, 64}`` | |||||
| Examples: | |||||
| .. code-block:: | |||||
| import megengine as mge | |||||
| mge.config._conv_format = "NHWC" | |||||
| """ | |||||
| return __conv_format | |||||
| @_conv_format.setter | |||||
| def _conv_format(mod, format: str): | |||||
| global __conv_format | |||||
| __conv_format = format | |||||
| def _reset_execution_config( | |||||
| benchmark_kernel=None, | |||||
| deterministic_kernel=None, | |||||
| async_level=None, | |||||
| compute_mode=None, | |||||
| conv_format=None, | |||||
| ): | |||||
| global _benchmark_kernel, _deterministic_kernel, _async_level, __compute_mode, __conv_format | |||||
| orig_flags = ( | |||||
| _benchmark_kernel, | |||||
| _deterministic_kernel, | |||||
| _async_level, | |||||
| __compute_mode, | |||||
| __conv_format, | |||||
| ) | |||||
| if benchmark_kernel is not None: | |||||
| _benchmark_kernel = benchmark_kernel | |||||
| if deterministic_kernel is not None: | |||||
| _deterministic_kernel = deterministic_kernel | |||||
| if async_level is not None: | |||||
| _async_level = async_level | |||||
| if compute_mode is not None: | |||||
| __compute_mode = compute_mode | |||||
| if conv_format is not None: | |||||
| __conv_format = conv_format | |||||
| return orig_flags | |||||
| @contextmanager | |||||
| def _override( | |||||
| benchmark_kernel=None, | |||||
| deterministic_kernel=None, | |||||
| async_level=None, | |||||
| compute_mode=None, | |||||
| conv_format=None, | |||||
| ): | |||||
| r"""A context manager that users can opt in by attaching the decorator to set | |||||
| the config of the global variable. | |||||
| Examples: | |||||
| .. code-block:: | |||||
| import megengine as mge | |||||
| @mge.config._override( | |||||
| benchmark_kernel = True, | |||||
| deterministic_kernel = Fasle, | |||||
| async_level=2, | |||||
| compute_mode="float32", | |||||
| conv_format="NHWC", | |||||
| ) | |||||
| def train(): | |||||
| """ | |||||
| orig_flags = _reset_execution_config( | |||||
| benchmark_kernel, deterministic_kernel, async_level, compute_mode, conv_format, | |||||
| ) | |||||
| try: | |||||
| yield | |||||
| finally: | |||||
| # recover the previous values | |||||
| _reset_execution_config(*orig_flags) | |||||
| def _get_actual_op_param(function_param, config_param): | |||||
| return function_param if config_param == "default" else config_param | |||||
| @@ -8,26 +8,38 @@ | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| import os | import os | ||||
| from ..core import _config | |||||
| from ..core.ops import builtin | from ..core.ops import builtin | ||||
| from ..logger import get_logger | from ..logger import get_logger | ||||
| from ..utils.deprecation import deprecated | from ..utils.deprecation import deprecated | ||||
| Strategy = builtin.ops.Convolution.Strategy | Strategy = builtin.ops.Convolution.Strategy | ||||
| _execution_strategy = os.getenv("MEGENGINE_EXECUTION_STRATEGY", "HEURISTIC") | |||||
| if os.getenv("MEGENGINE_CONV_EXECUTION_STRATEGY") != None: | if os.getenv("MEGENGINE_CONV_EXECUTION_STRATEGY") != None: | ||||
| get_logger().warning( | get_logger().warning( | ||||
| "Environment variable `MEGENGINE_CONV_EXECUTION_STRATEGY` is deprecated, please use `MEGENGINE_EXECUTION_STRATEGY`" | "Environment variable `MEGENGINE_CONV_EXECUTION_STRATEGY` is deprecated, please use `MEGENGINE_EXECUTION_STRATEGY`" | ||||
| ) | ) | ||||
| _valid_string_option = { | |||||
| "REPRODUCIBLE": Strategy.REPRODUCIBLE, | |||||
| "HEURISTIC": Strategy.HEURISTIC, | |||||
| "PROFILE": Strategy.PROFILE, | |||||
| } | |||||
| def get_execution_strategy() -> Strategy: | def get_execution_strategy() -> Strategy: | ||||
| r"""Returns the execution strategy of :class:`~module..Conv2d` and :func:`~.matmul` | r"""Returns the execution strategy of :class:`~module..Conv2d` and :func:`~.matmul` | ||||
| See :func:`~.set_execution_strategy` for possible return values | See :func:`~.set_execution_strategy` for possible return values | ||||
| """ | """ | ||||
| return _execution_strategy | |||||
| strategy = Strategy(0) | |||||
| if _config._benchmark_kernel: | |||||
| strategy |= Strategy.PROFILE | |||||
| else: | |||||
| strategy |= Strategy.HEURISTIC | |||||
| if _config._deterministic_kernel: | |||||
| strategy |= Strategy.REPRODUCIBLE | |||||
| return strategy | |||||
| def set_execution_strategy(option): | def set_execution_strategy(option): | ||||
| @@ -50,7 +62,6 @@ def set_execution_strategy(option): | |||||
| * 'HEURISTIC' uses heuristic to choose the fastest algorithm. | * 'HEURISTIC' uses heuristic to choose the fastest algorithm. | ||||
| * 'PROFILE' runs possible algorithms on real device to find the best one. | * 'PROFILE' runs possible algorithms on real device to find the best one. | ||||
| * 'PROFILE_HEURISTIC' uses profiling result and heuristic to choose the fastest algorithm. | |||||
| * 'PROFILE_REPRODUCIBLE' uses the fastest of profiling result that is also reproducible. | * 'PROFILE_REPRODUCIBLE' uses the fastest of profiling result that is also reproducible. | ||||
| * 'HEURISTIC_REPRODUCIBLE' uses heuristic to choose the fastest algorithm that is also reproducible. | * 'HEURISTIC_REPRODUCIBLE' uses heuristic to choose the fastest algorithm that is also reproducible. | ||||
| @@ -58,29 +69,33 @@ def set_execution_strategy(option): | |||||
| It can also be set through the environment variable 'MEGENGINE_EXECUTION_STRATEGY'. | It can also be set through the environment variable 'MEGENGINE_EXECUTION_STRATEGY'. | ||||
| """ | """ | ||||
| valid_string_option = { | |||||
| "REPRODUCIBLE": Strategy.REPRODUCIBLE, | |||||
| "HEURISTIC": Strategy.HEURISTIC, | |||||
| "PROFILE": Strategy.PROFILE, | |||||
| } | |||||
| global _execution_strategy # pylint: disable=global-statement | |||||
| if isinstance(option, Strategy): | if isinstance(option, Strategy): | ||||
| _execution_strategy = option | |||||
| _config._benchmark_kernel = ( | |||||
| True if option & _valid_string_option["PROFILE"] != Strategy(0) else False | |||||
| ) | |||||
| _config._deterministic_kernel = ( | |||||
| True | |||||
| if option & _valid_string_option["REPRODUCIBLE"] != Strategy(0) | |||||
| else False | |||||
| ) | |||||
| return | return | ||||
| assert isinstance(option, str) | assert isinstance(option, str) | ||||
| strategy_tmp = Strategy(0) | |||||
| _config._benchmark_kernel = False | |||||
| _config._deterministic_kernel = False | |||||
| for opt in option.split("_"): | for opt in option.split("_"): | ||||
| if not opt in valid_string_option: | |||||
| if not opt in _valid_string_option: | |||||
| raise ValueError( | raise ValueError( | ||||
| "Valid option can only be one of {}, or combine them with '_'.".format( | "Valid option can only be one of {}, or combine them with '_'.".format( | ||||
| valid_string_option.keys() | |||||
| _valid_string_option.keys() | |||||
| ) | ) | ||||
| ) | ) | ||||
| strategy_tmp = strategy_tmp | valid_string_option[opt] | |||||
| _execution_strategy = strategy_tmp | |||||
| _config._benchmark_kernel |= _valid_string_option[opt] == Strategy.PROFILE | |||||
| _config._deterministic_kernel |= ( | |||||
| _valid_string_option[opt] == Strategy.REPRODUCIBLE | |||||
| ) | |||||
| @deprecated(version="1.3", reason="use get_execution_strategy() instead") | @deprecated(version="1.3", reason="use get_execution_strategy() instead") | ||||
| @@ -91,3 +106,6 @@ def get_conv_execution_strategy() -> str: | |||||
| @deprecated(version="1.3", reason="use set_execution_strategy() instead") | @deprecated(version="1.3", reason="use set_execution_strategy() instead") | ||||
| def set_conv_execution_strategy(option: str): | def set_conv_execution_strategy(option: str): | ||||
| return set_execution_strategy(option) | return set_execution_strategy(option) | ||||
| set_execution_strategy(os.getenv("MEGENGINE_EXECUTION_STRATEGY", "HEURISTIC")) | |||||