You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

context.py 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. The context of mindspore, used to configure the current execution environment,
  17. including execution mode, execution backend and other feature switches.
  18. """
  19. import os
  20. import threading
  21. from collections import namedtuple
  22. from types import FunctionType
  23. from mindspore import log as logger
  24. from mindspore._c_expression import MSContext
  25. from mindspore._checkparam import args_type_check
  26. from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \
  27. _reset_auto_parallel_context
  28. __all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'set_context', 'get_context', 'set_auto_parallel_context',
  29. 'get_auto_parallel_context', 'reset_auto_parallel_context']
  30. GRAPH_MODE = 0
  31. PYNATIVE_MODE = 1
  32. def _make_directory(path):
  33. """Make directory."""
  34. real_path = None
  35. if path is None or not isinstance(path, str) or path.strip() == "":
  36. raise ValueError(f"Input path `{path}` is invalid type")
  37. # convert the relative paths
  38. path = os.path.realpath(path)
  39. logger.debug("The absolute path is %r", path)
  40. # check whether the path is already existed and has written permissions
  41. if os.path.exists(path):
  42. real_path = path
  43. else:
  44. # All exceptions need to be caught because create directory maybe have some limit(permissions)
  45. logger.debug("The directory(%s) doesn't exist, will create it", path)
  46. try:
  47. os.makedirs(path)
  48. real_path = path
  49. except PermissionError as e:
  50. logger.error(f"No write permission on the directory `{path}, error = {e}")
  51. raise ValueError(f"No write permission on the directory `{path}`.")
  52. return real_path
  53. class _ThreadLocalInfo(threading.local):
  54. """
  55. Thread local Info used for store thread local attributes.
  56. """
  57. def __init__(self):
  58. super(_ThreadLocalInfo, self).__init__()
  59. self._reserve_class_name_in_scope = True
  60. @property
  61. def reserve_class_name_in_scope(self):
  62. """Gets whether to save the network class name in the scope."""
  63. return self._reserve_class_name_in_scope
  64. @reserve_class_name_in_scope.setter
  65. def reserve_class_name_in_scope(self, reserve_class_name_in_scope):
  66. """Sets whether to save the network class name in the scope."""
  67. if not isinstance(reserve_class_name_in_scope, bool):
  68. raise ValueError("Set reserve_class_name_in_scope value must be bool!")
  69. self._reserve_class_name_in_scope = reserve_class_name_in_scope
  70. _ContextRecord = namedtuple("_ContextRecord", ["is_pynative_mode", "switch_context_fn"])
  71. class _ContextSwitchInfo(threading.local):
  72. """
  73. Record of context switch information.
  74. Args:
  75. is_pynative (bool): Whether to adopt the PyNative mode.
  76. """
  77. def __init__(self, is_pynative):
  78. super(_ContextSwitchInfo, self).__init__()
  79. self.context_stack = []
  80. if is_pynative:
  81. self.push(True, None)
  82. def push(self, is_pynative, switch_context_fn):
  83. """
  84. Push a context switch record onto the stack.
  85. Args:
  86. is_pynative (bool): Whether context switch to PyNative mode.
  87. switch_context_fn (Function): A callable that executes the context switch.
  88. """
  89. if isinstance(switch_context_fn, FunctionType):
  90. switch_context_fn()
  91. self.context_stack.append(_ContextRecord(is_pynative, switch_context_fn))
  92. def pop(self):
  93. self.context_stack.pop()
  94. class _Context:
  95. """
  96. _Context is the environment in which operations are executed
  97. Note:
  98. Create a context through instantiating Context object is not recommended.
  99. should use context() to get the context since Context is singleton.
  100. """
  101. _instance = None
  102. _instance_lock = threading.Lock()
  103. def __init__(self):
  104. self._thread_local_info = _ThreadLocalInfo()
  105. self._context_switches = _ContextSwitchInfo(True)
  106. self._context_handle = MSContext.get_instance()
  107. def __new__(cls, *args, **kwargs):
  108. if cls._instance is None:
  109. cls._instance_lock.acquire()
  110. cls._instance = object.__new__(cls)
  111. cls._instance_lock.release()
  112. return cls._instance
  113. def __getattribute__(self, attr):
  114. value = object.__getattribute__(self, attr)
  115. if attr == "_context_handle" and value is None:
  116. raise ValueError("Context handle is none in context!!!")
  117. return value
  118. # For Ascend task sink mode execution
  119. @property
  120. def enable_task_sink(self):
  121. return self._context_handle.get_task_sink_flag()
  122. @enable_task_sink.setter
  123. def enable_task_sink(self, task_sink):
  124. self._context_handle.set_task_sink_flag(task_sink)
  125. @property
  126. def mode(self):
  127. return self._context_handle.get_execution_mode()
  128. @mode.setter
  129. def mode(self, mode):
  130. """
  131. Switch between Graph mode and PyNative mode.
  132. Args:
  133. mode (int): GRAPH_MODE or PYNATIVE_MODE.
  134. """
  135. self._context_handle.set_execution_mode(mode)
  136. if mode == PYNATIVE_MODE:
  137. if self.enable_debug_runtime:
  138. self.set_backend_policy("vm")
  139. self._context_switches.push(True, None)
  140. else:
  141. if self.enable_debug_runtime:
  142. self.set_backend_policy("ge")
  143. self._context_switches.push(False, None)
  144. def set_backend_policy(self, policy):
  145. success = self._context_handle.set_backend_policy(policy)
  146. if not success:
  147. raise RuntimeError("Backend policy must be one of ge, vm, ms.")
  148. @property
  149. def precompile_only(self):
  150. return self._context_handle.get_precompile_only()
  151. @precompile_only.setter
  152. def precompile_only(self, precompile_only):
  153. self._context_handle.set_precompile_only(precompile_only)
  154. @property
  155. def save_graphs(self):
  156. return self._context_handle.get_save_graphs_flag()
  157. @save_graphs.setter
  158. def save_graphs(self, save_graphs_flag):
  159. self._context_handle.set_save_graphs_flag(save_graphs_flag)
  160. @property
  161. def save_graphs_path(self):
  162. return self._context_handle.get_save_graphs_path()
  163. @save_graphs_path.setter
  164. def save_graphs_path(self, save_graphs_path):
  165. self._context_handle.set_save_graphs_path(_make_directory(save_graphs_path))
  166. @property
  167. def device_target(self):
  168. return self._context_handle.get_device_target()
  169. @device_target.setter
  170. def device_target(self, target):
  171. success = self._context_handle.set_device_target(target)
  172. if not success:
  173. raise ValueError("Target device name is invalid!!!")
  174. @property
  175. def device_id(self):
  176. return self._context_handle.get_device_id()
  177. @device_id.setter
  178. def device_id(self, device_id):
  179. if device_id < 0 or device_id > 4095:
  180. raise ValueError("Device id must be in [0, 4095], but got {}".format(device_id))
  181. success = self._context_handle.set_device_id(device_id)
  182. if not success:
  183. raise RuntimeError("Device id set failed!!!")
  184. @property
  185. def enable_ir_fusion(self):
  186. return self._context_handle.get_ir_fusion_flag()
  187. @enable_ir_fusion.setter
  188. def enable_ir_fusion(self, enable_ir_fusion):
  189. self._context_handle.set_ir_fusion_flag(enable_ir_fusion)
  190. @property
  191. def enable_loop_sink(self):
  192. return self._context_handle.get_loop_sink_flag()
  193. @enable_loop_sink.setter
  194. def enable_loop_sink(self, enable_loop_sink):
  195. self._context_handle.set_loop_sink_flag(enable_loop_sink)
  196. @property
  197. def enable_mem_reuse(self):
  198. return self._context_handle.get_enable_mem_reuse()
  199. @enable_mem_reuse.setter
  200. def enable_mem_reuse(self, enable_mem_reuse):
  201. self._context_handle.set_enable_mem_reuse(enable_mem_reuse)
  202. @property
  203. def save_ms_model(self):
  204. return self._context_handle.get_save_ms_model_flag()
  205. @save_ms_model.setter
  206. def save_ms_model(self, save_ms_model_flag):
  207. self._context_handle.set_save_ms_model_flag(save_ms_model_flag)
  208. @property
  209. def save_ms_model_path(self):
  210. return self._context_handle.get_save_ms_model_path()
  211. @save_ms_model_path.setter
  212. def save_ms_model_path(self, save_ms_model_path):
  213. self._context_handle.set_save_ms_model_path(save_ms_model_path)
  214. @property
  215. def enable_gpu_summary(self):
  216. return self._context_handle.get_enable_gpu_summary()
  217. @enable_gpu_summary.setter
  218. def enable_gpu_summary(self, enable_gpu_summary):
  219. self._context_handle.set_enable_gpu_summary(enable_gpu_summary)
  220. @property
  221. def enable_auto_mixed_precision(self):
  222. return self._context_handle.get_auto_mixed_precision_flag()
  223. @enable_auto_mixed_precision.setter
  224. def enable_auto_mixed_precision(self, enable_auto_mixed_precision):
  225. self._context_handle.set_auto_mixed_precision_flag(enable_auto_mixed_precision)
  226. @property
  227. def enable_reduce_precision(self):
  228. return self._context_handle.get_enable_reduce_precision_flag()
  229. @enable_reduce_precision.setter
  230. def enable_reduce_precision(self, enable_reduce_precision):
  231. self._context_handle.set_enable_reduce_precision_flag(enable_reduce_precision)
  232. @property
  233. def enable_dump(self):
  234. return self._context_handle.get_enable_dump()
  235. @enable_dump.setter
  236. def enable_dump(self, enable_dump):
  237. self._context_handle.set_enable_dump(enable_dump)
  238. @property
  239. def save_dump_path(self):
  240. return self._context_handle.get_save_dump_path()
  241. @save_dump_path.setter
  242. def save_dump_path(self, save_dump_path):
  243. self._context_handle.set_save_dump_path(save_dump_path)
  244. @property
  245. def reserve_class_name_in_scope(self):
  246. """Gets whether to save the network class name in the scope."""
  247. return self._thread_local_info.reserve_class_name_in_scope
  248. @reserve_class_name_in_scope.setter
  249. def reserve_class_name_in_scope(self, reserve_class_name_in_scope):
  250. """Sets whether to save the network class name in the scope."""
  251. self._thread_local_info.reserve_class_name_in_scope = reserve_class_name_in_scope
  252. @property
  253. def enable_dynamic_memory(self):
  254. return self._context_handle.get_enable_dynamic_mem_pool()
  255. @enable_dynamic_memory.setter
  256. def enable_dynamic_memory(self, enable_dynamic_memory):
  257. self._context_handle.set_enable_dynamic_mem_pool(enable_dynamic_memory)
  258. @property
  259. def graph_memory_max_size(self):
  260. return None
  261. @graph_memory_max_size.setter
  262. def graph_memory_max_size(self, graph_memory_max_size):
  263. if check_input_format(graph_memory_max_size):
  264. graph_memory_max_size_ = graph_memory_max_size[:-2] + " * 1024 * 1024 * 1024"
  265. self._context_handle.set_graph_memory_max_size(graph_memory_max_size_)
  266. else:
  267. raise ValueError("Context param graph_memory_max_size should be in correct format! Such as \"26GB\"")
  268. @property
  269. def variable_memory_max_size(self):
  270. return None
  271. @variable_memory_max_size.setter
  272. def variable_memory_max_size(self, variable_memory_max_size):
  273. if check_input_format(variable_memory_max_size):
  274. variable_memory_max_size_ = variable_memory_max_size[:-2] + " * 1024 * 1024 * 1024"
  275. self._context_handle.set_variable_memory_max_size(variable_memory_max_size_)
  276. else:
  277. raise ValueError("Context param variable_memory_max_size should be in correct format! Such as \"5GB\"")
  278. @property
  279. def enable_ge(self):
  280. return self._context_handle.get_backend_policy() == 'ge'
  281. @property
  282. def enable_debug_runtime(self):
  283. return self._thread_local_info.debug_runtime
  284. @enable_debug_runtime.setter
  285. def enable_debug_runtime(self, enable):
  286. thread_info = self._thread_local_info
  287. thread_info.debug_runtime = enable
  288. def check_input_format(x):
  289. import re
  290. pattern = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB'
  291. result = re.match(pattern, x)
  292. return result is not None
  293. _k_context = None
  294. def _context():
  295. """
  296. Get the global _context, if context is not created, create a new one.
  297. Returns:
  298. _Context, the global context in PyNative mode.
  299. """
  300. global _k_context
  301. if _k_context is None:
  302. default_backend = 'debug'
  303. try:
  304. from mindspore import default_config
  305. default_backend = default_config.__backend__
  306. except ImportError:
  307. logger.error("import default config fail")
  308. _k_context = _Context()
  309. _k_context.enable_debug_runtime = False
  310. if default_backend == 'debug':
  311. _k_context.enable_debug_runtime = True
  312. default_backend = 'vm'
  313. _k_context.set_backend_policy(default_backend)
  314. return _k_context
  315. @args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, parallel_mode=str,
  316. parameter_broadcast=bool)
  317. def set_auto_parallel_context(**kwargs):
  318. """
  319. Set auto parallel context.
  320. Note:
  321. Attribute name is required for setting attributes.
  322. Args:
  323. device_num (int): Available device number, the value must be in [1, 4096]. Default: 1.
  324. global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0.
  325. mirror_mean (bool): Whether to perform mean operator after all-reduce of mirror.
  326. "stand_alone" do not support mirror_mean. Default: False.
  327. cast_before_mirror (bool): Insert Mirror Op after the cast if this flag is True.
  328. "stand_alone", "data_parallel" and "hybrid_parallel" do not support
  329. cast_before_mirror. Default: True.
  330. parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel",
  331. "hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone".
  332. - stand_alone: Only one processor working.
  333. - data_parallel: Distributing the data across different processors.
  334. - hybrid_parallel: Achieving data parallelism and model parallelism manually.
  335. - semi_auto_parallel: Achieving data parallelism and model parallelism by
  336. setting parallel strategies.
  337. - auto_parallel: Achieving parallelism automatically.
  338. parameter_broadcast (bool): Indicating whether to broadcast parameters before training.
  339. "stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter
  340. broadcast. Default: False.
  341. Raises:
  342. ValueError: If input key is not attribute in auto parallel context.
  343. Examples:
  344. >>> context.set_auto_parallel_context(device_num=8)
  345. >>> context.set_auto_parallel_context(global_rank=0)
  346. >>> context.set_auto_parallel_context(mirror_mean=True)
  347. >>> context.set_auto_parallel_context(cast_before_mirror=False)
  348. >>> context.set_auto_parallel_context(parallel_mode="auto_parallel")
  349. >>> context.set_auto_parallel_context(parameter_broadcast=False)
  350. """
  351. _set_auto_parallel_context(**kwargs)
  352. def get_auto_parallel_context(attr_key):
  353. """
  354. Gets auto parallel context attribute value according to the key.
  355. Args:
  356. attr_key (str): The key of the attribute.
  357. Returns:
  358. Returns attribute value according to the key.
  359. Raises:
  360. ValueError: If input key is not attribute in auto parallel context.
  361. """
  362. return _get_auto_parallel_context(attr_key)
  363. def reset_auto_parallel_context():
  364. """
  365. Reset auto parallel context attributes to the default values:
  366. - device_num: 1.
  367. - global_rank: 0.
  368. - mirror_mean: False.
  369. - cast_before_mirror: True.
  370. - parallel_mode: "stand_alone".
  371. - parameter_broadcast: False.
  372. """
  373. _reset_auto_parallel_context()
  374. @args_type_check(mode=int, precompile_only=bool, device_target=str,
  375. device_id=int, enable_ir_fusion=bool, save_graphs=bool,
  376. enable_task_sink=bool, save_graphs_path=str, enable_loop_sink=bool,
  377. enable_mem_reuse=bool, save_ms_model=bool, save_ms_model_path=str, enable_gpu_summary=bool,
  378. enable_auto_mixed_precision=bool, enable_dump=bool, save_dump_path=str,
  379. enable_reduce_precision=bool, enable_dynamic_memory=bool, graph_memory_max_size=str,
  380. variable_memory_max_size=str)
  381. def set_context(**kwargs):
  382. """
  383. Sets context for running environment.
  384. Context should be configured before running your program. If there is no configuration,
  385. the "Ascend" device target will be used by default. GRAPH_MODE or
  386. PYNATIVE_MODE can be set by `mode` attribute and both modes support all backends, default
  387. mode is PYNATIVE_MODE.
  388. When the `save_graphs` attribute is set to True, attribute of `save_graphs_path` is used to set the
  389. intermediate compilation graph storage path. By default, the graphs are saved in the current directory.
  390. As for other configurations and arguments, please refer to the corresponding module
  391. description, the configuration is optional and can be enabled when needed.
  392. Note:
  393. Attribute name is required for setting attributes.
  394. If need to config graph max memory size and variable max memory size, one must make sure:
  395. The sum of graph_memory_max_size and variable_memory_max_size should be less than total memory size of
  396. a device, while the total memory is supposed to be no more than 256GB.
  397. Args:
  398. mode (int): Running in GRAPH_MODE(0) or PYNATIVE_MODE(1). Default: PYNATIVE_MODE.
  399. device_target (str): The target device to run, support "Ascend", "GPU", "CPU". Default: "Ascend".
  400. device_id (int): Id of target device, the value must be in [0, device_num_per_host-1],
  401. while device_num_per_host should no more than 4096. Default: 0.
  402. enable_ir_fusion (bool): Whether to enable ir fusion. Default: True.
  403. save_graphs (bool): Whether to save graphs. Default: False.
  404. enable_loop_sink (bool): Whether to enable loop sink. Default: True.
  405. enable_task_sink (bool): Whether to enable task sink. Default: True.
  406. enable_mem_reuse (bool): Whether to enable memory reuse. Default: True.
  407. save_ms_model (bool): Whether to save lite model converted by graph. Default: False.
  408. save_ms_model_path (str): Path to save converted lite model. Default: "."
  409. enable_gpu_summary (bool): Whether to enable gpu summary. Default: True.
  410. save_graphs_path (str): Path to save graphs. Default: "."
  411. enable_auto_mixed_precision (bool): Whether to enable auto mixed precision. Default: True.
  412. reserve_class_name_in_scope (bool) : Whether to save the network class name in the scope. Default: True.
  413. enable_reduce_precision (bool): Whether to enable precision reduction. Default: True.
  414. enable_dump (bool): Whether to enable dump. Default: False.
  415. save_dump_path (str): When the program is executed on Ascend, operators can dump data here.
  416. The root dump path is configured in /home/HwHiAiUser/ide_daemon/ide_daemon.cfg.
  417. So the real dump path is "{configured root dump path}/{`save_dump_path`}". Default: ".".
  418. enable_dynamic_memory (bool): Whether to enable dynamic memory. Default: False.
  419. graph_memory_max_size (str): Sets graph memory max size. Default: "26GB".
  420. variable_memory_max_size (str): Sets variable memory max size. Default: "5GB".
  421. Raises:
  422. ValueError: If input key is not an attribute in context.
  423. Examples:
  424. >>> context.set_context(mode=context.GRAPH_MODE)
  425. >>> context.set_context(mode=context.PYNATIVE_MODE)
  426. >>> context.set_context(device_target="Ascend")
  427. >>> context.set_context(device_id=0)
  428. >>> context.set_context(save_graphs=True, save_graphs_path="./model.ms")
  429. >>> context.set_context(enable_task_sink=True)
  430. >>> context.set_context(enable_mem_reuse=True)
  431. >>> context.set_context(enable_reduce_precision=True)
  432. >>> context.set_context(save_ms_model=True, save_ms_model_path=".")
  433. >>> context.set_context(enable_gpu_summary=False)
  434. >>> context.set_context(enable_dump=False, save_dump_path=".")
  435. >>> context.set_context(reserve_class_name_in_scope=True)
  436. >>> context.set_context(enable_dynamic_memory=True)
  437. >>> context.set_context(graph_memory_max_size="25GB")
  438. >>> context.set_context(variable_memory_max_size="6GB")
  439. >>> context.set_context(mode=context.GRAPH_MODE,
  440. >>> device_target="Ascend",device_id=0, save_graphs=True,
  441. >>> save_graphs_path="/mindspore")
  442. """
  443. for key, value in kwargs.items():
  444. if not hasattr(_context(), key):
  445. raise ValueError("Set context keyword %s is not recognized!" % key)
  446. setattr(_context(), key, value)
  447. def get_context(attr_key):
  448. """
  449. Gets context attribute value according to the input key.
  450. Args:
  451. attr_key (str): The key of the attribute.
  452. Returns:
  453. Object, The value of given attribute key.
  454. Raises:
  455. ValueError: If input key is not an attribute in context.
  456. """
  457. if not hasattr(_context(), attr_key):
  458. raise ValueError("Get context keyword %s is not recognized!" % attr_key)
  459. return getattr(_context(), attr_key)