You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_auto_parallel_context.py 25 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Context of auto parallel"""
  16. import threading
  17. import mindspore.context as context
  18. from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size
  19. from mindspore.parallel._ps_context import _is_role_pserver
  20. from mindspore._c_expression import AutoParallelContext
  21. from mindspore._checkparam import args_type_check
  22. _MAX_GROUP_NAME_LEN = 127
  23. _DEFAULT_HCCL_FUSION_GROUP_NAME = "hccl_world_groupsum1"
  24. _DEFAULT_NCCL_FUSION_GROUP_NAME = "nccl_world_groupsum1"
  25. class _AutoParallelContext:
  26. """
  27. _AutoParallelContext is the environment in which operations are executed
  28. Note:
  29. Create a context through instantiating Context object is not recommended.
  30. Should use auto_parallel_context() to get the context since Context is singleton.
  31. """
  32. _instance = None
  33. _instance_lock = threading.Lock()
  34. def __init__(self):
  35. self._context_handle = AutoParallelContext.get_instance()
  36. def __new__(cls):
  37. if cls._instance is None:
  38. cls._instance_lock.acquire()
  39. cls._instance = object.__new__(cls)
  40. cls._instance_lock.release()
  41. return cls._instance
  42. def check_context_handle(self):
  43. """
  44. Check context handle.
  45. Raises:
  46. ValueError: If the context handle is none.
  47. """
  48. if self._context_handle is None:
  49. raise ValueError("Context handle is none in context!!!")
  50. def set_device_num(self, device_num):
  51. """
  52. Set device num for auto parallel.
  53. Args:
  54. device_num (int): The device number.
  55. Raises:
  56. ValueError: If the device num is not in [1, 4096].
  57. """
  58. self.check_context_handle()
  59. if device_num < 1 or device_num > 4096:
  60. raise ValueError("Device num must be in [1, 4096], but got {}".format(device_num))
  61. self._context_handle.set_device_num(device_num)
  62. def get_device_num(self):
  63. """Get device num."""
  64. self.check_context_handle()
  65. return self._context_handle.get_device_num()
  66. def set_global_rank(self, global_rank):
  67. """
  68. Set global rank for auto parallel.
  69. Args:
  70. global_rank (int): The rank id of current rank.
  71. Raises:
  72. ValueError: If the global rank is not in [1, 4096].
  73. """
  74. self.check_context_handle()
  75. if global_rank < 0 or global_rank > 4095:
  76. raise ValueError("Global rank must be in [0, 4095], but got {}".format(global_rank))
  77. self._context_handle.set_global_rank(global_rank)
  78. def get_global_rank(self):
  79. """Get current rank id."""
  80. self.check_context_handle()
  81. return self._context_handle.get_global_rank()
  82. def set_pipeline_stages(self, stages):
  83. """Set the stages of the pipeline"""
  84. self.check_context_handle()
  85. self._context_handle.set_pipeline_stage_split_num(stages)
  86. def get_pipeline_stages(self):
  87. """Get the stages of the pipeline"""
  88. self.check_context_handle()
  89. return self._context_handle.get_pipeline_stage_split_num()
  90. def set_gradients_mean(self, gradients_mean):
  91. """
  92. Set gradients_mean flag.
  93. Note:
  94. If gradients_mean is true, it will insert a div operator after parameter gradients allreduce.
  95. Args:
  96. gradients_mean (bool): The gradients_mean flag.
  97. """
  98. self.check_context_handle()
  99. self._context_handle.set_gradients_mean(gradients_mean)
  100. def get_gradients_mean(self):
  101. """Get gradients_mean flag."""
  102. self.check_context_handle()
  103. return self._context_handle.get_gradients_mean()
  104. def set_gradient_fp32_sync(self, gradient_fp32_sync):
  105. """
  106. Set gradient_fp32_sync.
  107. Note:
  108. If gradient_fp32_sync is true,
  109. it will convert tensor type from fp16 to fp32 before parameter gradients allreduce.
  110. Args:
  111. gradient_fp32_sync (bool): The gradient_fp32_sync flag.
  112. """
  113. self.check_context_handle()
  114. self._context_handle.set_gradient_fp32_sync(gradient_fp32_sync)
  115. def get_gradient_fp32_sync(self):
  116. """Get gradient_fp32_sync flag."""
  117. self.check_context_handle()
  118. return self._context_handle.get_gradient_fp32_sync()
  119. def set_loss_repeated_mean(self, loss_repeated_mean):
  120. """
  121. Set loss_repeated_mean flag.
  122. Note:
  123. If loss_repeated_mean is true,
  124. Distributed automatic differentiation will perform a mean operator
  125. in backward in the case of repeated calculations.
  126. Args:
  127. loss_repeated_mean (bool): The loss_repeated_mean flag.
  128. """
  129. self.check_context_handle()
  130. self._context_handle.set_loss_repeated_mean(loss_repeated_mean)
  131. def get_loss_repeated_mean(self):
  132. """Get loss_repeated_mean flag."""
  133. self.check_context_handle()
  134. return self._context_handle.get_loss_repeated_mean()
  135. def set_parallel_mode(self, parallel_mode):
  136. """
  137. Set parallel mode for auto parallel.
  138. Args:
  139. parallel_mode (str): The parallel mode of auto parallel.
  140. Raises:
  141. ValueError: If parallel mode is not supported.
  142. """
  143. self.check_context_handle()
  144. ret = self._context_handle.set_parallel_mode(parallel_mode)
  145. if ret is False:
  146. raise ValueError("Parallel mode does not support {}".format(parallel_mode))
  147. def get_parallel_mode(self):
  148. """Get parallel mode."""
  149. self.check_context_handle()
  150. if _is_role_pserver():
  151. return context.ParallelMode.STAND_ALONE
  152. return self._context_handle.get_parallel_mode()
  153. def set_strategy_search_mode(self, auto_parallel_search_mode):
  154. """
  155. Set search mode of strategy.
  156. Args:
  157. auto_parallel_search_mode (str): The search mode of strategy.
  158. """
  159. self.check_context_handle()
  160. ret = self._context_handle.set_strategy_search_mode(auto_parallel_search_mode)
  161. if ret is False:
  162. raise ValueError("Strategy search mode does not support {}".format(auto_parallel_search_mode))
  163. def get_strategy_search_mode(self):
  164. """Get search mode of strategy."""
  165. self.check_context_handle()
  166. return self._context_handle.get_strategy_search_mode()
  167. def set_parameter_broadcast(self, parameter_broadcast):
  168. """
  169. Set parameter broadcast.
  170. Args:
  171. parameter_broadcast (bool): Parameter broadcast or not.
  172. """
  173. self.check_context_handle()
  174. self._context_handle.set_parameter_broadcast(parameter_broadcast)
  175. def get_parameter_broadcast(self):
  176. """Get parameter broadcast flag."""
  177. self.check_context_handle()
  178. return self._context_handle.get_parameter_broadcast()
  179. def set_strategy_ckpt_load_file(self, strategy_ckpt_load_file):
  180. """
  181. Set strategy checkpoint load path.
  182. Args:
  183. strategy_ckpt_load_file (bool): Path to load parallel strategy checkpoint.
  184. """
  185. self.check_context_handle()
  186. self._context_handle.set_strategy_ckpt_load_file(strategy_ckpt_load_file)
  187. def get_strategy_ckpt_load_file(self):
  188. """Get strategy checkpoint load path."""
  189. self.check_context_handle()
  190. return self._context_handle.get_strategy_ckpt_load_file()
  191. def set_full_batch(self, full_batch):
  192. """
  193. Set whether load full batch on each device.
  194. Args:
  195. full_batch (bool): True if load full batch on each device.
  196. """
  197. self.check_context_handle()
  198. self._context_handle.set_full_batch(full_batch)
  199. def get_full_batch(self):
  200. """Get whether load full batch on each device."""
  201. self.check_context_handle()
  202. if _is_role_pserver():
  203. return False
  204. return self._context_handle.get_full_batch()
  205. def set_grad_accumulation_step(self, grad_accumulation_step):
  206. """
  207. Set grad accumulation step.
  208. Args:
  209. grad_accumulation_step (int): The grad accumulation step.
  210. """
  211. self.check_context_handle()
  212. self._context_handle.set_grad_accumulation_step(grad_accumulation_step)
  213. def get_grad_accumulation_step(self):
  214. """Get grad accumulation step."""
  215. self.check_context_handle()
  216. return self._context_handle.get_grad_accumulation_step()
  217. def set_strategy_ckpt_save_file(self, strategy_ckpt_save_file):
  218. """
  219. Set strategy checkpoint save path.
  220. Args:
  221. strategy_ckpt_save_file (bool): Path to save parallel strategy checkpoint.
  222. """
  223. self.check_context_handle()
  224. import os
  225. dir_path = os.path.dirname(strategy_ckpt_save_file)
  226. if dir_path and not os.path.exists(dir_path):
  227. os.makedirs(dir_path)
  228. self._context_handle.set_strategy_ckpt_save_file(strategy_ckpt_save_file)
  229. def get_strategy_ckpt_save_file(self):
  230. """Get strategy checkpoint save path."""
  231. self.check_context_handle()
  232. return self._context_handle.get_strategy_ckpt_save_file()
  233. def set_group_ckpt_save_file(self, group_ckpt_save_file):
  234. """Set group checkpoint save path."""
  235. self.check_context_handle()
  236. import os
  237. dir_path = os.path.dirname(group_ckpt_save_file)
  238. if dir_path and not os.path.exists(dir_path):
  239. os.makedirs(dir_path)
  240. self._context_handle.set_group_ckpt_save_file(group_ckpt_save_file)
  241. def get_parameter_broadcast_is_set(self):
  242. """Get parameter broadcast is set or not."""
  243. self.check_context_handle()
  244. return self._context_handle.get_parameter_broadcast_is_set()
  245. def set_all_reduce_fusion_split_indices(self, indices, group=""):
  246. """
  247. Set allreduce fusion strategy by parameters indices.
  248. Args:
  249. indices (list): Indices list.
  250. group (str): The communication group of hccl/nccl.
  251. Raises:
  252. TypeError: If type of indices item is not int.
  253. TypeError: If group is not a python str.
  254. """
  255. self.check_context_handle()
  256. if not indices:
  257. raise ValueError('indices can not be empty')
  258. if isinstance(indices, (list)):
  259. for index in indices:
  260. if not isinstance(index, int):
  261. raise TypeError('indices has invalid value')
  262. else:
  263. raise TypeError('indices must be a python list')
  264. if len(set(indices)) != len(indices):
  265. raise ValueError('indices has duplicate elements')
  266. if sorted(indices) != indices:
  267. raise ValueError('elements in indices must be sorted in ascending order')
  268. if isinstance(group, (str)):
  269. group_len = len(group)
  270. if group_len > _MAX_GROUP_NAME_LEN:
  271. raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
  272. else:
  273. raise TypeError('Group must be a python str')
  274. if group == "":
  275. if context.get_context("device_target") == "Ascend":
  276. group = _DEFAULT_HCCL_FUSION_GROUP_NAME
  277. else:
  278. group = _DEFAULT_NCCL_FUSION_GROUP_NAME
  279. self._context_handle.set_all_reduce_fusion_split_indices(indices, group)
  280. if context.get_context("device_target") == "Ascend" and context.get_context("enable_ge"):
  281. _set_fusion_strategy_by_idx(indices)
  282. def get_all_reduce_fusion_split_indices(self, group=""):
  283. """
  284. Get allreduce fusion split indices.
  285. Args:
  286. group (str): The communication group of hccl/nccl.
  287. Returns:
  288. Return split sizes list according to the group.
  289. Raises:
  290. TypeError: If group is not a python str.
  291. """
  292. self.check_context_handle()
  293. if isinstance(group, (str)):
  294. group_len = len(group)
  295. if group_len > _MAX_GROUP_NAME_LEN:
  296. raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
  297. else:
  298. raise TypeError('Group must be a python str')
  299. if group == "":
  300. if context.get_context("device_target") == "Ascend":
  301. group = _DEFAULT_HCCL_FUSION_GROUP_NAME
  302. else:
  303. group = _DEFAULT_NCCL_FUSION_GROUP_NAME
  304. return self._context_handle.get_all_reduce_fusion_split_indices(group)
  305. def set_all_reduce_fusion_split_sizes(self, sizes, group=""):
  306. """
  307. Set allreduce fusion strategy by parameters data sizes.
  308. Args:
  309. sizes (list): Sizes list.
  310. group (str): The communication group of hccl/nccl.
  311. Raises:
  312. TypeError: If type of sizes item is not int.
  313. TypeError: If group is not a python str.
  314. """
  315. self.check_context_handle()
  316. if isinstance(sizes, (list)):
  317. for size in sizes:
  318. if not isinstance(size, int):
  319. raise TypeError('sizes has invalid value')
  320. else:
  321. raise TypeError('sizes must be a python list')
  322. if isinstance(group, (str)):
  323. group_len = len(group)
  324. if group_len > _MAX_GROUP_NAME_LEN:
  325. raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
  326. else:
  327. raise TypeError('Group must be a python str')
  328. if group == "":
  329. if context.get_context("device_target") == "Ascend":
  330. group = _DEFAULT_HCCL_FUSION_GROUP_NAME
  331. else:
  332. group = _DEFAULT_NCCL_FUSION_GROUP_NAME
  333. self._context_handle.set_all_reduce_fusion_split_sizes(sizes, group)
  334. if context.get_context("device_target") == "Ascend":
  335. _set_fusion_strategy_by_size(sizes)
  336. def get_all_reduce_fusion_split_sizes(self, group=""):
  337. """
  338. Get allreduce fusion split sizes.
  339. Args:
  340. group (str): The communication group of hccl/nccl.
  341. Returns:
  342. Return split sizes list according to the group.
  343. Raises:
  344. TypeError: If group is not a python str.
  345. """
  346. self.check_context_handle()
  347. if isinstance(group, (str)):
  348. group_len = len(group)
  349. if group_len > _MAX_GROUP_NAME_LEN:
  350. raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
  351. else:
  352. raise TypeError('Group must be a python str')
  353. if group == "":
  354. if context.get_context("device_target") == "Ascend":
  355. group = _DEFAULT_HCCL_FUSION_GROUP_NAME
  356. else:
  357. group = _DEFAULT_NCCL_FUSION_GROUP_NAME
  358. return self._context_handle.get_all_reduce_fusion_split_sizes(group)
  359. def set_enable_all_reduce_fusion(self, enable_all_reduce_fusion):
  360. """
  361. Set enable/disable all reduce fusion.
  362. Args:
  363. enable_all_reduce_fusion (bool): Enable/disable all reduce fusion.
  364. """
  365. self.check_context_handle()
  366. if not isinstance(enable_all_reduce_fusion, bool):
  367. raise TypeError('enable_all_reduce_fusion is invalid type')
  368. self._context_handle.set_enable_all_reduce_fusion(enable_all_reduce_fusion)
  369. def get_enable_all_reduce_fusion(self):
  370. """Get all reduce fusion flag."""
  371. self.check_context_handle()
  372. return self._context_handle.get_enable_all_reduce_fusion()
  373. def get_device_num_is_set(self):
  374. """Get device number is set or not."""
  375. self.check_context_handle()
  376. return self._context_handle.get_device_num_is_set()
  377. def get_global_rank_is_set(self):
  378. """Get global rank is set or not."""
  379. self.check_context_handle()
  380. return self._context_handle.get_global_rank_is_set()
  381. def set_enable_parallel_optimizer(self, enable_parallel_optimizer):
  382. """
  383. Set enable/disable parallel optimizer.
  384. Args:
  385. set_enable_parallel_optimizer (bool): Enable/disable parallel optimizer.
  386. """
  387. self.check_context_handle()
  388. if not isinstance(enable_parallel_optimizer, bool):
  389. raise TypeError('enable_parallel_optimizer is invalid type')
  390. self._context_handle.set_enable_parallel_optimizer(enable_parallel_optimizer)
  391. def get_enable_parallel_optimizer(self):
  392. """Get parallel optimizer flag."""
  393. self.check_context_handle()
  394. return self._context_handle.get_enable_parallel_optimizer()
  395. def reset(self):
  396. """Reset all settings."""
  397. self.check_context_handle()
  398. self._context_handle.reset()
  399. _auto_parallel_context = None
  400. def auto_parallel_context():
  401. """
  402. Get the global _auto_parallel_context, if it is not created, create a new one.
  403. Returns:
  404. _AutoParallelContext, the global auto parallel context.
  405. """
  406. global _auto_parallel_context
  407. if _auto_parallel_context is None:
  408. _auto_parallel_context = _AutoParallelContext()
  409. return _auto_parallel_context
  410. _set_auto_parallel_context_func_map = {
  411. "device_num": auto_parallel_context().set_device_num,
  412. "global_rank": auto_parallel_context().set_global_rank,
  413. "gradients_mean": auto_parallel_context().set_gradients_mean,
  414. "gradient_fp32_sync": auto_parallel_context().set_gradient_fp32_sync,
  415. "loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean,
  416. "pipeline_stages": auto_parallel_context().set_pipeline_stages,
  417. "parallel_mode": auto_parallel_context().set_parallel_mode,
  418. "auto_parallel_search_mode": auto_parallel_context().set_strategy_search_mode,
  419. "parameter_broadcast": auto_parallel_context().set_parameter_broadcast,
  420. "strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file,
  421. "strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file,
  422. "group_ckpt_save_file": auto_parallel_context().set_group_ckpt_save_file,
  423. "full_batch": auto_parallel_context().set_full_batch,
  424. "enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer,
  425. "grad_accumulation_step": auto_parallel_context().set_grad_accumulation_step,
  426. "all_reduce_fusion_config": auto_parallel_context().set_all_reduce_fusion_split_indices}
  427. _get_auto_parallel_context_func_map = {
  428. "device_num": auto_parallel_context().get_device_num,
  429. "global_rank": auto_parallel_context().get_global_rank,
  430. "gradients_mean": auto_parallel_context().get_gradients_mean,
  431. "gradient_fp32_sync": auto_parallel_context().get_gradient_fp32_sync,
  432. "loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean,
  433. "pipeline_stages": auto_parallel_context().get_pipeline_stages,
  434. "parallel_mode": auto_parallel_context().get_parallel_mode,
  435. "auto_parallel_search_mode": auto_parallel_context().get_strategy_search_mode,
  436. "parameter_broadcast": auto_parallel_context().get_parameter_broadcast,
  437. "strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file,
  438. "strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file,
  439. "full_batch": auto_parallel_context().get_full_batch,
  440. "enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer,
  441. "grad_accumulation_step": auto_parallel_context().get_grad_accumulation_step,
  442. "all_reduce_fusion_config": auto_parallel_context().get_all_reduce_fusion_split_indices}
  443. @args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool,
  444. loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str,
  445. parameter_broadcast=bool, strategy_ckpt_load_file=str,
  446. strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
  447. grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str)
  448. def _set_auto_parallel_context(**kwargs):
  449. """
  450. Set auto parallel context.
  451. Note:
  452. Attribute name is required for setting attributes.
  453. Args:
  454. device_num (int): Available device number, the value must be in [1, 4096]. Default: 1.
  455. global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0.
  456. gradients_mean (bool): Whether to perform mean operator after all-reduce of mirror. Default: False.
  457. loss_repeated_mean (bool): Whether to perform mean operator in backward in the case of repeated
  458. calculations. Default: True.
  459. gradient_fp32_sync (bool): Gradients allreduce by fp32 even though gradients is fp16 if this flag is True.
  460. Default: True.
  461. parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel",
  462. "hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone".
  463. - stand_alone: Only one processor working.
  464. - data_parallel: Distributing the data across different processors.
  465. - hybrid_parallel: Achieving data parallelism and model parallelism manually.
  466. - semi_auto_parallel: Achieving data parallelism and model parallelism by
  467. setting parallel strategies.
  468. - auto_parallel: Achieving parallelism automatically.
  469. auto_parallel_search_mode (str): There are two kinds of search modes, "recursive_programming"
  470. and "dynamic_programming". Default: "dynamic_programming".
  471. - recursive_programming: Recursive programming search mode.
  472. - dynamic_programming: Dynamic programming search mode.
  473. parameter_broadcast (bool): Indicating whether to broadcast parameters before training.
  474. "stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter
  475. broadcast. Default: False.
  476. strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
  477. strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
  478. group_ckpt_save_file (str): The path to save parallel group checkpoint. Default: ''
  479. full_batch (bool): Whether to load the whole batch on each device. Default: False.
  480. enable_parallel_optimizer (bool): Enable using optimizer segmentation or not. Default: False.
  481. all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices.
  482. pipeline_stages (int): Set the stage information for pipeline parallel. This indicates how
  483. the devices are distributed alone the pipeline. The total devices will be divided into
  484. 'pipeline_stags' stages. This currently could only be used when
  485. parall mode semi_auto_parallel is enabled. Default: 0
  486. Raises:
  487. ValueError: If input key is not attribute in auto parallel context.
  488. """
  489. for key, value in kwargs.items():
  490. if key not in _set_auto_parallel_context_func_map:
  491. raise ValueError("Set context keyword %s is not recognized!" % key)
  492. set_func = _set_auto_parallel_context_func_map[key]
  493. set_func(value)
  494. def _get_auto_parallel_context(attr_key):
  495. """
  496. Get auto parallel context attribute value according to the key.
  497. Args:
  498. attr_key (str): The key of the attribute.
  499. Returns:
  500. Return attribute value according to the key.
  501. Raises:
  502. ValueError: If input key is not attribute in auto parallel context.
  503. """
  504. if attr_key not in _get_auto_parallel_context_func_map:
  505. raise ValueError("Get context keyword %s is not recognized!" % attr_key)
  506. get_func = _get_auto_parallel_context_func_map[attr_key]
  507. return get_func()
  508. def _reset_auto_parallel_context():
  509. """
  510. Reset auto parallel context attributes to the default values:
  511. - device_num: 1.
  512. - global_rank: 0.
  513. - gradients_mean: False.
  514. - gradient_fp32_sync: True.
  515. - parallel_mode: "stand_alone".
  516. - parameter_broadcast: False.
  517. - strategy_ckpt_load_file: ""
  518. - strategy_ckpt_save_file: ""
  519. - enable_parallel_optimizer: False
  520. - auto_parallel_search_mode: dynamic_programming
  521. - pipeline_stages: 0
  522. """
  523. auto_parallel_context().reset()