You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_config.py 6.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. # -*- coding: utf-8 -*-
  2. import os
  3. from contextlib import contextmanager
  4. from ._imperative_rt.core2 import (
  5. _clear_algorithm_cache,
  6. get_auto_format_convert,
  7. get_option,
  8. set_auto_format_convert,
  9. set_option,
  10. )
  11. __compute_mode = "default"
  12. __conv_format = "default"
  13. _benchmark_kernel = False
  14. _deterministic_kernel = False
  15. __all__ = [
  16. "benchmark_kernel",
  17. "deterministic_kernel",
  18. "async_level",
  19. "disable_memory_forwarding",
  20. "_compute_mode",
  21. "_conv_format",
  22. "_override",
  23. ]
  24. @property
  25. def benchmark_kernel(mod):
  26. r"""Whether or not run possible algorithms on real device to find the best one. The default option is false,
  27. which means use heuristic to choose the fastest algorithm.
  28. Examples:
  29. .. code-block::
  30. import megengine as mge
  31. mge.config.benchmark_kernel = True
  32. """
  33. return _benchmark_kernel
  34. @benchmark_kernel.setter
  35. def benchmark_kernel(mod, option: bool):
  36. global _benchmark_kernel
  37. # try different strategy, then clear algorithm cache
  38. if option != _benchmark_kernel:
  39. _clear_algorithm_cache()
  40. _benchmark_kernel = option
  41. @property
  42. def deterministic_kernel(mod):
  43. r"""Whether or not the fastest algorithm choosed is reproducible. The default option is false,
  44. which means the algorithm is not reproducible.
  45. Examples:
  46. .. code-block::
  47. import megengine as mge
  48. mge.config.deterministic_kernel = True
  49. """
  50. return _deterministic_kernel
  51. @deterministic_kernel.setter
  52. def deterministic_kernel(mod, option: bool):
  53. global _deterministic_kernel
  54. _deterministic_kernel = option
  55. @property
  56. def async_level(mod) -> int:
  57. r"""Get or set config whether raise error exactly when invoking op. The default level is 2,
  58. which means both device and user side errors are async.
  59. Examples:
  60. .. code-block::
  61. import megengine as mge
  62. mge.config.async_level = 2
  63. """
  64. return get_option("async_level")
  65. @async_level.setter
  66. def async_level(mod, level: int):
  67. assert level >= 0 and level <= 2, "async_level should be 0, 1 or 2"
  68. set_option("async_level", level)
  69. @property
  70. def disable_memory_forwarding(mod) -> bool:
  71. r"""Get or set config whether to disable memory forwarding. The default option is false,
  72. which means storage may be shared among tensors.
  73. Examples:
  74. .. code-block::
  75. import megengine as mge
  76. mge.config.disable_memory_forwarding = False
  77. """
  78. return bool(get_option("disable_memory_forwarding"))
  79. @disable_memory_forwarding.setter
  80. def disable_memory_forwarding(mod, disable: bool):
  81. set_option("disable_memory_forwarding", disable)
  82. @property
  83. def _compute_mode(mod):
  84. r"""Get or set the precision of intermediate results. The default option is "default",
  85. which means that no special requirements will be placed on. When set to 'float32', it
  86. would be used for accumulator and intermediate result, but only effective when input and
  87. output are of float16 dtype.
  88. Examples:
  89. .. code-block::
  90. import megengine as mge
  91. mge.config._compute_mode = "default"
  92. """
  93. return __compute_mode
  94. @_compute_mode.setter
  95. def _compute_mode(mod, _compute_mode: str):
  96. global __compute_mode
  97. __compute_mode = _compute_mode
  98. @property
  99. def _conv_format(mod):
  100. r"""Get or set convolution data/filter/output layout format. The default option is "default",
  101. which means that no special format will be placed on. There are all layout definitions
  102. ``NCHW`` layout: ``{N, C, H, W}``
  103. ``NHWC`` layout: ``{N, H, W, C}``
  104. ``NHWCD4`` layout: ``{N, H, (C + 3) / 4, W, 4}``
  105. ``NHWCD4I`` layout: with ``align_axis = 2``
  106. ``NCHW4`` layout: ``{N, C/4, H, W, 4}``
  107. ``NCHW88`` layout: ``{N, C/8, H, W, 8}``
  108. ``CHWN4`` layout: ``{C/4, H, W, N, 4}``
  109. ``NCHW64`` layout: ``{N, C/64, H, W, 64}``
  110. Examples:
  111. .. code-block::
  112. import megengine as mge
  113. mge.config._conv_format = "NHWC"
  114. """
  115. return __conv_format
  116. @_conv_format.setter
  117. def _conv_format(mod, format: str):
  118. global __conv_format
  119. __conv_format = format
  120. @property
  121. def _auto_format_convert(mod):
  122. r"""Automatically convert indexing params' order for NCHW Tensor to NHWC order.
  123. The default value is False, which means no convert.
  124. Examples:
  125. .. code-block::
  126. import megengine as mge
  127. mge.config._auto_format_convert = True
  128. """
  129. return get_auto_format_convert()
  130. @_auto_format_convert.setter
  131. def _auto_format_convert(mod, option: bool):
  132. set_auto_format_convert(option)
  133. def _reset_execution_config(
  134. benchmark_kernel=None,
  135. deterministic_kernel=None,
  136. async_level=None,
  137. compute_mode=None,
  138. conv_format=None,
  139. auto_format_convert=None,
  140. ):
  141. global _benchmark_kernel, _deterministic_kernel, __compute_mode, __conv_format
  142. orig_flags = (
  143. _benchmark_kernel,
  144. _deterministic_kernel,
  145. get_option("async_level"),
  146. __compute_mode,
  147. __conv_format,
  148. get_auto_format_convert(),
  149. )
  150. if benchmark_kernel is not None:
  151. _benchmark_kernel = benchmark_kernel
  152. if deterministic_kernel is not None:
  153. _deterministic_kernel = deterministic_kernel
  154. if async_level is not None:
  155. set_option("async_level", async_level)
  156. if compute_mode is not None:
  157. __compute_mode = compute_mode
  158. if conv_format is not None:
  159. __conv_format = conv_format
  160. if auto_format_convert is not None:
  161. set_auto_format_convert(auto_format_convert)
  162. return orig_flags
  163. @contextmanager
  164. def _override(
  165. benchmark_kernel=None,
  166. deterministic_kernel=None,
  167. async_level=None,
  168. compute_mode=None,
  169. conv_format=None,
  170. auto_format_convert=None,
  171. ):
  172. r"""A context manager that users can opt in by attaching the decorator to set
  173. the config of the global variable.
  174. Examples:
  175. .. code-block::
  176. import megengine as mge
  177. @mge.config._override(
  178. benchmark_kernel = True,
  179. deterministic_kernel = Fasle,
  180. async_level=2,
  181. compute_mode="float32",
  182. conv_format="NHWC",
  183. auto_format_convert=True,
  184. )
  185. def train():
  186. """
  187. orig_flags = _reset_execution_config(
  188. benchmark_kernel,
  189. deterministic_kernel,
  190. async_level,
  191. compute_mode,
  192. conv_format,
  193. auto_format_convert,
  194. )
  195. try:
  196. yield
  197. finally:
  198. # recover the previous values
  199. _reset_execution_config(*orig_flags)
  200. def _get_actual_op_param(function_param, config_param):
  201. return function_param if config_param == "default" else config_param