You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

context.py 27 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. The context of mindspore, used to configure the current execution environment,
  17. including execution mode, execution backend and other feature switches.
  18. """
  19. import os
  20. import time
  21. import threading
  22. from collections import namedtuple
  23. from types import FunctionType
  24. from mindspore import log as logger
  25. from mindspore._c_expression import MSContext
  26. from mindspore._checkparam import args_type_check
  27. from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \
  28. _reset_auto_parallel_context
  29. __all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'set_context', 'get_context', 'set_auto_parallel_context',
  30. 'get_auto_parallel_context', 'reset_auto_parallel_context']
  31. GRAPH_MODE = 0
  32. PYNATIVE_MODE = 1
  33. # The max memory size of graph plus variable.
  34. _DEVICE_APP_MEMORY_SIZE = 31
  35. def _make_directory(path):
  36. """Make directory."""
  37. real_path = None
  38. if path is None or not isinstance(path, str) or path.strip() == "":
  39. raise ValueError(f"Input path `{path}` is invalid type")
  40. # convert the relative paths
  41. path = os.path.realpath(path)
  42. logger.debug("The absolute path is %r", path)
  43. # check whether the path is already existed and has written permissions
  44. if os.path.exists(path):
  45. real_path = path
  46. else:
  47. # All exceptions need to be caught because create directory maybe have some limit(permissions)
  48. logger.debug("The directory(%s) doesn't exist, will create it", path)
  49. try:
  50. os.makedirs(path)
  51. real_path = path
  52. except PermissionError as e:
  53. logger.error(f"No write permission on the directory `{path}, error = {e}")
  54. raise ValueError(f"No write permission on the directory `{path}`.")
  55. return real_path
  56. def _get_print_file_name(file_name):
  57. """Add timestamp suffix to file name. Rename the file name: file_name + "." + time(seconds)."""
  58. time_second = str(int(time.time()))
  59. file_name = file_name + "." + time_second
  60. if os.path.exists(file_name):
  61. ValueError("This file {} already exists.".format(file_name))
  62. return file_name
  63. class _ThreadLocalInfo(threading.local):
  64. """
  65. Thread local Info used for store thread local attributes.
  66. """
  67. def __init__(self):
  68. super(_ThreadLocalInfo, self).__init__()
  69. self._reserve_class_name_in_scope = True
  70. @property
  71. def reserve_class_name_in_scope(self):
  72. """Gets whether to save the network class name in the scope."""
  73. return self._reserve_class_name_in_scope
  74. @reserve_class_name_in_scope.setter
  75. def reserve_class_name_in_scope(self, reserve_class_name_in_scope):
  76. """Sets whether to save the network class name in the scope."""
  77. if not isinstance(reserve_class_name_in_scope, bool):
  78. raise ValueError(
  79. "Set reserve_class_name_in_scope value must be bool!")
  80. self._reserve_class_name_in_scope = reserve_class_name_in_scope
  81. _ContextRecord = namedtuple(
  82. "_ContextRecord", ["is_pynative_mode", "switch_context_fn"])
  83. class _ContextSwitchInfo(threading.local):
  84. """
  85. Record of context switch information.
  86. Args:
  87. is_pynative (bool): Whether to adopt the PyNative mode.
  88. """
  89. def __init__(self, is_pynative):
  90. super(_ContextSwitchInfo, self).__init__()
  91. self.context_stack = []
  92. if is_pynative:
  93. self.push(True, None)
  94. def push(self, is_pynative, switch_context_fn):
  95. """
  96. Push a context switch record onto the stack.
  97. Args:
  98. is_pynative (bool): Whether context switch to PyNative mode.
  99. switch_context_fn (Function): A callable that executes the context switch.
  100. """
  101. if isinstance(switch_context_fn, FunctionType):
  102. switch_context_fn()
  103. self.context_stack.append(
  104. _ContextRecord(is_pynative, switch_context_fn))
  105. def pop(self):
  106. self.context_stack.pop()
  107. class _Context:
  108. """
  109. _Context is the environment in which operations are executed
  110. Note:
  111. Create a context through instantiating Context object is not recommended.
  112. should use context() to get the context since Context is singleton.
  113. """
  114. _instance = None
  115. _instance_lock = threading.Lock()
  116. def __init__(self):
  117. self._thread_local_info = _ThreadLocalInfo()
  118. self._context_switches = _ContextSwitchInfo(True)
  119. self._context_handle = MSContext.get_instance()
  120. def __new__(cls, *args, **kwargs):
  121. if cls._instance is None:
  122. cls._instance_lock.acquire()
  123. cls._instance = object.__new__(cls)
  124. cls._instance_lock.release()
  125. return cls._instance
  126. def __getattribute__(self, attr):
  127. value = object.__getattribute__(self, attr)
  128. if attr == "_context_handle" and value is None:
  129. raise ValueError("Context handle is none in context!!!")
  130. return value
  131. @property
  132. def mode(self):
  133. return self._context_handle.get_execution_mode()
  134. @mode.setter
  135. def mode(self, mode):
  136. """
  137. Switch between Graph mode and PyNative mode.
  138. Args:
  139. mode (int): GRAPH_MODE or PYNATIVE_MODE.
  140. """
  141. self._context_handle.set_execution_mode(mode)
  142. if mode == PYNATIVE_MODE:
  143. if self.enable_debug_runtime:
  144. self.set_backend_policy("vm")
  145. self._context_switches.push(True, None)
  146. else:
  147. if self.enable_debug_runtime:
  148. self.set_backend_policy("ge")
  149. self._context_switches.push(False, None)
  150. def set_backend_policy(self, policy):
  151. success = self._context_handle.set_backend_policy(policy)
  152. if not success:
  153. raise RuntimeError("Backend policy must be one of ge, vm, ms.")
  154. @property
  155. def precompile_only(self):
  156. return self._context_handle.get_precompile_only()
  157. @precompile_only.setter
  158. def precompile_only(self, precompile_only):
  159. self._context_handle.set_precompile_only(precompile_only)
  160. @property
  161. def save_graphs(self):
  162. return self._context_handle.get_save_graphs_flag()
  163. @save_graphs.setter
  164. def save_graphs(self, save_graphs_flag):
  165. self._context_handle.set_save_graphs_flag(save_graphs_flag)
  166. @property
  167. def save_graphs_path(self):
  168. return self._context_handle.get_save_graphs_path()
  169. @save_graphs_path.setter
  170. def save_graphs_path(self, save_graphs_path):
  171. self._context_handle.set_save_graphs_path(
  172. _make_directory(save_graphs_path))
  173. @property
  174. def device_target(self):
  175. return self._context_handle.get_device_target()
  176. @device_target.setter
  177. def device_target(self, target):
  178. success = self._context_handle.set_device_target(target)
  179. if not success:
  180. raise ValueError("Target device name is invalid!!!")
  181. if self.enable_debug_runtime and self.device_target == "CPU":
  182. self.set_backend_policy("vm")
  183. @property
  184. def device_id(self):
  185. return self._context_handle.get_device_id()
  186. @device_id.setter
  187. def device_id(self, device_id):
  188. if device_id < 0 or device_id > 4095:
  189. raise ValueError(
  190. "Device id must be in [0, 4095], but got {}".format(device_id))
  191. success = self._context_handle.set_device_id(device_id)
  192. if not success:
  193. raise RuntimeError("Device id set failed!!!")
  194. @property
  195. def max_call_depth(self):
  196. return self._context_handle.get_max_call_depth()
  197. @max_call_depth.setter
  198. def max_call_depth(self, max_call_depth):
  199. if max_call_depth <= 0:
  200. raise ValueError(
  201. "Max call depth must be greater than 0, but got {}".format(max_call_depth))
  202. self._context_handle.set_max_call_depth(max_call_depth)
  203. @property
  204. def enable_auto_mixed_precision(self):
  205. return self._context_handle.get_auto_mixed_precision_flag()
  206. @enable_auto_mixed_precision.setter
  207. def enable_auto_mixed_precision(self, enable_auto_mixed_precision):
  208. self._context_handle.set_auto_mixed_precision_flag(
  209. enable_auto_mixed_precision)
  210. @property
  211. def enable_reduce_precision(self):
  212. return self._context_handle.get_enable_reduce_precision_flag()
  213. @enable_reduce_precision.setter
  214. def enable_reduce_precision(self, enable_reduce_precision):
  215. self._context_handle.set_enable_reduce_precision_flag(
  216. enable_reduce_precision)
  217. @property
  218. def enable_dump(self):
  219. return self._context_handle.get_enable_dump()
  220. @enable_dump.setter
  221. def enable_dump(self, enable_dump):
  222. self._context_handle.set_enable_dump(enable_dump)
  223. @property
  224. def save_dump_path(self):
  225. return self._context_handle.get_save_dump_path()
  226. @save_dump_path.setter
  227. def save_dump_path(self, save_dump_path):
  228. self._context_handle.set_save_dump_path(save_dump_path)
  229. @property
  230. def enable_profiling(self):
  231. return self._context_handle.get_enable_profiling()
  232. @enable_profiling.setter
  233. def enable_profiling(self, flag):
  234. self._context_handle.set_enable_profiling(flag)
  235. @property
  236. def profiling_options(self):
  237. return self._context_handle.get_profiling_options()
  238. @profiling_options.setter
  239. def profiling_options(self, option):
  240. options = ["training_trace", "task_trace",
  241. "task_trace:training_trace", "training_trace:task_trace", "op_trace"]
  242. if option not in options:
  243. raise ValueError("Profiling options must be in 'training_trace' 'task_trace' "
  244. "'task_trace:training_trace' 'training_trace:task_trace' or 'op_trace'.")
  245. self._context_handle.set_profiling_options(option)
  246. @property
  247. def enable_graph_kernel(self):
  248. return self._context_handle.get_enable_graph_kernel()
  249. @enable_graph_kernel.setter
  250. def enable_graph_kernel(self, graph_kernel_switch_):
  251. self._context_handle.set_enable_graph_kernel(graph_kernel_switch_)
  252. @property
  253. def reserve_class_name_in_scope(self):
  254. """Gets whether to save the network class name in the scope."""
  255. return self._thread_local_info.reserve_class_name_in_scope
  256. @reserve_class_name_in_scope.setter
  257. def reserve_class_name_in_scope(self, reserve_class_name_in_scope):
  258. """Sets whether to save the network class name in the scope."""
  259. self._thread_local_info.reserve_class_name_in_scope = reserve_class_name_in_scope
  260. @property
  261. def variable_memory_max_size(self):
  262. return None
  263. @variable_memory_max_size.setter
  264. def variable_memory_max_size(self, variable_memory_max_size):
  265. if not check_input_format(variable_memory_max_size):
  266. raise ValueError(
  267. "Context param variable_memory_max_size should be in correct format! Such as \"5GB\"")
  268. if int(variable_memory_max_size[:-2]) >= _DEVICE_APP_MEMORY_SIZE:
  269. raise ValueError(
  270. "Context param variable_memory_max_size should be less than 31GB.")
  271. variable_memory_max_size_ = variable_memory_max_size[:-
  272. 2] + " * 1024 * 1024 * 1024"
  273. graph_memory_max_size = _DEVICE_APP_MEMORY_SIZE - \
  274. int(variable_memory_max_size[:-2])
  275. graph_memory_max_size_ = str(
  276. graph_memory_max_size) + " * 1024 * 1024 * 1024"
  277. self._context_handle.set_variable_memory_max_size(
  278. variable_memory_max_size_)
  279. self._context_handle.set_graph_memory_max_size(graph_memory_max_size_)
  280. @property
  281. def enable_ge(self):
  282. return self._context_handle.get_backend_policy() == 'ge'
  283. @property
  284. def enable_debug_runtime(self):
  285. return self._thread_local_info.debug_runtime
  286. @enable_debug_runtime.setter
  287. def enable_debug_runtime(self, enable):
  288. thread_info = self._thread_local_info
  289. thread_info.debug_runtime = enable
  290. @property
  291. def check_bprop(self):
  292. return self._context_handle.get_check_bprop_flag()
  293. @check_bprop.setter
  294. def check_bprop(self, check_bprop_flag):
  295. self._context_handle.set_check_bprop_flag(check_bprop_flag)
  296. @property
  297. def max_device_memory(self):
  298. return self._context_handle.get_max_device_memory()
  299. @max_device_memory.setter
  300. def max_device_memory(self, max_device_memory):
  301. if not check_input_format(max_device_memory):
  302. raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"")
  303. max_device_memory_value = float(max_device_memory[:-2])
  304. if max_device_memory_value == 0:
  305. raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"")
  306. self._context_handle.set_max_device_memory(max_device_memory_value)
  307. @property
  308. def print_file_path(self):
  309. return None
  310. @print_file_path.setter
  311. def print_file_path(self, file_path):
  312. """Add timestamp suffix to file name. Sets print file path."""
  313. print_file_path = os.path.realpath(file_path)
  314. if os.path.isdir(print_file_path):
  315. raise IOError("Print_file_path should be file path, but got {}.".format(file_path))
  316. if os.path.exists(print_file_path):
  317. _path, _file_name = os.path.split(print_file_path)
  318. path = _make_directory(_path)
  319. file_name = _get_print_file_name(_file_name)
  320. full_file_name = os.path.join(path, file_name)
  321. else:
  322. full_file_name = print_file_path
  323. self._context_handle.set_print_file_path(full_file_name)
  324. @property
  325. def enable_sparse(self):
  326. return self._context_handle.get_enable_sparse()
  327. @enable_sparse.setter
  328. def enable_sparse(self, enable_sparse):
  329. self._context_handle.set_enable_sparse(enable_sparse)
  330. def check_input_format(x):
  331. import re
  332. pattern = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB'
  333. result = re.match(pattern, x)
  334. return result is not None
  335. _k_context = None
  336. def _context():
  337. """
  338. Get the global _context, if context is not created, create a new one.
  339. Returns:
  340. _Context, the global context in PyNative mode.
  341. """
  342. global _k_context
  343. if _k_context is None:
  344. default_backend = 'debug'
  345. try:
  346. from mindspore import default_config
  347. default_backend = default_config.__backend__
  348. except ImportError:
  349. logger.error("import default config fail")
  350. _k_context = _Context()
  351. _k_context.enable_debug_runtime = False
  352. if default_backend == 'debug':
  353. _k_context.enable_debug_runtime = True
  354. default_backend = 'vm'
  355. _k_context.set_backend_policy(default_backend)
  356. return _k_context
  357. @args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, parallel_mode=str,
  358. auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
  359. strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool)
  360. def set_auto_parallel_context(**kwargs):
  361. """
  362. Set auto parallel context.
  363. Note:
  364. Attribute name is required for setting attributes.
  365. If a program has tasks with different parallel modes, then before setting new parallel mode for
  366. next task, interface mindspore.context.reset_auto_parallel_context() needs to be called to reset
  367. the configuration.
  368. Setting or changing parallel modes must be called before any Initializer created, or RuntimeError
  369. may be raised when compile network.
  370. Args:
  371. device_num (int): Available device number, the value must be in [1, 4096]. Default: 1.
  372. global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0.
  373. mirror_mean (bool): Whether to perform mean operator after all-reduce of mirror.
  374. "stand_alone" do not support mirror_mean. Default: False.
  375. cast_before_mirror (bool): Insert Mirror Op after the cast if this flag is True.
  376. "stand_alone", "data_parallel" and "hybrid_parallel" do not support
  377. cast_before_mirror. Default: True.
  378. parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel",
  379. "hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone".
  380. - stand_alone: Only one processor working.
  381. - data_parallel: Distributing the data across different processors.
  382. - hybrid_parallel: Achieving data parallelism and model parallelism manually.
  383. - semi_auto_parallel: Achieving data parallelism and model parallelism by
  384. setting parallel strategies.
  385. - auto_parallel: Achieving parallelism automatically.
  386. auto_parallel_search_mode (str): There are two kinds of search modes, "recursive_programming"
  387. and "dynamic_programming". Default: "dynamic_programming".
  388. - recursive_programming: Recursive programming search mode.
  389. - dynamic_programming: Dynamic programming search mode.
  390. parameter_broadcast (bool): Indicating whether to broadcast parameters before training.
  391. "stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter
  392. broadcast. Default: False.
  393. strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
  394. strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
  395. full_batch (bool): Whether to load the whole batch on each device. Default: False.
  396. enable_parallel_optimizer(bool): This is a developing feature, which shards the weight update computation in
  397. data parallel training in the benefit of time and memory saving.
  398. max_call_depth(int): Specify the function call depth limit. Default: 1000.
  399. Raises:
  400. ValueError: If input key is not attribute in auto parallel context.
  401. Examples:
  402. >>> context.set_auto_parallel_context(device_num=8)
  403. >>> context.set_auto_parallel_context(global_rank=0)
  404. >>> context.set_auto_parallel_context(mirror_mean=True)
  405. >>> context.set_auto_parallel_context(cast_before_mirror=False)
  406. >>> context.set_auto_parallel_context(parallel_mode="auto_parallel")
  407. >>> context.set_auto_parallel_context(parameter_broadcast=False)
  408. >>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt")
  409. >>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt")
  410. >>> context.set_auto_parallel_context(max_call_depth=80)
  411. """
  412. _set_auto_parallel_context(**kwargs)
  413. def get_auto_parallel_context(attr_key):
  414. """
  415. Gets auto parallel context attribute value according to the key.
  416. Args:
  417. attr_key (str): The key of the attribute.
  418. Returns:
  419. Returns attribute value according to the key.
  420. Raises:
  421. ValueError: If input key is not attribute in auto parallel context.
  422. """
  423. return _get_auto_parallel_context(attr_key)
  424. def reset_auto_parallel_context():
  425. """
  426. Reset auto parallel context attributes to the default values:
  427. - device_num: 1.
  428. - global_rank: 0.
  429. - mirror_mean: False.
  430. - cast_before_mirror: True.
  431. - parallel_mode: "stand_alone".
  432. - parameter_broadcast: False.
  433. - strategy_ckpt_load_file: "".
  434. - strategy_ckpt_save_file: "".
  435. - enable_parallel_optimizer: False.
  436. """
  437. _reset_auto_parallel_context()
  438. @args_type_check(mode=int, precompile_only=bool, device_target=str, device_id=int, save_graphs=bool,
  439. save_graphs_path=str, enable_dump=bool,
  440. save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str,
  441. enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool,
  442. enable_graph_kernel=bool, check_bprop=bool, max_device_memory=str, print_file_path=str,
  443. enable_sparse=bool, max_call_depth=int)
  444. def set_context(**kwargs):
  445. """
  446. Sets context for running environment.
  447. Context should be configured before running your program. If there is no configuration,
  448. the "Ascend" device target will be used by default. GRAPH_MODE or
  449. PYNATIVE_MODE can be set by `mode` attribute and both modes support all backends, default
  450. mode is PYNATIVE_MODE.
  451. When the `save_graphs` attribute is set to True, attribute of `save_graphs_path` is used to set the
  452. intermediate compilation graph storage path. By default, the graphs are saved in the current directory.
  453. As for other configurations and arguments, please refer to the corresponding module
  454. description, the configuration is optional and can be enabled when needed.
  455. Note:
  456. Attribute name is required for setting attributes.
  457. The mode is not recommended to be changed after net was initilized because the implementations of some
  458. operations are different in graph mode and pynative mode. Default: PYNATIVE_MODE.
  459. Args:
  460. mode (int): Running in GRAPH_MODE(0) or PYNATIVE_MODE(1).
  461. device_target (str): The target device to run, support "Ascend", "GPU", "CPU". Default: "Ascend".
  462. device_id (int): Id of target device, the value must be in [0, device_num_per_host-1],
  463. while device_num_per_host should no more than 4096. Default: 0.
  464. save_graphs (bool): Whether to save graphs. Default: False.
  465. save_graphs_path (str): Path to save graphs. Default: "."
  466. enable_auto_mixed_precision (bool): Whether to enable auto mixed precision. Default: True.
  467. enable_graph_kernel (bool): Whether to enable composition of basic primitives. These primitives would be
  468. compiled into a fused kernel automatically. Default: False.
  469. reserve_class_name_in_scope (bool) : Whether to save the network class name in the scope. Default: True.
  470. enable_reduce_precision (bool): Whether to enable precision reduction. Default: True.
  471. enable_dump (bool): Whether to enable dump. Default: False.
  472. save_dump_path (str): When the program is executed on Ascend, operators can dump data here.
  473. The root dump path is configured in /home/HwHiAiUser/ide_daemon/ide_daemon.cfg.
  474. So the real dump path is "{configured root dump path}/{`save_dump_path`}". Default: ".".
  475. variable_memory_max_size (str): Sets variable memory max size. Default: "5GB".
  476. enable_profiling (bool): Whether to open profiling. Default: False.
  477. profiling_options (str): Sets profiling collection options, operators can profiling data here.
  478. Profiling collection options, the values are as follows, supporting the collection of multiple data.
  479. - training_trace: collect iterative trajectory data, that is, the training task and software information of
  480. the AI software stack, to achieve performance analysis of the training task, focusing on data
  481. enhancement, forward and backward calculation, gradient aggregation update and other related data.
  482. - task_trace: collect task trajectory data, that is, the hardware information of the HWTS/AICore of
  483. the Ascend 910 processor, and analyze the information of start and end of the task.
  484. - op_trace: collect single operator performance data.
  485. The profiling can choose training_trace, task_trace, training_trace and task_trace combination and
  486. separated by colons; single operator can choose op_trace, op_trace cannot be combined with
  487. training_trace and task_trace. Default: "training_trace".
  488. check_bprop (bool): Whether to check bprop. Default: False.
  489. max_device_memory (str): Sets the maximum memory available for device, currently only supported on GPU.
  490. The format is "xxGB". Default: "1024GB".
  491. print_file_path (str): The path of print data to save. If this parameter is set, print data is saved to
  492. a file by default, and turn off printing to the screen. If the file already exists, add a timestamp
  493. suffix to the file.
  494. enable_sparse (bool): Whether to enable sparsity feature. Default: False.
  495. Raises:
  496. ValueError: If input key is not an attribute in context.
  497. Examples:
  498. >>> context.set_context(mode=context.GRAPH_MODE)
  499. >>> context.set_context(mode=context.PYNATIVE_MODE)
  500. >>> context.set_context(device_target="Ascend")
  501. >>> context.set_context(device_id=0)
  502. >>> context.set_context(save_graphs=True, save_graphs_path="./model.ms")
  503. >>> context.set_context(enable_reduce_precision=True)
  504. >>> context.set_context(enable_dump=True, save_dump_path=".")
  505. >>> context.set_context(reserve_class_name_in_scope=True)
  506. >>> context.set_context(variable_memory_max_size="6GB")
  507. >>> context.set_context(mode=context.GRAPH_MODE,
  508. >>> device_target="Ascend",device_id=0, save_graphs=True,
  509. >>> save_graphs_path="/mindspore")
  510. >>> context.set_context(enable_profiling=True, profiling_options="training_trace")
  511. >>> context.set_context(max_device_memory="3.5GB")
  512. >>> context.set_context(print_file_path="print.pb")
  513. """
  514. for key, value in kwargs.items():
  515. if not hasattr(_context(), key):
  516. raise ValueError("Set context keyword %s is not recognized!" % key)
  517. setattr(_context(), key, value)
  518. def get_context(attr_key):
  519. """
  520. Gets context attribute value according to the input key.
  521. Args:
  522. attr_key (str): The key of the attribute.
  523. Returns:
  524. Object, The value of given attribute key.
  525. Raises:
  526. ValueError: If input key is not an attribute in context.
  527. """
  528. if not hasattr(_context(), attr_key):
  529. raise ValueError(
  530. "Get context keyword %s is not recognized!" % attr_key)
  531. return getattr(_context(), attr_key)