You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

context.py 32 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. The context of mindspore, used to configure the current execution environment,
  17. includes the execution mode, execution backend and other feature switches.
  18. """
  19. import os
  20. import time
  21. import threading
  22. from collections import namedtuple
  23. from types import FunctionType
  24. from mindspore import log as logger
  25. from mindspore._c_expression import MSContext, ms_ctx_param
  26. from mindspore._checkparam import args_type_check, Validator
  27. from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \
  28. _reset_auto_parallel_context
  29. from mindspore.parallel._ps_context import _set_ps_context, _get_ps_context, _reset_ps_context
  30. from .default_config import __device_target__, __package_name__
  31. __all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'set_context', 'get_context', 'set_auto_parallel_context',
  32. 'get_auto_parallel_context', 'reset_auto_parallel_context', 'ParallelMode', 'set_ps_context',
  33. 'get_ps_context', 'reset_ps_context']
  34. GRAPH_MODE = 0
  35. PYNATIVE_MODE = 1
  36. _DEVICE_APP_MEMORY_SIZE = 31 # The max memory size of graph plus variable.
  37. _re_pattern = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB'
  38. _k_context = None
  39. def _make_directory(path):
  40. """Make directory."""
  41. real_path = None
  42. if path is None or not isinstance(path, str) or path.strip() == "":
  43. raise ValueError(f"Input path `{path}` is invalid type")
  44. # convert the relative paths
  45. path = os.path.realpath(path)
  46. logger.debug("The absolute path is %r", path)
  47. # check whether the path is already existed and has written permissions
  48. if os.path.exists(path):
  49. real_path = path
  50. else:
  51. # All exceptions need to be caught because create directory maybe have some limit(permissions)
  52. logger.debug("The directory(%s) doesn't exist, will create it", path)
  53. try:
  54. os.makedirs(path)
  55. real_path = path
  56. except PermissionError as e:
  57. logger.error(f"No write permission on the directory `{path}, error = {e}")
  58. raise ValueError(f"No write permission on the directory `{path}`.")
  59. return real_path
  60. def _get_print_file_name(file_name):
  61. """Add timestamp suffix to file name. Rename the file name: file_name + "." + time(seconds)."""
  62. time_second = str(int(time.time()))
  63. file_name = file_name + "." + time_second
  64. if os.path.exists(file_name):
  65. ValueError("This file {} already exists.".format(file_name))
  66. return file_name
  67. class _ThreadLocalInfo(threading.local):
  68. """
  69. Thread local Info used for store thread local attributes.
  70. """
  71. def __init__(self):
  72. super(_ThreadLocalInfo, self).__init__()
  73. self._reserve_class_name_in_scope = True
  74. @property
  75. def reserve_class_name_in_scope(self):
  76. """Gets whether to save the network class name in the scope."""
  77. return self._reserve_class_name_in_scope
  78. @reserve_class_name_in_scope.setter
  79. def reserve_class_name_in_scope(self, reserve_class_name_in_scope):
  80. """Sets whether to save the network class name in the scope."""
  81. if not isinstance(reserve_class_name_in_scope, bool):
  82. raise ValueError(
  83. "Set reserve_class_name_in_scope value must be bool!")
  84. self._reserve_class_name_in_scope = reserve_class_name_in_scope
  85. _ContextRecord = namedtuple(
  86. "_ContextRecord", ["is_pynative_mode", "switch_context_fn"])
  87. class _ContextSwitchInfo(threading.local):
  88. """
  89. Record of context switch information.
  90. Args:
  91. is_pynative (bool): Whether to adopt the PyNative mode.
  92. """
  93. def __init__(self, is_pynative):
  94. super(_ContextSwitchInfo, self).__init__()
  95. self.context_stack = []
  96. if is_pynative:
  97. self.push(True, None)
  98. def push(self, is_pynative, switch_context_fn):
  99. """
  100. Push a context switch record onto the stack.
  101. Args:
  102. is_pynative (bool): Whether context switch to PyNative mode.
  103. switch_context_fn (Function): A callable that executes the context switch.
  104. """
  105. if isinstance(switch_context_fn, FunctionType):
  106. switch_context_fn()
  107. self.context_stack.append(
  108. _ContextRecord(is_pynative, switch_context_fn))
  109. def pop(self):
  110. self.context_stack.pop()
  111. class _Context:
  112. """
  113. _Context is the environment in which operations are executed
  114. Note:
  115. Create a context through instantiating Context object is not recommended.
  116. should use context() to get the context since Context is singleton.
  117. """
  118. _instance = None
  119. _instance_lock = threading.Lock()
  120. def __init__(self):
  121. self._thread_local_info = _ThreadLocalInfo()
  122. self._context_switches = _ContextSwitchInfo(True)
  123. self._context_handle = MSContext.get_instance()
  124. def __new__(cls, *args, **kwargs):
  125. if cls._instance is None:
  126. cls._instance_lock.acquire()
  127. cls._instance = object.__new__(cls)
  128. cls._instance_lock.release()
  129. return cls._instance
  130. def __getattribute__(self, attr):
  131. value = object.__getattribute__(self, attr)
  132. if attr == "_context_handle" and value is None:
  133. raise ValueError("Context handle is none in context!!!")
  134. return value
  135. def get_param(self, param):
  136. return self._context_handle.get_param(param)
  137. def set_param(self, param, value):
  138. self._context_handle.set_param(param, value)
  139. def set_mode(self, mode):
  140. """
  141. Switch between Graph mode and PyNative mode.
  142. Args:
  143. mode (int): GRAPH_MODE or PYNATIVE_MODE.
  144. """
  145. if mode == PYNATIVE_MODE:
  146. if self.enable_debug_runtime:
  147. self.set_backend_policy("vm")
  148. self._context_switches.push(True, None)
  149. elif mode == GRAPH_MODE:
  150. if self.enable_debug_runtime:
  151. self.set_backend_policy("ge")
  152. self._context_switches.push(False, None)
  153. else:
  154. raise ValueError(f'The execution mode {mode} is invalid!')
  155. self.set_param(ms_ctx_param.mode, mode)
  156. def set_backend_policy(self, policy):
  157. success = self._context_handle.set_backend_policy(policy)
  158. if not success:
  159. raise RuntimeError("Backend policy must be one of ge, vm, ms.")
  160. def set_save_graphs_path(self, save_graphs_path):
  161. self.set_param(ms_ctx_param.save_graphs_path, _make_directory(save_graphs_path))
  162. def set_device_target(self, target):
  163. valid_targets = ["CPU", "GPU", "Ascend", "Davinci"]
  164. if not target in valid_targets:
  165. raise ValueError(f"Target device name {target} is invalid! It must be one of {valid_targets}")
  166. if target == "Davinci":
  167. target = "Ascend"
  168. self.set_param(ms_ctx_param.device_target, target)
  169. if self.enable_debug_runtime and target == "CPU":
  170. self.set_backend_policy("vm")
  171. def set_device_id(self, device_id):
  172. if device_id < 0 or device_id > 4095:
  173. raise ValueError(f"Device id must be in [0, 4095], but got {device_id}")
  174. self.set_param(ms_ctx_param.device_id, device_id)
  175. def set_max_call_depth(self, max_call_depth):
  176. if max_call_depth <= 0:
  177. raise ValueError(f"Max call depth must be greater than 0, but got {max_call_depth}")
  178. self.set_param(ms_ctx_param.max_call_depth, max_call_depth)
  179. def set_profiling_options(self, option):
  180. options = ["training_trace", "task_trace",
  181. "task_trace:training_trace", "training_trace:task_trace", "op_trace"]
  182. if option not in options:
  183. raise ValueError("Profiling options must be in 'training_trace' 'task_trace' "
  184. "'task_trace:training_trace' 'training_trace:task_trace' or 'op_trace'.")
  185. self.set_param(ms_ctx_param.profiling_options, option)
  186. def set_variable_memory_max_size(self, variable_memory_max_size):
  187. """set values of variable_memory_max_size and graph_memory_max_size"""
  188. if not Validator.check_str_by_regular(variable_memory_max_size, _re_pattern):
  189. raise ValueError("Context param variable_memory_max_size should be in correct format! Such as \"5GB\"")
  190. if int(variable_memory_max_size[:-2]) >= _DEVICE_APP_MEMORY_SIZE:
  191. raise ValueError("Context param variable_memory_max_size should be less than 31GB.")
  192. variable_memory_max_size_ = variable_memory_max_size[:-2] + " * 1024 * 1024 * 1024"
  193. graph_memory_max_size = _DEVICE_APP_MEMORY_SIZE - int(variable_memory_max_size[:-2])
  194. graph_memory_max_size_ = str(graph_memory_max_size) + " * 1024 * 1024 * 1024"
  195. self.set_param(ms_ctx_param.variable_memory_max_size, variable_memory_max_size_)
  196. # pylint: disable=protected-access
  197. self.set_param(ms_ctx_param._graph_memory_max_size, graph_memory_max_size_)
  198. def set_max_device_memory(self, max_device_memory):
  199. if not Validator.check_str_by_regular(max_device_memory, _re_pattern):
  200. raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"")
  201. max_device_memory_value = float(max_device_memory[:-2])
  202. if max_device_memory_value == 0:
  203. raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"")
  204. self.set_param(ms_ctx_param.max_device_memory, max_device_memory_value)
  205. def set_print_file_path(self, file_path):
  206. """Add timestamp suffix to file name. Sets print file path."""
  207. print_file_path = os.path.realpath(file_path)
  208. if os.path.isdir(print_file_path):
  209. raise IOError("Print_file_path should be file path, but got {}.".format(file_path))
  210. if os.path.exists(print_file_path):
  211. _path, _file_name = os.path.split(print_file_path)
  212. path = _make_directory(_path)
  213. file_name = _get_print_file_name(_file_name)
  214. full_file_name = os.path.join(path, file_name)
  215. else:
  216. full_file_name = print_file_path
  217. self.set_param(ms_ctx_param.print_file_path, full_file_name)
  218. setters = {
  219. 'mode': set_mode,
  220. 'backend_policy': set_backend_policy,
  221. 'save_graphs_path': set_save_graphs_path,
  222. 'device_target': set_device_target,
  223. 'device_id': set_device_id,
  224. 'max_call_depth': set_max_call_depth,
  225. 'profiling_options': set_profiling_options,
  226. 'variable_memory_max_size': set_variable_memory_max_size,
  227. 'max_device_memory': set_max_device_memory,
  228. 'print_file_path': set_print_file_path
  229. }
  230. @property
  231. def reserve_class_name_in_scope(self):
  232. """Gets whether to save the network class name in the scope."""
  233. return self._thread_local_info.reserve_class_name_in_scope
  234. @reserve_class_name_in_scope.setter
  235. def reserve_class_name_in_scope(self, reserve_class_name_in_scope):
  236. """Sets whether to save the network class name in the scope."""
  237. self._thread_local_info.reserve_class_name_in_scope = reserve_class_name_in_scope
  238. @property
  239. def enable_ge(self):
  240. return self._context_handle.get_backend_policy() == 'ge'
  241. @property
  242. def enable_debug_runtime(self):
  243. return self._thread_local_info.debug_runtime
  244. @enable_debug_runtime.setter
  245. def enable_debug_runtime(self, enable):
  246. thread_info = self._thread_local_info
  247. thread_info.debug_runtime = enable
  248. def _context():
  249. """
  250. Get the global _context, if context is not created, create a new one.
  251. Returns:
  252. _Context, the global context in PyNative mode.
  253. """
  254. global _k_context
  255. if _k_context is None:
  256. default_backend = 'debug'
  257. try:
  258. from mindspore import default_config
  259. default_backend = default_config.__backend__
  260. except ImportError:
  261. logger.error("import default config fail")
  262. _k_context = _Context()
  263. _k_context.enable_debug_runtime = False
  264. if default_backend == 'debug':
  265. _k_context.enable_debug_runtime = True
  266. default_backend = 'vm'
  267. _k_context.set_backend_policy(default_backend)
  268. return _k_context
  269. @args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, parallel_mode=str,
  270. auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
  271. strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
  272. all_reduce_fusion_config=list, pipeline_stages=int)
  273. def set_auto_parallel_context(**kwargs):
  274. r"""
  275. Set auto parallel context, which is valid only for Ascend and GPU target.
  276. Auto parallel context should be configured before the initialization of your network.
  277. Note:
  278. Attribute name is required for setting attributes.
  279. If a program has tasks with different parallel modes, then before setting new parallel mode for the
  280. next task, interface mindspore.context.reset_auto_parallel_context() needs to be called to reset
  281. the configuration.
  282. Setting or changing parallel modes must be called before any creating Initializer, otherwise,
  283. RuntimeError may be raised when compiling the network.
  284. Some configurations are parallel mode specific, see the below table for details:
  285. =========================== ===========================
  286. Common AUTO_PARALLEL
  287. =========================== ===========================
  288. device_num gradient_fp32_sync
  289. global_rank loss_repeated_mean
  290. gradients_mean auto_parallel_search_mode
  291. parallel_mode strategy_ckpt_load_file
  292. all_reduce_fusion_config strategy_ckpt_save_file
  293. enable_parallel_optimizer full_batch
  294. \ pipeline_stages
  295. =========================== ===========================
  296. Args:
  297. device_num (int): Available device number, the value must be in [1, 4096]. Default: 1.
  298. global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0.
  299. gradients_mean (bool): Whether to perform mean operator after allreduce of gradients.
  300. "stand_alone" do not support gradients_mean. Default: False.
  301. gradient_fp32_sync (bool): Run allreduce of gradients in fp32.
  302. "stand_alone", "data_parallel" and "hybrid_parallel" do not support
  303. gradient_fp32_sync. Default: True.
  304. parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel",
  305. "hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone".
  306. - stand_alone: Only one processor is working.
  307. - data_parallel: Distributes the data across different processors.
  308. - hybrid_parallel: Achieves data parallelism and model parallelism manually.
  309. - semi_auto_parallel: Achieves data parallelism and model parallelism by
  310. setting parallel strategies.
  311. - auto_parallel: Achieving parallelism automatically.
  312. auto_parallel_search_mode (str): There are two kinds of shard strategy search modes, "recursive_programming"
  313. and "dynamic_programming". Default: "dynamic_programming".
  314. - recursive_programming: Recursive programming search mode.
  315. - dynamic_programming: Dynamic programming search mode.
  316. parameter_broadcast (bool): A developing feature. Whether to broadcast parameters before training.
  317. "stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter
  318. broadcast. Default: False.
  319. strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
  320. strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
  321. full_batch (bool): If you load whole batch datasets in auto_parallel mode, this parameter
  322. should be set with True. Default: False.
  323. enable_parallel_optimizer (bool): This is a developing feature, which shards the weight update computation for
  324. data parallel training in the benefit of time and memory saving. For now, auto parallel mode
  325. supports all optimizers. Data parallel mode only supports `Lamb` and `AdamWeightDecay`.
  326. Default: False.
  327. all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. Only support ReduceOp.SUM
  328. and HCCL_WORLD_GROUP/NCCL_WORLD_GROUP. No Default, if it is not set, the fusion is closed.
  329. pipeline_stages (int): Set the stage information for pipeline parallel. This indicates how
  330. the devices are distributed alone the pipeline. The total devices will be divided into
  331. 'pipeline_stags' stages. This currently could only be used when
  332. parall mode semi_auto_parallel is enabled.
  333. Raises:
  334. ValueError: If input key is not attribute in auto parallel context.
  335. Examples:
  336. >>> context.set_auto_parallel_context(device_num=8)
  337. >>> context.set_auto_parallel_context(global_rank=0)
  338. >>> context.set_auto_parallel_context(gradients_mean=True)
  339. >>> context.set_auto_parallel_context(gradient_fp32_sync=False)
  340. >>> context.set_auto_parallel_context(parallel_mode="auto_parallel")
  341. >>> context.set_auto_parallel_context(auto_parallel_search_mode="dynamic_programming")
  342. >>> context.set_auto_parallel_context(parameter_broadcast=False)
  343. >>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt")
  344. >>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt")
  345. >>> context.set_auto_parallel_context(full_batch=True)
  346. >>> context.set_auto_parallel_context(enable_parallel_optimizer=False)
  347. >>> context.set_auto_parallel_context(all_reduce_fusion_config=[8, 160])
  348. >>> context.set_auto_parallel_context(pipeline_stages=2)
  349. """
  350. _set_auto_parallel_context(**kwargs)
  351. def get_auto_parallel_context(attr_key):
  352. """
  353. Gets auto parallel context attribute value according to the key.
  354. Args:
  355. attr_key (str): The key of the attribute.
  356. Returns:
  357. Returns attribute value according to the key.
  358. Raises:
  359. ValueError: If input key is not attribute in auto parallel context.
  360. """
  361. return _get_auto_parallel_context(attr_key)
  362. def reset_auto_parallel_context():
  363. """
  364. Reset auto parallel context attributes to the default values:
  365. - device_num: 1.
  366. - global_rank: 0.
  367. - gradients_mean: False.
  368. - gradient_fp32_sync: True.
  369. - parallel_mode: 'stand_alone'.
  370. - auto_parallel_search_mode: 'dynamic_programming'.
  371. - parameter_broadcast: False.
  372. - strategy_ckpt_load_file: ''.
  373. - strategy_ckpt_save_file: ''.
  374. - full_batch: False.
  375. - enable_parallel_optimizer: False.
  376. """
  377. _reset_auto_parallel_context()
  378. def _check_target_specific_cfgs(device, arg_key):
  379. """Checking whether a config is sutable for a specified device"""
  380. device_cfgs = {
  381. 'enable_auto_mixed_precision': ['Ascend'],
  382. 'enable_dump': ['Ascend'],
  383. 'save_dump_path': ['Ascend'],
  384. 'enable_graph_kernel': ['Ascend', 'GPU'],
  385. 'enable_reduce_precision': ['Ascend'],
  386. 'enable_profiling': ['Ascend'],
  387. 'profiling_options': ['Ascend'],
  388. 'print_file_path': ['Ascend'],
  389. 'variable_memory_max_size': ['Ascend'],
  390. 'max_device_memory': ['GPU']
  391. }
  392. # configs not in map device_cfgs are supposed to be suitable for all devices
  393. if not arg_key in device_cfgs:
  394. return True
  395. supported_devices = device_cfgs[arg_key]
  396. if device in supported_devices:
  397. return True
  398. logger.warning(f"Config '{arg_key}' only supports devices in {supported_devices}, current device is '{device}'"
  399. ", ignore it.")
  400. return False
  401. @args_type_check(mode=int, precompile_only=bool, device_target=str, device_id=int, save_graphs=bool,
  402. save_graphs_path=str, enable_dump=bool,
  403. save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str,
  404. enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool,
  405. enable_graph_kernel=bool, check_bprop=bool, max_device_memory=str, print_file_path=str,
  406. enable_sparse=bool, max_call_depth=int)
  407. def set_context(**kwargs):
  408. """
  409. Sets context for running environment.
  410. Context should be configured before running your program. If there is no configuration,
  411. the "Ascend" device target will be used by default. GRAPH_MODE or
  412. PYNATIVE_MODE can be set by `mode` attribute and both modes support all backends, default
  413. mode is PYNATIVE_MODE.
  414. When the `save_graphs` attribute is set to True, attribute of `save_graphs_path` is used to set the
  415. intermediate compilation graph storage path. By default, the graphs are saved in the current directory.
  416. For other configurations and arguments, please refer to the corresponding module
  417. description, the configuration is optional and can be enabled when needed.
  418. Note:
  419. Attribute name is required for setting attributes.
  420. The mode is not recommended to be changed after net was initilized because the implementations of some
  421. operations are different in graph mode and pynative mode. Default: PYNATIVE_MODE.
  422. Some configurations are device specific, see the bellow table for details:
  423. =========================== =========================== =================
  424. Common(CPU/GPU/Ascend) Ascend GPU
  425. =========================== =========================== =================
  426. check_bprop enable_auto_mixed_precision max_device_memory
  427. device_id enable_dump enable_graph_kernel
  428. device_target save_dump_path
  429. enable_sparse enable_graph_kernel
  430. max_call_depth enable_reduce_precision
  431. mode enable_profiling
  432. reserve_class_name_in_scope profiling_options
  433. save_graphs variable_memory_max_size
  434. save_graphs_path print_file_path
  435. =========================== =========================== =================
  436. Args:
  437. mode (int): Running in GRAPH_MODE(0) or PYNATIVE_MODE(1). Default: PYNATIVE_MODE(1).
  438. device_target (str): The target device to run, support "Ascend", "GPU", and "CPU". Default: "Ascend".
  439. device_id (int): ID of the target device, the value must be in [0, device_num_per_host-1],
  440. while device_num_per_host should be no more than 4096. Default: 0.
  441. save_graphs (bool): Whether to save graphs. Default: False.
  442. save_graphs_path (str): Path to save graphs. Default: "."
  443. enable_auto_mixed_precision (bool): Whether to enable auto mixed precision. Default: False.
  444. enable_graph_kernel (bool): Whether to enable composition of basic primitives. These primitives would be
  445. compiled into a fused kernel automatically. Default: False.
  446. reserve_class_name_in_scope (bool) : Whether to save the network class name in the scope. Default: True.
  447. enable_reduce_precision (bool): Whether to enable precision reduction. Default: True.
  448. enable_dump (bool): Whether to enable dump. Default: False.
  449. save_dump_path (str): When the program is executed on Ascend, operators can dump data in this path.
  450. The root dump path is configured in /home/HwHiAiUser/ide_daemon/ide_daemon.cfg.
  451. So the real dump path is "{configured root dump path}/{`save_dump_path`}". Default: ".".
  452. variable_memory_max_size (str): Set the maximum size of the variable memory max size. Default: "0GB".
  453. enable_profiling (bool): Whether to open profiling. Default: False.
  454. profiling_options (str): Set profiling collection options, operators can profiling data here.
  455. The values of profiling collection options are as follows, supporting the collection of multiple data.
  456. - training_trace: collect iterative trajectory data, that is, the training task and software information of
  457. the AI software stack, to achieve performance analysis of the training task, focusing on data
  458. enhancement, forward and backward calculation, gradient aggregation update and other related data.
  459. - task_trace: collect task trajectory data, that is, the hardware information of the HWTS/AICore of
  460. the Ascend 910 processor, and analyze the information of beginning and ending of the task.
  461. - op_trace: collect single operator performance data.
  462. The profiling can choose the combination of `training_trace`, `task_trace`,
  463. `training_trace` and `task_trace` combination, and eparated by colons;
  464. a single operator can choose `op_trace`, `op_trace` cannot be combined with
  465. `training_trace` and `task_trace`. Default: "training_trace".
  466. check_bprop (bool): Whether to check bprop. Default: False.
  467. max_device_memory (str): Sets the maximum memory available for devices.
  468. Currently, it is only supported on GPU. The format is "xxGB". Default: "1024GB".
  469. print_file_path (str): The path of saving print data. If this parameter is set, print data is saved to
  470. a file by default, and turns off printing to the screen. If the file already exists, add a timestamp
  471. suffix to the file. Default: ''.
  472. enable_sparse (bool): Whether to enable sparsity feature. Default: False.
  473. max_call_depth(int): Specify the maximum depth of function call. Default: 1000.
  474. Raises:
  475. ValueError: If input key is not an attribute in context.
  476. Examples:
  477. >>> context.set_context(mode=context.GRAPH_MODE)
  478. >>> context.set_context(mode=context.PYNATIVE_MODE)
  479. >>> context.set_context(device_target="Ascend")
  480. >>> context.set_context(device_id=0)
  481. >>> context.set_context(save_graphs=True, save_graphs_path="./model.ms")
  482. >>> context.set_context(enable_reduce_precision=True)
  483. >>> context.set_context(enable_dump=True, save_dump_path=".")
  484. >>> context.set_context(reserve_class_name_in_scope=True)
  485. >>> context.set_context(variable_memory_max_size="6GB")
  486. >>> context.set_context(mode=context.GRAPH_MODE,
  487. >>> device_target="Ascend",device_id=0, save_graphs=True,
  488. >>> save_graphs_path="/mindspore")
  489. >>> context.set_context(enable_profiling=True, profiling_options="training_trace")
  490. >>> context.set_context(max_device_memory="3.5GB")
  491. >>> context.set_context(print_file_path="print.pb")
  492. >>> context.set_context(max_call_depth=80)
  493. """
  494. ctx = _context()
  495. # set device target first
  496. if 'device_target' in kwargs:
  497. ctx.set_device_target(kwargs['device_target'])
  498. device = ctx.get_param(ms_ctx_param.device_target)
  499. if not device.lower() in __device_target__:
  500. raise ValueError(f"Error, package type {__package_name__} support device type {__device_target__}, "
  501. f"but got device target {device}")
  502. device = ctx.get_param(ms_ctx_param.device_target)
  503. for key, value in kwargs.items():
  504. if not _check_target_specific_cfgs(device, key):
  505. continue
  506. if hasattr(ctx, key):
  507. setattr(ctx, key, value)
  508. continue
  509. if key in ctx.setters:
  510. ctx.setters[key](ctx, value)
  511. continue
  512. # enum variables begining with '_' are for internal use
  513. if key in ms_ctx_param.__members__ and key[0] != '_':
  514. ctx.set_param(ms_ctx_param.__members__[key], value)
  515. continue
  516. raise ValueError("Set context keyword %s is not recognized!" % key)
  517. def get_context(attr_key):
  518. """
  519. Gets context attribute value according to the input key.
  520. Args:
  521. attr_key (str): The key of the attribute.
  522. Returns:
  523. Object, The value of given attribute key.
  524. Raises:
  525. ValueError: If input key is not an attribute in context.
  526. """
  527. ctx = _context()
  528. device = ctx.get_param(ms_ctx_param.device_target)
  529. _ = _check_target_specific_cfgs(device, attr_key)
  530. if hasattr(ctx, attr_key):
  531. return getattr(ctx, attr_key)
  532. # enum variables begining with '_' are for internal use
  533. if attr_key in ms_ctx_param.__members__ and attr_key[0] != '_':
  534. return ctx.get_param(ms_ctx_param.__members__[attr_key])
  535. raise ValueError("Get context keyword %s is not recognized!" % attr_key)
  536. class ParallelMode:
  537. """
  538. Parallel mode options.
  539. There are five kinds of parallel modes, "STAND_ALONE", "DATA_PARALLEL",
  540. "HYBRID_PARALLEL", "SEMI_AUTO_PARALLEL" and "AUTO_PARALLEL". Default: "STAND_ALONE".
  541. - STAND_ALONE: Only one processor is working.
  542. - DATA_PARALLEL: Distributes the data across different processors.
  543. - HYBRID_PARALLEL: Achieves data parallelism and model parallelism manually.
  544. - SEMI_AUTO_PARALLEL: Achieves data parallelism and model parallelism by setting parallel strategies.
  545. - AUTO_PARALLEL: Achieves parallelism automatically.
  546. MODE_LIST: The list of all supported parallel modes.
  547. """
  548. STAND_ALONE = "stand_alone"
  549. DATA_PARALLEL = "data_parallel"
  550. HYBRID_PARALLEL = "hybrid_parallel"
  551. SEMI_AUTO_PARALLEL = "semi_auto_parallel"
  552. AUTO_PARALLEL = "auto_parallel"
  553. MODE_LIST = [STAND_ALONE, DATA_PARALLEL, HYBRID_PARALLEL, SEMI_AUTO_PARALLEL, AUTO_PARALLEL]
  554. @args_type_check(enable_ps=bool)
  555. def set_ps_context(**kwargs):
  556. """
  557. Set parameter server training mode context.
  558. Note:
  559. Some other environment variables should also be set for parameter server training mode.
  560. These environment variables are listed below:
  561. .. code-block::
  562. MS_SERVER_NUM # Server number
  563. MS_WORKER_NUM # Worker number
  564. MS_SCHED_HOST # Scheduler IP address
  565. MS_SCHED_PORT # Scheduler port
  566. MS_ROLE # The role of this process:
  567. # MS_SCHED represents the scheduler,
  568. # MS_WORKER represents the worker,
  569. # MS_PSERVER represents the Server
  570. Args:
  571. enable_ps (bool): Whether to enable parameter server training mode.
  572. Only after enable_ps is set True, the environment variables will be effective.
  573. Default: False.
  574. Raises:
  575. ValueError: If input key is not the attribute in parameter server training mode context.
  576. Examples:
  577. >>> context.set_ps_context(enable_ps=True)
  578. """
  579. _set_ps_context(**kwargs)
  580. def get_ps_context(attr_key):
  581. """
  582. Get parameter server training mode context attribute value according to the key.
  583. Args:
  584. attr_key (str): The key of the attribute.
  585. Returns:
  586. Returns attribute value according to the key.
  587. Raises:
  588. ValueError: If input key is not attribute in auto parallel context.
  589. """
  590. return _get_ps_context(attr_key)
  591. def reset_ps_context():
  592. """
  593. Reset parameter server training mode context attributes to the default values:
  594. - enable_ps: False.
  595. """
  596. _reset_ps_context()