You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_auto_parallel_context.py 30 kB

5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Context of auto parallel"""
  16. import threading
  17. import mindspore.context as context
  18. import mindspore.log as logger
  19. from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size
  20. from mindspore.parallel._ps_context import _is_role_pserver
  21. from mindspore._c_expression import AutoParallelContext
  22. from mindspore._checkparam import args_type_check, Validator
  23. _MAX_GROUP_NAME_LEN = 127
  24. _DEFAULT_HCCL_FUSION_GROUP_NAME = "hccl_world_groupsum1"
  25. _DEFAULT_NCCL_FUSION_GROUP_NAME = "nccl_world_groupsum1"
  26. class _AutoParallelContext:
  27. """
  28. _AutoParallelContext is the environment in which operations are executed
  29. Note:
  30. Create a context through instantiating Context object is not recommended.
  31. Should use auto_parallel_context() to get the context since Context is singleton.
  32. """
  33. _instance = None
  34. _instance_lock = threading.Lock()
  35. def __init__(self):
  36. self._context_handle = AutoParallelContext.get_instance()
  37. def __new__(cls):
  38. if cls._instance is None:
  39. cls._instance_lock.acquire()
  40. cls._instance = object.__new__(cls)
  41. cls._instance_lock.release()
  42. return cls._instance
  43. def check_context_handle(self):
  44. """
  45. Check context handle.
  46. Raises:
  47. ValueError: If the context handle is none.
  48. """
  49. if self._context_handle is None:
  50. raise ValueError("Context handle is none in context!!!")
  51. def set_device_num(self, device_num):
  52. """
  53. Set device num for auto parallel.
  54. Args:
  55. device_num (int): The device number.
  56. Raises:
  57. ValueError: If the device num is not in [1, 4096].
  58. """
  59. self.check_context_handle()
  60. if device_num < 1 or device_num > 4096:
  61. raise ValueError("Device num must be in [1, 4096], but got {}".format(device_num))
  62. self._context_handle.set_device_num(device_num)
  63. def get_device_num(self):
  64. """Get device num."""
  65. self.check_context_handle()
  66. return self._context_handle.get_device_num()
  67. def set_global_rank(self, global_rank):
  68. """
  69. Set global rank for auto parallel.
  70. Args:
  71. global_rank (int): The rank id of current rank.
  72. Raises:
  73. ValueError: If the global rank is not in [1, 4096].
  74. """
  75. self.check_context_handle()
  76. if global_rank < 0 or global_rank > 4095:
  77. raise ValueError("Global rank must be in [0, 4095], but got {}".format(global_rank))
  78. self._context_handle.set_global_rank(global_rank)
  79. def get_global_rank(self):
  80. """Get current rank id."""
  81. self.check_context_handle()
  82. return self._context_handle.get_global_rank()
  83. def set_pipeline_stages(self, stages):
  84. """Set the stages of the pipeline"""
  85. self.check_context_handle()
  86. self._context_handle.set_pipeline_stage_split_num(stages)
  87. def get_pipeline_stages(self):
  88. """Get the stages of the pipeline"""
  89. self.check_context_handle()
  90. return self._context_handle.get_pipeline_stage_split_num()
  91. def set_gradients_mean(self, gradients_mean):
  92. """
  93. Set gradients_mean flag.
  94. Note:
  95. If gradients_mean is true, it will insert a div operator after parameter gradients allreduce.
  96. Args:
  97. gradients_mean (bool): The gradients_mean flag.
  98. """
  99. self.check_context_handle()
  100. self._context_handle.set_gradients_mean(gradients_mean)
  101. def get_gradients_mean(self):
  102. """Get gradients_mean flag."""
  103. self.check_context_handle()
  104. return self._context_handle.get_gradients_mean()
  105. def set_gradient_fp32_sync(self, gradient_fp32_sync):
  106. """
  107. Set gradient_fp32_sync.
  108. Note:
  109. If gradient_fp32_sync is true,
  110. it will convert tensor type from fp16 to fp32 before parameter gradients allreduce.
  111. Args:
  112. gradient_fp32_sync (bool): The gradient_fp32_sync flag.
  113. """
  114. self.check_context_handle()
  115. self._context_handle.set_gradient_fp32_sync(gradient_fp32_sync)
  116. def get_gradient_fp32_sync(self):
  117. """Get gradient_fp32_sync flag."""
  118. self.check_context_handle()
  119. return self._context_handle.get_gradient_fp32_sync()
  120. def set_loss_repeated_mean(self, loss_repeated_mean):
  121. """
  122. Set loss_repeated_mean flag.
  123. Note:
  124. If loss_repeated_mean is true,
  125. Distributed automatic differentiation will perform a mean operator
  126. in backward in the case of repeated calculations.
  127. Args:
  128. loss_repeated_mean (bool): The loss_repeated_mean flag.
  129. """
  130. self.check_context_handle()
  131. self._context_handle.set_loss_repeated_mean(loss_repeated_mean)
  132. def get_loss_repeated_mean(self):
  133. """Get loss_repeated_mean flag."""
  134. self.check_context_handle()
  135. return self._context_handle.get_loss_repeated_mean()
  136. def set_parallel_mode(self, parallel_mode):
  137. """
  138. Set parallel mode for auto parallel.
  139. Args:
  140. parallel_mode (str): The parallel mode of auto parallel.
  141. Raises:
  142. ValueError: If parallel mode is not supported.
  143. """
  144. self.check_context_handle()
  145. ret = self._context_handle.set_parallel_mode(parallel_mode)
  146. if ret is False:
  147. raise ValueError("Parallel mode does not support {}".format(parallel_mode))
  148. def get_parallel_mode(self):
  149. """Get parallel mode."""
  150. self.check_context_handle()
  151. if _is_role_pserver():
  152. return context.ParallelMode.STAND_ALONE
  153. return self._context_handle.get_parallel_mode()
  154. def set_strategy_search_mode(self, auto_parallel_search_mode):
  155. """
  156. Set search mode of strategy.
  157. Args:
  158. auto_parallel_search_mode (str): The search mode of strategy.
  159. """
  160. self.check_context_handle()
  161. ret = self._context_handle.set_strategy_search_mode(auto_parallel_search_mode)
  162. if ret is False:
  163. raise ValueError("Strategy search mode does not support {}".format(auto_parallel_search_mode))
  164. def get_strategy_search_mode(self):
  165. """Get search mode of strategy."""
  166. self.check_context_handle()
  167. return self._context_handle.get_strategy_search_mode()
  168. def set_parameter_broadcast(self, parameter_broadcast):
  169. """
  170. Set parameter broadcast.
  171. Args:
  172. parameter_broadcast (bool): Parameter broadcast or not.
  173. """
  174. self.check_context_handle()
  175. self._context_handle.set_parameter_broadcast(parameter_broadcast)
  176. def get_parameter_broadcast(self):
  177. """Get parameter broadcast flag."""
  178. self.check_context_handle()
  179. return self._context_handle.get_parameter_broadcast()
  180. def set_strategy_ckpt_load_file(self, strategy_ckpt_load_file):
  181. """
  182. Set strategy checkpoint load path.
  183. Args:
  184. strategy_ckpt_load_file (bool): Path to load parallel strategy checkpoint.
  185. """
  186. self.check_context_handle()
  187. self._context_handle.set_strategy_ckpt_load_file(strategy_ckpt_load_file)
  188. def get_strategy_ckpt_load_file(self):
  189. """Get strategy checkpoint load path."""
  190. self.check_context_handle()
  191. return self._context_handle.get_strategy_ckpt_load_file()
  192. def set_full_batch(self, full_batch):
  193. """
  194. Set whether load full batch on each device.
  195. Args:
  196. full_batch (bool): True if load full batch on each device.
  197. """
  198. self.check_context_handle()
  199. self._context_handle.set_full_batch(full_batch)
  200. def get_full_batch(self):
  201. """Get whether load full batch on each device."""
  202. self.check_context_handle()
  203. if _is_role_pserver():
  204. return False
  205. return self._context_handle.get_full_batch()
  206. def set_grad_accumulation_step(self, grad_accumulation_step):
  207. """
  208. Set grad accumulation step.
  209. Args:
  210. grad_accumulation_step (int): The grad accumulation step.
  211. """
  212. self.check_context_handle()
  213. Validator.check_positive_int(grad_accumulation_step)
  214. self._context_handle.set_grad_accumulation_step(grad_accumulation_step)
  215. def get_grad_accumulation_step(self):
  216. """Get grad accumulation step."""
  217. self.check_context_handle()
  218. return self._context_handle.get_grad_accumulation_step()
  219. def set_strategy_ckpt_save_file(self, strategy_ckpt_save_file):
  220. """
  221. Set strategy checkpoint save path.
  222. Args:
  223. strategy_ckpt_save_file (bool): Path to save parallel strategy checkpoint.
  224. """
  225. self.check_context_handle()
  226. import os
  227. dir_path = os.path.dirname(strategy_ckpt_save_file)
  228. if dir_path and not os.path.exists(dir_path):
  229. os.makedirs(dir_path)
  230. self._context_handle.set_strategy_ckpt_save_file(strategy_ckpt_save_file)
  231. def get_strategy_ckpt_save_file(self):
  232. """Get strategy checkpoint save path."""
  233. self.check_context_handle()
  234. return self._context_handle.get_strategy_ckpt_save_file()
  235. def set_group_ckpt_save_file(self, group_ckpt_save_file):
  236. """Set group checkpoint save path."""
  237. self.check_context_handle()
  238. import os
  239. dir_path = os.path.dirname(group_ckpt_save_file)
  240. if dir_path and not os.path.exists(dir_path):
  241. os.makedirs(dir_path)
  242. self._context_handle.set_group_ckpt_save_file(group_ckpt_save_file)
  243. def get_parameter_broadcast_is_set(self):
  244. """Get parameter broadcast is set or not."""
  245. self.check_context_handle()
  246. return self._context_handle.get_parameter_broadcast_is_set()
  247. def set_all_reduce_fusion_split_indices(self, indices, group=""):
  248. """
  249. Set allreduce fusion strategy by parameters indices.
  250. Args:
  251. indices (list): Indices list.
  252. group (str): The communication group of hccl/nccl.
  253. Raises:
  254. TypeError: If type of indices item is not int.
  255. TypeError: If group is not a python str.
  256. """
  257. self.check_context_handle()
  258. if not indices:
  259. raise ValueError('indices can not be empty')
  260. if isinstance(indices, (list)):
  261. for index in indices:
  262. if not isinstance(index, int):
  263. raise TypeError('indices has invalid value')
  264. else:
  265. raise TypeError('indices must be a python list')
  266. if len(set(indices)) != len(indices):
  267. raise ValueError('indices has duplicate elements')
  268. if sorted(indices) != indices:
  269. raise ValueError('elements in indices must be sorted in ascending order')
  270. if isinstance(group, (str)):
  271. group_len = len(group)
  272. if group_len > _MAX_GROUP_NAME_LEN:
  273. raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
  274. else:
  275. raise TypeError('Group must be a python str')
  276. if group == "":
  277. if context.get_context("device_target") == "Ascend":
  278. group = _DEFAULT_HCCL_FUSION_GROUP_NAME
  279. else:
  280. group = _DEFAULT_NCCL_FUSION_GROUP_NAME
  281. self._context_handle.set_all_reduce_fusion_split_indices(indices, group)
  282. if context.get_context("device_target") == "Ascend" and context.get_context("enable_ge"):
  283. _set_fusion_strategy_by_idx(indices)
  284. def get_all_reduce_fusion_split_indices(self, group=""):
  285. """
  286. Get allreduce fusion split indices.
  287. Args:
  288. group (str): The communication group of hccl/nccl.
  289. Returns:
  290. Return split sizes list according to the group.
  291. Raises:
  292. TypeError: If group is not a python str.
  293. """
  294. self.check_context_handle()
  295. if isinstance(group, (str)):
  296. group_len = len(group)
  297. if group_len > _MAX_GROUP_NAME_LEN:
  298. raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
  299. else:
  300. raise TypeError('Group must be a python str')
  301. if group == "":
  302. if context.get_context("device_target") == "Ascend":
  303. group = _DEFAULT_HCCL_FUSION_GROUP_NAME
  304. else:
  305. group = _DEFAULT_NCCL_FUSION_GROUP_NAME
  306. return self._context_handle.get_all_reduce_fusion_split_indices(group)
  307. def set_all_reduce_fusion_split_sizes(self, sizes, group=""):
  308. """
  309. Set allreduce fusion strategy by parameters data sizes.
  310. Args:
  311. sizes (list): Sizes list.
  312. group (str): The communication group of hccl/nccl.
  313. Raises:
  314. TypeError: If type of sizes item is not int.
  315. TypeError: If group is not a python str.
  316. """
  317. self.check_context_handle()
  318. if isinstance(sizes, (list)):
  319. for size in sizes:
  320. if not isinstance(size, int):
  321. raise TypeError('sizes has invalid value')
  322. else:
  323. raise TypeError('sizes must be a python list')
  324. if isinstance(group, (str)):
  325. group_len = len(group)
  326. if group_len > _MAX_GROUP_NAME_LEN:
  327. raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
  328. else:
  329. raise TypeError('Group must be a python str')
  330. if group == "":
  331. if context.get_context("device_target") == "Ascend":
  332. group = _DEFAULT_HCCL_FUSION_GROUP_NAME
  333. else:
  334. group = _DEFAULT_NCCL_FUSION_GROUP_NAME
  335. self._context_handle.set_all_reduce_fusion_split_sizes(sizes, group)
  336. if context.get_context("device_target") == "Ascend":
  337. _set_fusion_strategy_by_size(sizes)
  338. def get_all_reduce_fusion_split_sizes(self, group=""):
  339. """
  340. Get allreduce fusion split sizes.
  341. Args:
  342. group (str): The communication group of hccl/nccl.
  343. Returns:
  344. Return split sizes list according to the group.
  345. Raises:
  346. TypeError: If group is not a python str.
  347. """
  348. self.check_context_handle()
  349. if isinstance(group, (str)):
  350. group_len = len(group)
  351. if group_len > _MAX_GROUP_NAME_LEN:
  352. raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
  353. else:
  354. raise TypeError('Group must be a python str')
  355. if group == "":
  356. if context.get_context("device_target") == "Ascend":
  357. group = _DEFAULT_HCCL_FUSION_GROUP_NAME
  358. else:
  359. group = _DEFAULT_NCCL_FUSION_GROUP_NAME
  360. return self._context_handle.get_all_reduce_fusion_split_sizes(group)
  361. def set_enable_all_reduce_fusion(self, enable_all_reduce_fusion):
  362. """
  363. Set enable/disable all reduce fusion.
  364. Args:
  365. enable_all_reduce_fusion (bool): Enable/disable all reduce fusion.
  366. """
  367. self.check_context_handle()
  368. if not isinstance(enable_all_reduce_fusion, bool):
  369. raise TypeError('enable_all_reduce_fusion is invalid type')
  370. self._context_handle.set_enable_all_reduce_fusion(enable_all_reduce_fusion)
  371. def get_enable_all_reduce_fusion(self):
  372. """Get all reduce fusion flag."""
  373. self.check_context_handle()
  374. return self._context_handle.get_enable_all_reduce_fusion()
  375. def get_device_num_is_set(self):
  376. """Get device number is set or not."""
  377. self.check_context_handle()
  378. return self._context_handle.get_device_num_is_set()
  379. def get_global_rank_is_set(self):
  380. """Get global rank is set or not."""
  381. self.check_context_handle()
  382. return self._context_handle.get_global_rank_is_set()
  383. def set_enable_parallel_optimizer(self, enable_parallel_optimizer):
  384. """
  385. Set enable/disable parallel optimizer.
  386. Args:
  387. set_enable_parallel_optimizer (bool): Enable/disable parallel optimizer.
  388. """
  389. self.check_context_handle()
  390. if not isinstance(enable_parallel_optimizer, bool):
  391. raise TypeError('enable_parallel_optimizer is invalid type')
  392. self._context_handle.set_enable_parallel_optimizer(enable_parallel_optimizer)
  393. def get_enable_parallel_optimizer(self):
  394. """Get parallel optimizer flag."""
  395. self.check_context_handle()
  396. return self._context_handle.get_enable_parallel_optimizer()
  397. def set_communi_parallel_mode(self, communi_parallel_mode):
  398. """
  399. Set communication parallel mode.
  400. Args:
  401. communi_parallel_mode (str): The communication parallel mode.
  402. Raises:
  403. ValueError: If parallel mode is not supported.
  404. """
  405. self.check_context_handle()
  406. ret = self._context_handle.set_communi_parallel_mode(communi_parallel_mode)
  407. if ret is False:
  408. raise ValueError("Communication parallel mode does not support {}".format(communi_parallel_mode))
  409. def get_communi_parallel_mode(self):
  410. """Get communication parallel mode."""
  411. self.check_context_handle()
  412. return self._context_handle.get_communi_parallel_mode()
  413. def set_optimizer_weight_shard_size(self, optimizer_weight_shard_size):
  414. """
  415. Set optimizer_weight_shard_size.
  416. Args:
  417. optimizer_weight_shard_size (int): Opt shard group size when not globally use parallel
  418. optimizer across devices.
  419. """
  420. self.check_context_handle()
  421. if not isinstance(optimizer_weight_shard_size, int):
  422. raise TypeError('optimizer_weight_shard_size is invalid type')
  423. if optimizer_weight_shard_size <= 1:
  424. logger.warning("The setting 'optimizer_weight_shard_size' is invalid. "
  425. "Please use the integer larger than 1.")
  426. return
  427. self._context_handle.set_optimizer_weight_shard_size(optimizer_weight_shard_size)
  428. def get_optimizer_weight_shard_size(self):
  429. """Get optimizer_weight_shard_size."""
  430. self.check_context_handle()
  431. return self._context_handle.get_optimizer_weight_shard_size()
  432. def set_optimizer_weight_shard_integrated_save(self, optimizer_weight_shard_integrated_save):
  433. """
  434. Set optimizer_weight_shard_integrated_save.
  435. Args:
  436. optimizer_weight_shard_integrated_save (bool): Whether to integrated save weight shard when
  437. enable parallel optimizer.
  438. """
  439. self.check_context_handle()
  440. if not isinstance(optimizer_weight_shard_integrated_save, bool):
  441. raise TypeError('optimizer_weight_shard_integrated_save is invalid type')
  442. self._context_handle.set_optimizer_weight_shard_integrated_save(optimizer_weight_shard_integrated_save)
  443. def get_optimizer_weight_shard_integrated_save(self):
  444. """Get optimizer_weight_shard_size."""
  445. self.check_context_handle()
  446. return self._context_handle.get_optimizer_weight_shard_integrated_save()
  447. def reset(self):
  448. """Reset all settings."""
  449. self.check_context_handle()
  450. self._context_handle.reset()
  451. _auto_parallel_context = None
  452. def auto_parallel_context():
  453. """
  454. Get the global _auto_parallel_context, if it is not created, create a new one.
  455. Returns:
  456. _AutoParallelContext, the global auto parallel context.
  457. """
  458. global _auto_parallel_context
  459. if _auto_parallel_context is None:
  460. _auto_parallel_context = _AutoParallelContext()
  461. return _auto_parallel_context
  462. _set_auto_parallel_context_func_map = {
  463. "device_num": auto_parallel_context().set_device_num,
  464. "global_rank": auto_parallel_context().set_global_rank,
  465. "gradients_mean": auto_parallel_context().set_gradients_mean,
  466. "gradient_fp32_sync": auto_parallel_context().set_gradient_fp32_sync,
  467. "loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean,
  468. "pipeline_stages": auto_parallel_context().set_pipeline_stages,
  469. "parallel_mode": auto_parallel_context().set_parallel_mode,
  470. "auto_parallel_search_mode": auto_parallel_context().set_strategy_search_mode,
  471. "parameter_broadcast": auto_parallel_context().set_parameter_broadcast,
  472. "strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file,
  473. "strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file,
  474. "group_ckpt_save_file": auto_parallel_context().set_group_ckpt_save_file,
  475. "full_batch": auto_parallel_context().set_full_batch,
  476. "enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer,
  477. "grad_accumulation_step": auto_parallel_context().set_grad_accumulation_step,
  478. "all_reduce_fusion_config": auto_parallel_context().set_all_reduce_fusion_split_indices,
  479. "communi_parallel_mode": auto_parallel_context().set_communi_parallel_mode,
  480. "optimizer_weight_shard_size": auto_parallel_context().set_optimizer_weight_shard_size,
  481. "optimizer_weight_shard_integrated_save": auto_parallel_context().set_optimizer_weight_shard_integrated_save}
  482. _get_auto_parallel_context_func_map = {
  483. "device_num": auto_parallel_context().get_device_num,
  484. "global_rank": auto_parallel_context().get_global_rank,
  485. "gradients_mean": auto_parallel_context().get_gradients_mean,
  486. "gradient_fp32_sync": auto_parallel_context().get_gradient_fp32_sync,
  487. "loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean,
  488. "pipeline_stages": auto_parallel_context().get_pipeline_stages,
  489. "parallel_mode": auto_parallel_context().get_parallel_mode,
  490. "auto_parallel_search_mode": auto_parallel_context().get_strategy_search_mode,
  491. "parameter_broadcast": auto_parallel_context().get_parameter_broadcast,
  492. "strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file,
  493. "strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file,
  494. "full_batch": auto_parallel_context().get_full_batch,
  495. "enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer,
  496. "grad_accumulation_step": auto_parallel_context().get_grad_accumulation_step,
  497. "all_reduce_fusion_config": auto_parallel_context().get_all_reduce_fusion_split_indices,
  498. "communi_parallel_mode": auto_parallel_context().get_communi_parallel_mode,
  499. "optimizer_weight_shard_size": auto_parallel_context().get_optimizer_weight_shard_size,
  500. "optimizer_weight_shard_integrated_save": auto_parallel_context().get_optimizer_weight_shard_integrated_save}
  501. @args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool,
  502. loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str,
  503. parameter_broadcast=bool, strategy_ckpt_load_file=str,
  504. strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
  505. grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str,
  506. communi_parallel_mode=str, optimizer_weight_shard_size=int,
  507. optimizer_weight_shard_integrated_save=bool)
  508. def _set_auto_parallel_context(**kwargs):
  509. """
  510. Set auto parallel context.
  511. Note:
  512. Attribute name is required for setting attributes.
  513. Args:
  514. device_num (int): Available device number, the value must be in [1, 4096]. Default: 1.
  515. global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0.
  516. gradients_mean (bool): Whether to perform mean operator after all-reduce of mirror. Default: False.
  517. loss_repeated_mean (bool): Whether to perform mean operator in backward in the case of repeated
  518. calculations. Default: True.
  519. gradient_fp32_sync (bool): Gradients allreduce by fp32 even though gradients is fp16 if this flag is True.
  520. Default: True.
  521. parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel",
  522. "hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone".
  523. - stand_alone: Only one processor working.
  524. - data_parallel: Distributing the data across different processors.
  525. - hybrid_parallel: Achieving data parallelism and model parallelism manually.
  526. - semi_auto_parallel: Achieving data parallelism and model parallelism by
  527. setting parallel strategies.
  528. - auto_parallel: Achieving parallelism automatically.
  529. auto_parallel_search_mode (str): There are two kinds of search modes, "recursive_programming"
  530. and "dynamic_programming". Default: "dynamic_programming".
  531. - recursive_programming: Recursive programming search mode.
  532. - dynamic_programming: Dynamic programming search mode.
  533. parameter_broadcast (bool): Indicating whether to broadcast parameters before training.
  534. "stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter
  535. broadcast. Default: False.
  536. strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
  537. strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
  538. group_ckpt_save_file (str): The path to save parallel group checkpoint. Default: ''
  539. full_batch (bool): Whether to load the whole batch on each device. Default: False.
  540. enable_parallel_optimizer (bool): Enable using optimizer segmentation or not. Default: False.
  541. all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices.
  542. pipeline_stages (int): Set the stage information for pipeline parallel. This indicates how
  543. the devices are distributed alone the pipeline. The total devices will be divided into
  544. 'pipeline_stags' stages. This currently could only be used when
  545. parallel mode semi_auto_parallel is enabled. Default: 0
  546. communi_parallel_mode (str): There are tree kinds of communication parallel modes, "all_group_parallel",
  547. "same_server_group_parallel" and "no_group_parallel". Default: "all_group_parallel".
  548. - all_group_parallel: All communication groups are in parallel.
  549. - same_server_group_parallel: Only the communication groups within the same server are parallel.
  550. - no_group_parallel: All communication groups are not parallel.
  551. optimizer_weight_shard_size (int): Set optimizer shard group size when not fully use parallel optimizer.
  552. It should be larger than one and less than or equal with the data parallel size.
  553. Default: -1, which means fully use parallel optimizer in data parallel dimension.
  554. optimizer_weight_shard_integrated_save (bool): Whether to integrated save weight shard when enable parallel
  555. optimizer. Default: False.
  556. Raises:
  557. ValueError: If input key is not attribute in auto parallel context.
  558. """
  559. for key, value in kwargs.items():
  560. if key not in _set_auto_parallel_context_func_map:
  561. raise ValueError("Set context keyword %s is not recognized!" % key)
  562. set_func = _set_auto_parallel_context_func_map[key]
  563. set_func(value)
  564. def _get_auto_parallel_context(attr_key):
  565. """
  566. Get auto parallel context attribute value according to the key.
  567. Args:
  568. attr_key (str): The key of the attribute.
  569. Returns:
  570. Return attribute value according to the key.
  571. Raises:
  572. ValueError: If input key is not attribute in auto parallel context.
  573. """
  574. if attr_key not in _get_auto_parallel_context_func_map:
  575. raise ValueError("Get context keyword %s is not recognized!" % attr_key)
  576. get_func = _get_auto_parallel_context_func_map[attr_key]
  577. return get_func()
  578. def _reset_auto_parallel_context():
  579. """
  580. Reset auto parallel context attributes to the default values:
  581. - device_num: 1.
  582. - global_rank: 0.
  583. - gradients_mean: False.
  584. - gradient_fp32_sync: True.
  585. - parallel_mode: "stand_alone".
  586. - parameter_broadcast: False.
  587. - strategy_ckpt_load_file: ""
  588. - strategy_ckpt_save_file: ""
  589. - enable_parallel_optimizer: False
  590. - auto_parallel_search_mode: dynamic_programming
  591. - pipeline_stages: 0
  592. """
  593. auto_parallel_context().reset()